from __future__ import annotations
from collections.abc import Sequence as ABCSequence
import dataclasses
import jax
import jax.numpy as jnp
from drinx.transform import dataclass
from typing import dataclass_transform, Any, Callable, Generic, Sequence, Self, TypeVar
from dataclasses import field as orig_field
from drinx.attribute import (
DRINX_ON_GETATTR,
DRINX_ON_SETATTR,
field,
private_field,
static_field,
static_private_field,
)
[docs]
@dataclass_transform(
field_specifiers=(
orig_field,
field,
static_field,
private_field,
static_private_field,
)
)
class DataClass:
"""Base class alternative to the ``@drinx.dataclass`` decorator.
Subclassing ``DataClass`` automatically applies the ``@drinx.dataclass``
transform, registering the subclass as a frozen dataclass and a JAX pytree
node. Fields annotated with :func:`static_field` (or ``field(static=True)``)
are placed in the pytree auxiliary data (not traced by JAX); all other fields
become pytree leaves.
Usage::
class MyModel(DataClass):
weights: jax.Array
learning_rate: float = static_field(default=1e-3)
Dataclass keyword arguments (``init``, ``repr``, ``eq``, etc.) can be
forwarded via the class definition::
class MyModel(DataClass, order=True, kw_only=True):
...
Also provides :meth:`aset` for functional nested updates and
:meth:`updated_copy` as a convenience wrapper around
:func:`dataclasses.replace`.
"""
def __init_subclass__(
cls,
/,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
match_args: bool = True,
kw_only: bool = False,
slots: bool = False,
weakref_slot: bool = False,
):
"""Apply the ``@drinx.dataclass`` transform to every subclass automatically.
Called by Python whenever a new subclass of :class:`DataClass` is
defined. Accepts the same keyword arguments as the standard
:func:`dataclasses.dataclass` decorator (except ``frozen``, which is
always ``True``).
Args:
init: Generate ``__init__``.
repr: Generate ``__repr__``.
eq: Generate ``__eq__`` and ``__hash__``.
order: Generate comparison methods (``<``, ``<=``, ``>``, ``>=``).
unsafe_hash: Force generation of ``__hash__`` even when ``eq=True``.
match_args: Set ``__match_args__`` for structural pattern matching.
kw_only: Make all fields keyword-only in ``__init__``.
slots: Not supported; ignored. Slot-based dataclasses are excluded
because ``__slots__`` attribute-access speedups are negligible
compared to JAX kernel dispatch overhead, and supporting them
would add significant complexity to pytree flatten/unflatten.
weakref_slot: Add a ``__weakref__`` slot (ignored; kept for API
compatibility).
"""
del slots, weakref_slot
super().__init_subclass__()
user_post_init = cls.__dict__.get("__post_init__")
if user_post_init is not None and user_post_init is not DataClass.__post_init__:
setattr(
cls,
"__post_init__",
DataClass._wrap_user_post_init(user_post_init),
)
# Programmatically apply our custom dataclass wrapper to the subclass.
dataclass_transform = dataclass(
init=init,
repr=repr,
eq=eq,
order=order,
unsafe_hash=unsafe_hash,
match_args=match_args,
kw_only=kw_only,
weakref_slot=False,
)
dataclass_transform(cls)
setattr(cls, "__setattr__", DataClass.__setattr__)
setattr(cls, "__getattribute__", DataClass.__getattribute__)
@staticmethod
def _wrap_user_post_init(
user_post_init: Callable[..., Any],
) -> Callable[..., None]:
"""Wraps the user-defined __post_init__ to ensure framework-level initialization logic runs afterward."""
def wrapped(self: DataClass, *args: Any, **kwargs: Any) -> None:
user_post_init(self, *args, **kwargs)
DataClass.__post_init__(self)
return wrapped
@staticmethod
def _normalize_callbacks(callbacks: Any) -> tuple[Callable[..., Any], ...]:
"""Standardizes callback metadata into a consistent tuple of callable functions."""
if callbacks is None:
return ()
if isinstance(callbacks, ABCSequence) and not isinstance(
callbacks, (str, bytes)
):
return tuple(callbacks)
return (callbacks,)
@staticmethod
def _run_callbacks(value: Any, callbacks: Sequence[Callable[..., Any]]) -> Any:
"""Sequentially applies a series of callback functions to a value, passing the result of one to the next."""
result = value
for callback in callbacks:
callback_result = callback(result)
if callback_result is not None:
result = callback_result
return result
@staticmethod
def _get_field_definition(instance: DataClass, name: str) -> Any:
"""Safely retrieves the dataclass field metadata for a specific attribute name using low-level access."""
try:
dataclass_fields = object.__getattribute__(instance, "__dataclass_fields__")
except AttributeError:
return None
return dataclass_fields.get(name)
def __post_init__(self) -> None:
"""Executes initial field transformations and sets the initialization flag to lock the instance for immutability."""
object.__setattr__(self, "_drinx_initialized", False)
instance_dict = object.__getattribute__(self, "__dict__")
if instance_dict.get("_drinx_active_logic_applied", False):
object.__setattr__(self, "_drinx_initialized", True)
return
object.__setattr__(self, "_drinx_active_logic_applied", True)
for dc_field in dataclasses.fields(self):
try:
value = object.__getattribute__(self, dc_field.name)
except AttributeError:
continue
callbacks = DataClass._normalize_callbacks(
dc_field.metadata.get(DRINX_ON_SETATTR, ())
)
value = DataClass._run_callbacks(value, callbacks)
object.__setattr__(self, dc_field.name, value)
object.__setattr__(self, "_drinx_initialized", True)
def __setattr__(self, name: str, value: Any) -> None:
"""Prevents modifications after initialization and applies setter callbacks during the initial phase."""
try:
is_initialized = object.__getattribute__(self, "_drinx_initialized")
except AttributeError:
is_initialized = False
if is_initialized:
raise dataclasses.FrozenInstanceError(f"cannot assign to field {name!r}")
dc_field = DataClass._get_field_definition(self, name)
if dc_field is None:
try:
instance_dict = object.__getattribute__(self, "__dict__")
except AttributeError:
instance_dict = {}
if name not in instance_dict and not hasattr(type(self), name):
raise AttributeError(
f"{type(self).__name__!r} has no attribute {name!r}"
)
object.__setattr__(self, name, value)
return
callbacks = DataClass._normalize_callbacks(
dc_field.metadata.get(DRINX_ON_SETATTR, ())
)
value = DataClass._run_callbacks(value, callbacks)
object.__setattr__(self, name, value)
def __getattribute__(self, name: str) -> Any:
"""Intercepts attribute access to apply getter callbacks for transformations like automatic unfreezing."""
value = object.__getattribute__(self, name)
if name.startswith("__") and name.endswith("__"):
return value
dc_field = DataClass._get_field_definition(self, name)
if dc_field is None:
return value
callbacks = DataClass._normalize_callbacks(
dc_field.metadata.get(DRINX_ON_GETATTR, ())
)
return DataClass._run_callbacks(value, callbacks)
@staticmethod
def _parse_operations(s: str) -> list[tuple[str | int, str]]:
"""Parse a path string into a sequence of typed operations.
Splits an ``aset``-style path such as ``"a->b->[0]->['key']"`` into an
ordered list of ``(operand, operation_type)`` pairs understood by
:meth:`aset`.
Operation types:
* ``"attribute"`` — attribute access (``getattr``). Operand is the
attribute name as a :class:`str`.
* ``"index"`` — integer subscript (``obj[n]``). Operand is an
:class:`int`.
* ``"key"`` — string subscript (``obj['k']``). Operand is a
:class:`str`.
Args:
s: Path string. Steps are separated by ``"->"``. Integer indices
are written as ``[n]`` and string keys as ``['k']``.
Returns:
Ordered list of ``(operand, operation_type)`` pairs.
Raises:
ValueError: If *s* is empty, malformed, or contains invalid
identifiers or bracket expressions.
"""
if not s:
raise ValueError("Empty string is not valid")
operations = []
i = 0
while i < len(s):
if i > 0:
# Expect "->" separator
if not s[i:].startswith("->"):
raise ValueError(f"Expected '->' at position {i}")
i += 2 # Skip "->"
if i >= len(s):
raise ValueError("String ends with '->'")
# Parse the next operation
if s[i] == "[":
# Find the closing bracket
j = i + 1
while j < len(s) and s[j] != "]":
j += 1
if j >= len(s):
raise ValueError(f"Unclosed bracket starting at position {i}")
bracket_content = s[i + 1 : j].strip()
# Determine if it's an integer or string
if bracket_content.isdigit() or (
bracket_content.startswith("-") and bracket_content[1:].isdigit()
):
operations.append((int(bracket_content), "index"))
elif bracket_content.startswith("'") and bracket_content.endswith("'"):
# Extract string content
if len(bracket_content) < 2:
raise ValueError(
f"Invalid string format in brackets: [{bracket_content}]"
)
string_content = bracket_content[1:-1]
# Check for forbidden characters
if "'" in string_content:
raise ValueError(
f"String keys cannot contain single quotes: '{string_content}'"
)
if "[" in string_content or "]" in string_content:
raise ValueError(
f"String keys cannot contain square brackets: '{string_content}'"
)
operations.append((string_content, "key"))
else:
raise ValueError(f"Invalid bracket content: [{bracket_content}]")
i = j + 1
else:
# Parse attribute name
j = i
while j < len(s) and s[j : j + 2] != "->":
j += 1
attr_name = s[i:j]
# Validate attribute name
if not attr_name:
raise ValueError(f"Empty attribute at position {i}")
# Check if it's a valid Python identifier
if not attr_name.isidentifier():
raise ValueError(f"Invalid attribute name: '{attr_name}'")
operations.append((attr_name, "attribute"))
i = j
return operations
def _traverse_path(
self,
ops: list[tuple[str | int, str]],
) -> list[Any]:
"""Top-down traversal along *ops*, returning all intermediate parents.
Returns a list of length ``len(ops)`` where index ``i`` is the parent
object on which ``ops[i]`` should be applied. The list always starts
with ``self``.
Navigates through ``ops[:-1]`` (all but the final step), so the caller
is responsible for handling the final operation. All intermediate
missing attributes/keys raise ``Exception``.
Args:
ops: Parsed operation list from :meth:`_parse_operations`.
Returns:
List of parent objects, one per operation.
Raises:
Exception: If any intermediate attribute or key is missing.
"""
attr_list: list[Any] = [self]
current_parent: Any = self
for op, op_type in ops[:-1]:
if op_type == "attribute":
if not hasattr(current_parent, str(op)):
raise Exception(
f"Attribute: {op} does not exist for {current_parent.__class__}"
)
current_parent = getattr(current_parent, str(op))
elif op_type == "index":
if not hasattr(current_parent, "__getitem__"):
raise Exception(
f"{current_parent.__class__} does not implement __getitem__"
)
current_parent = current_parent[int(op)]
elif op_type == "key":
if not hasattr(current_parent, "__getitem__"):
raise Exception(
f"{current_parent.__class__} does not implement __getitem__"
)
if op not in current_parent:
raise Exception(f"Key: {op} does not exist for {current_parent}")
current_parent = current_parent[op]
else:
raise Exception(
f"Invalid operation type: {op_type}. This is an internal bug!"
)
attr_list.append(current_parent)
return attr_list
@staticmethod
def _copy_and_set(obj: Any, field_name: str, value: Any) -> Any:
cls = type(obj)
new_obj = object.__new__(cls)
instance_dict = object.__getattribute__(obj, "__dict__")
for k, v in instance_dict.items():
object.__setattr__(new_obj, k, v)
object.__setattr__(new_obj, field_name, value)
return new_obj
[docs]
def aset(
self,
attr_name: str,
val: Any,
create_new_ok: bool = False,
allow_private: bool = False,
bypass_callbacks: bool = False,
) -> Self:
"""Sets an attribute of this class. In contrast to the classical .at[].set(), this method updates the class
attribute directly and does not only operate on jax pytree leaf nodes. Instead, replaces the full attribute
with the new value.
The attribute can either be the attribute name of this class, or for nested classes it can also be the
attribute name of a class, which itself is an attribute of this class. The syntax for this operation could
look like this: "a->b->[0]->['name']". Here, the current class has an attribute a, which has an attribute b,
which is a list, which we index at index 0, which is an element of type dictionary, which we index using
the dictionary key 'name'.
Note that dictionary keys cannot contain square brackets or single quotes (even if they are escaped).
Args:
attr_name (str): Name of attribute to set
val (Any): Value to set the attribute to
create_new_ok (bool, optional): If false (default), throw an error if the attribute does not exist.
If true, creates a new attribute if the attribute name does not exist yet.
bypass_callbacks (bool, optional): If True, skip ``on_setattr`` callbacks for all attribute
operations in the path. If False (default), each attribute write runs the callbacks
registered on that field, mirroring the behaviour of ``__setattr__`` during ``__init__``.
Returns:
Self: Updated instance with new attribute value
"""
ops = self._parse_operations(attr_name)
# 1. Top-down traversal: Find final attribute and save intermediate parents
attr_list = self._traverse_path(ops)
final_op, final_op_type = ops[-1]
parent = attr_list[-1]
# Validate the final step (respecting create_new_ok)
if final_op_type == "attribute":
if dataclasses.is_dataclass(parent):
dc_field_names = {f.name for f in dataclasses.fields(parent)}
if str(final_op) not in dc_field_names and not create_new_ok:
raise Exception(
f"Attribute: {final_op} does not exist for {parent.__class__}"
)
elif not hasattr(parent, str(final_op)) and not create_new_ok:
raise Exception(
f"Attribute: {final_op} does not exist for {parent.__class__}"
)
elif final_op_type == "index":
if not hasattr(parent, "__getitem__"):
raise Exception(f"{parent.__class__} does not implement __getitem__")
elif final_op_type == "key":
if not hasattr(parent, "__getitem__"):
raise Exception(f"{parent.__class__} does not implement __getitem__")
if final_op not in parent:
if not create_new_ok:
raise Exception(f"Key: {final_op} does not exist for {parent}")
# 2. Bottom-up copy: Set attributes functionally returning a brand-new top-level instance
cur_attr = val
for idx in list(range(len(attr_list)))[::-1]:
op, op_type = ops[idx]
current_parent = attr_list[idx]
if op_type == "attribute":
if not dataclasses.is_dataclass(current_parent):
raise Exception(
f"Can only set attribute functionally on a dataclass, but got {current_parent.__class__}"
)
dc_fields = {f.name: f for f in dataclasses.fields(current_parent)}
target_field = dc_fields.get(str(op))
if target_field is None:
raise TypeError(
f"Field {str(op)!r} is not a dataclass field of {type(current_parent).__name__!r}."
)
if not target_field.init and not allow_private:
raise TypeError(
f"Field {str(op)!r} has init=False (non-init/private field). "
"Pass allow_private=True to allow updating non-init fields."
)
if not bypass_callbacks:
callbacks = DataClass._normalize_callbacks(
target_field.metadata.get(DRINX_ON_SETATTR, ())
)
cur_attr = DataClass._run_callbacks(cur_attr, callbacks)
cur_attr = DataClass._copy_and_set(current_parent, str(op), cur_attr)
elif op_type in ("index", "key"):
if not hasattr(current_parent, "copy"):
raise Exception(
f"Target {current_parent.__class__} must implement a .copy() method for functional updates."
)
# Copy the dictionary/list to avoid mutating the original frozen structure
cpy = current_parent.copy()
if op_type == "index":
cpy[int(op)] = cur_attr
else:
cpy[op] = cur_attr
cur_attr = cpy
else:
raise Exception(
f"Invalid operation type: {op_type}. This is an internal bug!"
)
assert cur_attr.__class__ == self.__class__
return cur_attr
[docs]
def aset_inplace(
self,
attr_name: str,
val: Any,
create_new_ok: bool = False,
bypass_callbacks: bool = False,
) -> None:
"""Sets an attribute of this dataclass **in place** by bypassing the frozen
restriction via ``object.__setattr__``.
.. warning::
**This method is NOT functional and is potentially very unsafe.**
It mutates ``self`` directly, violating the immutability contract that
JAX pytree registration relies on. Using it outside of
``__post_init__`` — or on an object that is already part of a JAX
computation — can lead to silent correctness bugs, broken JAX caches,
and undefined behaviour. Prefer :meth:`aset` for all normal use.
The primary intended use case is setting derived or cached fields inside
``__post_init__``, before the instance is passed to JAX.
Supports the same path syntax as :meth:`aset`:
``"a->b->[0]->['key']"``.
Args:
attr_name: Path string (see :meth:`aset` for syntax).
val: Value to assign at the target location.
create_new_ok (bool, optional): If False (default), raise if the target
attribute or dictionary key does not already exist. If True, allow
setting new attributes or creating new dictionary keys.
bypass_callbacks (bool, optional): If True, skip ``on_setattr``
callbacks. If False, fire the callbacks registered on the target field before
writing, mirroring the behaviour of :meth:`aset` (default).
"""
ops = self._parse_operations(attr_name)
attr_list = self._traverse_path(ops)
final_op, final_op_type = ops[-1]
parent = attr_list[-1]
if final_op_type == "attribute":
if dataclasses.is_dataclass(parent):
dc_field_names = {f.name for f in dataclasses.fields(parent)}
if str(final_op) not in dc_field_names and not create_new_ok:
raise Exception(
f"Attribute: {final_op} does not exist for {parent.__class__}"
)
elif not hasattr(parent, str(final_op)) and not create_new_ok:
raise Exception(
f"Attribute: {final_op} does not exist for {parent.__class__}"
)
elif final_op_type == "key":
if not hasattr(parent, "__getitem__"):
raise Exception(f"{parent.__class__} does not implement __getitem__")
if final_op not in parent:
if not create_new_ok:
raise Exception(f"Key: {final_op} does not exist for {parent}")
if final_op_type == "attribute" and not bypass_callbacks:
if dataclasses.is_dataclass(parent):
dc_fields = {f.name: f for f in dataclasses.fields(parent)}
target_field = dc_fields.get(str(final_op))
if target_field is not None:
callbacks = DataClass._normalize_callbacks(
target_field.metadata.get(DRINX_ON_SETATTR, ())
)
val = DataClass._run_callbacks(val, callbacks)
if final_op_type == "attribute":
object.__setattr__(parent, str(final_op), val)
elif final_op_type == "index":
parent[int(final_op)] = val
elif final_op_type == "key":
parent[final_op] = val
@property
def at(self) -> _AtProxy[Self]:
"""Returns an ``_AtProxy`` for JAX-style ``.at[].set()`` fluent updates.
Example::
tree = tree.at["weights"].set(new_weights)
tree = tree.at["layer"][0].set(new_layer)
mask = jax.tree_map(lambda x: x > 0, tree)
tree = tree.at[mask].set(0.0)
"""
return _AtProxy(self)
[docs]
def updated_copy(self, **kwargs: Any) -> Self:
"""Returns an updated copy of the tree with modified top-level attributes.
Args:
**kwargs: Dictionary mapping immediate attribute names to their new values.
Returns:
Self: A newly instantiated object with the updated attributes.
"""
# Directly utilize dataclasses.replace for standard functional updates
return dataclasses.replace(self, **kwargs)
_DC = TypeVar("_DC", bound="DataClass")
class _AtIndexer(Generic[_DC]):
"""Accumulates a path of keys/indices and dispatches ``.set()`` on a DataClass."""
def __init__(self, obj: _DC, path: list[str | int | Any]) -> None:
self._obj = obj
self._path = path
def __getitem__(self, key: Any) -> "_AtIndexer[_DC]":
return _AtIndexer(self._obj, self._path + [key])
def set(
self, value: Any, allow_private: bool = False, bypass_callbacks: bool = False
) -> _DC:
"""Apply a functional update and return the new DataClass instance.
If the accumulated path contains only ``str``/``int`` keys, delegates to
:meth:`DataClass.aset`. If the single key is a DataClass instance of the
same type (mask case), applies element-wise ``jnp.where``.
Args:
value: The new value. For mask updates this may be a scalar or a
DataClass tree of the same type as the mask/object.
bypass_callbacks: If True, skip ``on_setattr`` callbacks. Default False runs them.
Returns:
Updated DataClass instance.
Raises:
TypeError: If the path contains a key of an unsupported type.
"""
path = self._path
# Mask case: single key that is an instance of the same DataClass type
if len(path) == 1 and isinstance(path[0], type(self._obj)):
mask = path[0]
if isinstance(value, type(self._obj)):
return jax.tree.map(
lambda leaf, m, v: jnp.where(m, v, leaf),
self._obj,
mask,
value,
)
return jax.tree.map(
lambda leaf, m: jnp.where(m, value, leaf),
self._obj,
mask,
)
# Path case: build an aset-compatible path string
parts: list[str] = []
for key in path:
if isinstance(key, str):
parts.append(key)
elif isinstance(key, int):
parts.append(f"[{key}]")
else:
raise TypeError(
f"Unsupported key type in .at[] path: {type(key).__name__!r}. "
"Expected str, int, or a DataClass mask."
)
path_str = "->".join(parts)
return self._obj.aset(
path_str,
value,
allow_private=allow_private,
bypass_callbacks=bypass_callbacks,
)
class _AtProxy(Generic[_DC]):
"""Returned by ``DataClass.at``; entry point for ``.at[key].set()`` syntax."""
def __init__(self, obj: _DC) -> None:
self._obj = obj
def __getitem__(self, key: Any) -> _AtIndexer[_DC]:
return _AtIndexer(self._obj, [key])