import jax
from dataclasses import dataclass as dataclass_orig
from dataclasses import fields, field as orig_field
from typing import TypeVar
from typing import dataclass_transform
from drinx.attribute import field, static_field, private_field, static_private_field
from typing import Callable, overload
T = TypeVar("T")
def _register_jax_tree(cls_: type[T]) -> type[T]:
"""Registers a class as a JAX Pytree, safely preventing double-registration."""
# Guard: If already registered (e.g., by __init_subclass__), skip re-registering
if getattr(cls_, "_jax_tree_registered", False):
return cls_
static_fields = [f.name for f in fields(cls_) if f.metadata.get("jax_static")]
dynamic_fields = [f.name for f in fields(cls_) if not f.metadata.get("jax_static")]
def flatten_with_keys(obj):
keyed_leaves = []
for f in dynamic_fields:
try:
val = getattr(obj, f)
except AttributeError:
raise AttributeError(
f"Field '{f}' of '{type(obj).__name__}' has not been set. "
"Non-init fields must be assigned a default value or set in __post_init__."
) from None
keyed_leaves.append((jax.tree_util.GetAttrKey(f), val))
aux = []
for f in static_fields:
try:
val = getattr(obj, f)
except AttributeError:
raise AttributeError(
f"Field '{f}' of '{type(obj).__name__}' has not been set. "
"Non-init fields must be assigned a default value or set in __post_init__."
) from None
aux.append(val)
return keyed_leaves, tuple(aux)
def unflatten(aux, leaves):
kwargs = {**dict(zip(static_fields, aux)), **dict(zip(dynamic_fields, leaves))}
return cls_(**kwargs)
jax.tree_util.register_pytree_with_keys(cls_, flatten_with_keys, unflatten)
cls_._jax_tree_registered = True # ty:ignore[unresolved-attribute]
return cls_
# Overload 1: For when the decorator is called WITHOUT arguments: @dataclass
@overload
def dataclass(
cls: type[T],
/,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
match_args: bool = True,
kw_only: bool = False,
slots: bool = False,
weakref_slot: bool = False,
) -> type[T]: ...
# Overload 2: For when the decorator is called WITH arguments: @dataclass(kw_only=True)
@overload
def dataclass(
cls: None = None,
/,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
match_args: bool = True,
kw_only: bool = False,
slots: bool = False,
weakref_slot: bool = False,
) -> Callable[[type[T]], type[T]]: ...
[docs]
@dataclass_transform(
field_specifiers=(
orig_field,
field,
static_field,
private_field,
static_private_field,
)
)
def dataclass(
cls: type[T] | None = None,
/,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
match_args: bool = True,
kw_only: bool = False,
slots: bool = False,
weakref_slot: bool = False,
) -> type[T] | Callable[[type[T]], type[T]]:
"""Decorator that converts a class into a frozen dataclass registered as a JAX pytree node.
Wraps :func:`dataclasses.dataclass` with ``frozen=True`` and then calls
:func:`_register_jax_tree` so the class can be used transparently with JAX
transformations (``jit``, ``vmap``, ``grad``, etc.).
Fields marked with :func:`static_field` (or ``field(static=True)``) are
placed in the pytree auxiliary data and are excluded from JAX tracing. All
other fields become pytree leaves and are traced normally.
Can be used with or without arguments::
@drinx.dataclass
class Params:
weights: jax.Array
lr: float = static_field(default=1e-3)
@drinx.dataclass(kw_only=True)
class Config:
hidden_size: int = static_field(default=128)
Args:
cls: The class to decorate when used without arguments (``@dataclass``).
``None`` when called with arguments (``@dataclass(...)``).
init: Generate ``__init__``.
repr: Generate ``__repr__``.
eq: Generate ``__eq__`` and ``__hash__``.
order: Generate comparison methods (``<``, ``<=``, ``>``, ``>=``).
unsafe_hash: Force generation of ``__hash__`` even when ``eq=True``.
match_args: Set ``__match_args__`` for structural pattern matching.
kw_only: Make all fields keyword-only in ``__init__``.
slots: Not supported. Slot-based dataclasses are excluded because
``__slots__`` attribute-access speedups are negligible compared to
JAX kernel dispatch overhead, and supporting them would add
significant complexity to pytree flatten/unflatten.
weakref_slot: Add a ``__weakref__`` slot.
Returns:
The decorated class (when ``cls`` is provided), or a one-argument
decorator (when called with keyword arguments only).
Note:
``frozen=True`` is always enforced and cannot be overridden. Mutability
would break JAX's pytree contract.
"""
del slots, weakref_slot
# The wrapper handles the actual class modification
def wrapper(cls_: type[T]) -> type[T]:
# Detect if the class was already processed by __init_subclass__
if getattr(cls_, "__dataclass_params__", None) is not None:
# If so, it has generated __setattr__ and __delattr__ methods.
# We must delete them from the class dictionary so the second
# pass doesn't throw a "Cannot overwrite attribute" TypeError.
if "__setattr__" in cls_.__dict__:
delattr(cls_, "__setattr__")
if "__delattr__" in cls_.__dict__:
delattr(cls_, "__delattr__")
decorator = dataclass_orig(
init=init,
repr=repr,
eq=eq,
order=order,
unsafe_hash=unsafe_hash,
frozen=True,
match_args=match_args,
kw_only=kw_only,
slots=False,
weakref_slot=False,
)
cls_ = decorator(cls_)
return _register_jax_tree(cls_)
if cls is None:
return wrapper
return wrapper(cls)