Source code for drinx.base

from __future__ import annotations
from collections.abc import Sequence as ABCSequence
import dataclasses
import jax
import jax.numpy as jnp
from drinx.transform import dataclass
from typing import dataclass_transform, Any, Callable, Generic, Sequence, Self, TypeVar
from dataclasses import field as orig_field
from drinx.attribute import (
    DRINX_ON_GETATTR,
    DRINX_ON_SETATTR,
    field,
    private_field,
    static_field,
    static_private_field,
)


[docs] @dataclass_transform( field_specifiers=( orig_field, field, static_field, private_field, static_private_field, ) ) class DataClass: """Base class alternative to the ``@drinx.dataclass`` decorator. Subclassing ``DataClass`` automatically applies the ``@drinx.dataclass`` transform, registering the subclass as a frozen dataclass and a JAX pytree node. Fields annotated with :func:`static_field` (or ``field(static=True)``) are placed in the pytree auxiliary data (not traced by JAX); all other fields become pytree leaves. Usage:: class MyModel(DataClass): weights: jax.Array learning_rate: float = static_field(default=1e-3) Dataclass keyword arguments (``init``, ``repr``, ``eq``, etc.) can be forwarded via the class definition:: class MyModel(DataClass, order=True, kw_only=True): ... Also provides :meth:`aset` for functional nested updates and :meth:`updated_copy` as a convenience wrapper around :func:`dataclasses.replace`. """ def __init_subclass__( cls, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False, ): """Apply the ``@drinx.dataclass`` transform to every subclass automatically. Called by Python whenever a new subclass of :class:`DataClass` is defined. Accepts the same keyword arguments as the standard :func:`dataclasses.dataclass` decorator (except ``frozen``, which is always ``True``). Args: init: Generate ``__init__``. repr: Generate ``__repr__``. eq: Generate ``__eq__`` and ``__hash__``. order: Generate comparison methods (``<``, ``<=``, ``>``, ``>=``). unsafe_hash: Force generation of ``__hash__`` even when ``eq=True``. match_args: Set ``__match_args__`` for structural pattern matching. kw_only: Make all fields keyword-only in ``__init__``. slots: Not supported; ignored. Slot-based dataclasses are excluded because ``__slots__`` attribute-access speedups are negligible compared to JAX kernel dispatch overhead, and supporting them would add significant complexity to pytree flatten/unflatten. weakref_slot: Add a ``__weakref__`` slot (ignored; kept for API compatibility). """ del slots, weakref_slot super().__init_subclass__() user_post_init = cls.__dict__.get("__post_init__") if user_post_init is not None and user_post_init is not DataClass.__post_init__: setattr( cls, "__post_init__", DataClass._wrap_user_post_init(user_post_init), ) # Programmatically apply our custom dataclass wrapper to the subclass. dataclass_transform = dataclass( init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, match_args=match_args, kw_only=kw_only, weakref_slot=False, ) dataclass_transform(cls) setattr(cls, "__setattr__", DataClass.__setattr__) setattr(cls, "__getattribute__", DataClass.__getattribute__) @staticmethod def _wrap_user_post_init( user_post_init: Callable[..., Any], ) -> Callable[..., None]: """Wraps the user-defined __post_init__ to ensure framework-level initialization logic runs afterward.""" def wrapped(self: DataClass, *args: Any, **kwargs: Any) -> None: user_post_init(self, *args, **kwargs) DataClass.__post_init__(self) return wrapped @staticmethod def _normalize_callbacks(callbacks: Any) -> tuple[Callable[..., Any], ...]: """Standardizes callback metadata into a consistent tuple of callable functions.""" if callbacks is None: return () if isinstance(callbacks, ABCSequence) and not isinstance( callbacks, (str, bytes) ): return tuple(callbacks) return (callbacks,) @staticmethod def _run_callbacks(value: Any, callbacks: Sequence[Callable[..., Any]]) -> Any: """Sequentially applies a series of callback functions to a value, passing the result of one to the next.""" result = value for callback in callbacks: callback_result = callback(result) if callback_result is not None: result = callback_result return result @staticmethod def _get_field_definition(instance: DataClass, name: str) -> Any: """Safely retrieves the dataclass field metadata for a specific attribute name using low-level access.""" try: dataclass_fields = object.__getattribute__(instance, "__dataclass_fields__") except AttributeError: return None return dataclass_fields.get(name) def __post_init__(self) -> None: """Executes initial field transformations and sets the initialization flag to lock the instance for immutability.""" object.__setattr__(self, "_drinx_initialized", False) instance_dict = object.__getattribute__(self, "__dict__") if instance_dict.get("_drinx_active_logic_applied", False): object.__setattr__(self, "_drinx_initialized", True) return object.__setattr__(self, "_drinx_active_logic_applied", True) for dc_field in dataclasses.fields(self): try: value = object.__getattribute__(self, dc_field.name) except AttributeError: continue callbacks = DataClass._normalize_callbacks( dc_field.metadata.get(DRINX_ON_SETATTR, ()) ) value = DataClass._run_callbacks(value, callbacks) object.__setattr__(self, dc_field.name, value) object.__setattr__(self, "_drinx_initialized", True) def __setattr__(self, name: str, value: Any) -> None: """Prevents modifications after initialization and applies setter callbacks during the initial phase.""" try: is_initialized = object.__getattribute__(self, "_drinx_initialized") except AttributeError: is_initialized = False if is_initialized: raise dataclasses.FrozenInstanceError(f"cannot assign to field {name!r}") dc_field = DataClass._get_field_definition(self, name) if dc_field is None: try: instance_dict = object.__getattribute__(self, "__dict__") except AttributeError: instance_dict = {} if name not in instance_dict and not hasattr(type(self), name): raise AttributeError( f"{type(self).__name__!r} has no attribute {name!r}" ) object.__setattr__(self, name, value) return callbacks = DataClass._normalize_callbacks( dc_field.metadata.get(DRINX_ON_SETATTR, ()) ) value = DataClass._run_callbacks(value, callbacks) object.__setattr__(self, name, value) def __getattribute__(self, name: str) -> Any: """Intercepts attribute access to apply getter callbacks for transformations like automatic unfreezing.""" value = object.__getattribute__(self, name) if name.startswith("__") and name.endswith("__"): return value dc_field = DataClass._get_field_definition(self, name) if dc_field is None: return value callbacks = DataClass._normalize_callbacks( dc_field.metadata.get(DRINX_ON_GETATTR, ()) ) return DataClass._run_callbacks(value, callbacks) @staticmethod def _parse_operations(s: str) -> list[tuple[str | int, str]]: """Parse a path string into a sequence of typed operations. Splits an ``aset``-style path such as ``"a->b->[0]->['key']"`` into an ordered list of ``(operand, operation_type)`` pairs understood by :meth:`aset`. Operation types: * ``"attribute"`` — attribute access (``getattr``). Operand is the attribute name as a :class:`str`. * ``"index"`` — integer subscript (``obj[n]``). Operand is an :class:`int`. * ``"key"`` — string subscript (``obj['k']``). Operand is a :class:`str`. Args: s: Path string. Steps are separated by ``"->"``. Integer indices are written as ``[n]`` and string keys as ``['k']``. Returns: Ordered list of ``(operand, operation_type)`` pairs. Raises: ValueError: If *s* is empty, malformed, or contains invalid identifiers or bracket expressions. """ if not s: raise ValueError("Empty string is not valid") operations = [] i = 0 while i < len(s): if i > 0: # Expect "->" separator if not s[i:].startswith("->"): raise ValueError(f"Expected '->' at position {i}") i += 2 # Skip "->" if i >= len(s): raise ValueError("String ends with '->'") # Parse the next operation if s[i] == "[": # Find the closing bracket j = i + 1 while j < len(s) and s[j] != "]": j += 1 if j >= len(s): raise ValueError(f"Unclosed bracket starting at position {i}") bracket_content = s[i + 1 : j].strip() # Determine if it's an integer or string if bracket_content.isdigit() or ( bracket_content.startswith("-") and bracket_content[1:].isdigit() ): operations.append((int(bracket_content), "index")) elif bracket_content.startswith("'") and bracket_content.endswith("'"): # Extract string content if len(bracket_content) < 2: raise ValueError( f"Invalid string format in brackets: [{bracket_content}]" ) string_content = bracket_content[1:-1] # Check for forbidden characters if "'" in string_content: raise ValueError( f"String keys cannot contain single quotes: '{string_content}'" ) if "[" in string_content or "]" in string_content: raise ValueError( f"String keys cannot contain square brackets: '{string_content}'" ) operations.append((string_content, "key")) else: raise ValueError(f"Invalid bracket content: [{bracket_content}]") i = j + 1 else: # Parse attribute name j = i while j < len(s) and s[j : j + 2] != "->": j += 1 attr_name = s[i:j] # Validate attribute name if not attr_name: raise ValueError(f"Empty attribute at position {i}") # Check if it's a valid Python identifier if not attr_name.isidentifier(): raise ValueError(f"Invalid attribute name: '{attr_name}'") operations.append((attr_name, "attribute")) i = j return operations def _traverse_path( self, ops: list[tuple[str | int, str]], ) -> list[Any]: """Top-down traversal along *ops*, returning all intermediate parents. Returns a list of length ``len(ops)`` where index ``i`` is the parent object on which ``ops[i]`` should be applied. The list always starts with ``self``. Navigates through ``ops[:-1]`` (all but the final step), so the caller is responsible for handling the final operation. All intermediate missing attributes/keys raise ``Exception``. Args: ops: Parsed operation list from :meth:`_parse_operations`. Returns: List of parent objects, one per operation. Raises: Exception: If any intermediate attribute or key is missing. """ attr_list: list[Any] = [self] current_parent: Any = self for op, op_type in ops[:-1]: if op_type == "attribute": if not hasattr(current_parent, str(op)): raise Exception( f"Attribute: {op} does not exist for {current_parent.__class__}" ) current_parent = getattr(current_parent, str(op)) elif op_type == "index": if not hasattr(current_parent, "__getitem__"): raise Exception( f"{current_parent.__class__} does not implement __getitem__" ) current_parent = current_parent[int(op)] elif op_type == "key": if not hasattr(current_parent, "__getitem__"): raise Exception( f"{current_parent.__class__} does not implement __getitem__" ) if op not in current_parent: raise Exception(f"Key: {op} does not exist for {current_parent}") current_parent = current_parent[op] else: raise Exception( f"Invalid operation type: {op_type}. This is an internal bug!" ) attr_list.append(current_parent) return attr_list @staticmethod def _copy_and_set(obj: Any, field_name: str, value: Any) -> Any: cls = type(obj) new_obj = object.__new__(cls) instance_dict = object.__getattribute__(obj, "__dict__") for k, v in instance_dict.items(): object.__setattr__(new_obj, k, v) object.__setattr__(new_obj, field_name, value) return new_obj
[docs] def aset( self, attr_name: str, val: Any, create_new_ok: bool = False, allow_private: bool = False, bypass_callbacks: bool = False, ) -> Self: """Sets an attribute of this class. In contrast to the classical .at[].set(), this method updates the class attribute directly and does not only operate on jax pytree leaf nodes. Instead, replaces the full attribute with the new value. The attribute can either be the attribute name of this class, or for nested classes it can also be the attribute name of a class, which itself is an attribute of this class. The syntax for this operation could look like this: "a->b->[0]->['name']". Here, the current class has an attribute a, which has an attribute b, which is a list, which we index at index 0, which is an element of type dictionary, which we index using the dictionary key 'name'. Note that dictionary keys cannot contain square brackets or single quotes (even if they are escaped). Args: attr_name (str): Name of attribute to set val (Any): Value to set the attribute to create_new_ok (bool, optional): If false (default), throw an error if the attribute does not exist. If true, creates a new attribute if the attribute name does not exist yet. bypass_callbacks (bool, optional): If True, skip ``on_setattr`` callbacks for all attribute operations in the path. If False (default), each attribute write runs the callbacks registered on that field, mirroring the behaviour of ``__setattr__`` during ``__init__``. Returns: Self: Updated instance with new attribute value """ ops = self._parse_operations(attr_name) # 1. Top-down traversal: Find final attribute and save intermediate parents attr_list = self._traverse_path(ops) final_op, final_op_type = ops[-1] parent = attr_list[-1] # Validate the final step (respecting create_new_ok) if final_op_type == "attribute": if dataclasses.is_dataclass(parent): dc_field_names = {f.name for f in dataclasses.fields(parent)} if str(final_op) not in dc_field_names and not create_new_ok: raise Exception( f"Attribute: {final_op} does not exist for {parent.__class__}" ) elif not hasattr(parent, str(final_op)) and not create_new_ok: raise Exception( f"Attribute: {final_op} does not exist for {parent.__class__}" ) elif final_op_type == "index": if not hasattr(parent, "__getitem__"): raise Exception(f"{parent.__class__} does not implement __getitem__") elif final_op_type == "key": if not hasattr(parent, "__getitem__"): raise Exception(f"{parent.__class__} does not implement __getitem__") if final_op not in parent: if not create_new_ok: raise Exception(f"Key: {final_op} does not exist for {parent}") # 2. Bottom-up copy: Set attributes functionally returning a brand-new top-level instance cur_attr = val for idx in list(range(len(attr_list)))[::-1]: op, op_type = ops[idx] current_parent = attr_list[idx] if op_type == "attribute": if not dataclasses.is_dataclass(current_parent): raise Exception( f"Can only set attribute functionally on a dataclass, but got {current_parent.__class__}" ) dc_fields = {f.name: f for f in dataclasses.fields(current_parent)} target_field = dc_fields.get(str(op)) if target_field is None: raise TypeError( f"Field {str(op)!r} is not a dataclass field of {type(current_parent).__name__!r}." ) if not target_field.init and not allow_private: raise TypeError( f"Field {str(op)!r} has init=False (non-init/private field). " "Pass allow_private=True to allow updating non-init fields." ) if not bypass_callbacks: callbacks = DataClass._normalize_callbacks( target_field.metadata.get(DRINX_ON_SETATTR, ()) ) cur_attr = DataClass._run_callbacks(cur_attr, callbacks) cur_attr = DataClass._copy_and_set(current_parent, str(op), cur_attr) elif op_type in ("index", "key"): if not hasattr(current_parent, "copy"): raise Exception( f"Target {current_parent.__class__} must implement a .copy() method for functional updates." ) # Copy the dictionary/list to avoid mutating the original frozen structure cpy = current_parent.copy() if op_type == "index": cpy[int(op)] = cur_attr else: cpy[op] = cur_attr cur_attr = cpy else: raise Exception( f"Invalid operation type: {op_type}. This is an internal bug!" ) assert cur_attr.__class__ == self.__class__ return cur_attr
[docs] def aset_inplace( self, attr_name: str, val: Any, create_new_ok: bool = False, bypass_callbacks: bool = False, ) -> None: """Sets an attribute of this dataclass **in place** by bypassing the frozen restriction via ``object.__setattr__``. .. warning:: **This method is NOT functional and is potentially very unsafe.** It mutates ``self`` directly, violating the immutability contract that JAX pytree registration relies on. Using it outside of ``__post_init__`` — or on an object that is already part of a JAX computation — can lead to silent correctness bugs, broken JAX caches, and undefined behaviour. Prefer :meth:`aset` for all normal use. The primary intended use case is setting derived or cached fields inside ``__post_init__``, before the instance is passed to JAX. Supports the same path syntax as :meth:`aset`: ``"a->b->[0]->['key']"``. Args: attr_name: Path string (see :meth:`aset` for syntax). val: Value to assign at the target location. create_new_ok (bool, optional): If False (default), raise if the target attribute or dictionary key does not already exist. If True, allow setting new attributes or creating new dictionary keys. bypass_callbacks (bool, optional): If True, skip ``on_setattr`` callbacks. If False, fire the callbacks registered on the target field before writing, mirroring the behaviour of :meth:`aset` (default). """ ops = self._parse_operations(attr_name) attr_list = self._traverse_path(ops) final_op, final_op_type = ops[-1] parent = attr_list[-1] if final_op_type == "attribute": if dataclasses.is_dataclass(parent): dc_field_names = {f.name for f in dataclasses.fields(parent)} if str(final_op) not in dc_field_names and not create_new_ok: raise Exception( f"Attribute: {final_op} does not exist for {parent.__class__}" ) elif not hasattr(parent, str(final_op)) and not create_new_ok: raise Exception( f"Attribute: {final_op} does not exist for {parent.__class__}" ) elif final_op_type == "key": if not hasattr(parent, "__getitem__"): raise Exception(f"{parent.__class__} does not implement __getitem__") if final_op not in parent: if not create_new_ok: raise Exception(f"Key: {final_op} does not exist for {parent}") if final_op_type == "attribute" and not bypass_callbacks: if dataclasses.is_dataclass(parent): dc_fields = {f.name: f for f in dataclasses.fields(parent)} target_field = dc_fields.get(str(final_op)) if target_field is not None: callbacks = DataClass._normalize_callbacks( target_field.metadata.get(DRINX_ON_SETATTR, ()) ) val = DataClass._run_callbacks(val, callbacks) if final_op_type == "attribute": object.__setattr__(parent, str(final_op), val) elif final_op_type == "index": parent[int(final_op)] = val elif final_op_type == "key": parent[final_op] = val
@property def at(self) -> _AtProxy[Self]: """Returns an ``_AtProxy`` for JAX-style ``.at[].set()`` fluent updates. Example:: tree = tree.at["weights"].set(new_weights) tree = tree.at["layer"][0].set(new_layer) mask = jax.tree_map(lambda x: x > 0, tree) tree = tree.at[mask].set(0.0) """ return _AtProxy(self)
[docs] def updated_copy(self, **kwargs: Any) -> Self: """Returns an updated copy of the tree with modified top-level attributes. Args: **kwargs: Dictionary mapping immediate attribute names to their new values. Returns: Self: A newly instantiated object with the updated attributes. """ # Directly utilize dataclasses.replace for standard functional updates return dataclasses.replace(self, **kwargs)
_DC = TypeVar("_DC", bound="DataClass") class _AtIndexer(Generic[_DC]): """Accumulates a path of keys/indices and dispatches ``.set()`` on a DataClass.""" def __init__(self, obj: _DC, path: list[str | int | Any]) -> None: self._obj = obj self._path = path def __getitem__(self, key: Any) -> "_AtIndexer[_DC]": return _AtIndexer(self._obj, self._path + [key]) def set( self, value: Any, allow_private: bool = False, bypass_callbacks: bool = False ) -> _DC: """Apply a functional update and return the new DataClass instance. If the accumulated path contains only ``str``/``int`` keys, delegates to :meth:`DataClass.aset`. If the single key is a DataClass instance of the same type (mask case), applies element-wise ``jnp.where``. Args: value: The new value. For mask updates this may be a scalar or a DataClass tree of the same type as the mask/object. bypass_callbacks: If True, skip ``on_setattr`` callbacks. Default False runs them. Returns: Updated DataClass instance. Raises: TypeError: If the path contains a key of an unsupported type. """ path = self._path # Mask case: single key that is an instance of the same DataClass type if len(path) == 1 and isinstance(path[0], type(self._obj)): mask = path[0] if isinstance(value, type(self._obj)): return jax.tree.map( lambda leaf, m, v: jnp.where(m, v, leaf), self._obj, mask, value, ) return jax.tree.map( lambda leaf, m: jnp.where(m, value, leaf), self._obj, mask, ) # Path case: build an aset-compatible path string parts: list[str] = [] for key in path: if isinstance(key, str): parts.append(key) elif isinstance(key, int): parts.append(f"[{key}]") else: raise TypeError( f"Unsupported key type in .at[] path: {type(key).__name__!r}. " "Expected str, int, or a DataClass mask." ) path_str = "->".join(parts) return self._obj.aset( path_str, value, allow_private=allow_private, bypass_callbacks=bypass_callbacks, ) class _AtProxy(Generic[_DC]): """Returned by ``DataClass.at``; entry point for ``.at[key].set()`` syntax.""" def __init__(self, obj: _DC) -> None: self._obj = obj def __getitem__(self, key: Any) -> _AtIndexer[_DC]: return _AtIndexer(self._obj, [key])