drinx.dataclass#
- drinx.dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)[source]#
- Overloads:
cls (type[T]), init (bool), repr (bool), eq (bool), order (bool), unsafe_hash (bool), match_args (bool), kw_only (bool), slots (bool), weakref_slot (bool) → type[T]
cls (None), init (bool), repr (bool), eq (bool), order (bool), unsafe_hash (bool), match_args (bool), kw_only (bool), slots (bool), weakref_slot (bool) → Callable[[type[T]], type[T]]
Decorator that converts a class into a frozen dataclass registered as a JAX pytree node.
Wraps
dataclasses.dataclass()withfrozen=Trueand then calls_register_jax_tree()so the class can be used transparently with JAX transformations (jit,vmap,grad, etc.).Fields marked with
static_field()(orfield(static=True)) are placed in the pytree auxiliary data and are excluded from JAX tracing. All other fields become pytree leaves and are traced normally.Can be used with or without arguments:
@drinx.dataclass class Params: weights: jax.Array lr: float = static_field(default=1e-3) @drinx.dataclass(kw_only=True) class Config: hidden_size: int = static_field(default=128)
- Parameters:
cls (
type[TypeVar(T)] |None) – The class to decorate when used without arguments (@dataclass).Nonewhen called with arguments (@dataclass(...)).init (
bool) – Generate__init__.repr (
bool) – Generate__repr__.eq (
bool) – Generate__eq__and__hash__.order (
bool) – Generate comparison methods (<,<=,>,>=).unsafe_hash (
bool) – Force generation of__hash__even wheneq=True.match_args (
bool) – Set__match_args__for structural pattern matching.kw_only (
bool) – Make all fields keyword-only in__init__.slots (
bool) – Not supported. Slot-based dataclasses are excluded because__slots__attribute-access speedups are negligible compared to JAX kernel dispatch overhead, and supporting them would add significant complexity to pytree flatten/unflatten.weakref_slot (
bool) – Add a__weakref__slot.
- Returns:
The decorated class (when
clsis provided), or a one-argument decorator (when called with keyword arguments only).
Note
frozen=Trueis always enforced and cannot be overridden. Mutability would break JAX’s pytree contract.