drinx.dataclass

Contents

drinx.dataclass#

drinx.dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)[source]#
Overloads:
  • cls (type[T]), init (bool), repr (bool), eq (bool), order (bool), unsafe_hash (bool), match_args (bool), kw_only (bool), slots (bool), weakref_slot (bool) → type[T]

  • cls (None), init (bool), repr (bool), eq (bool), order (bool), unsafe_hash (bool), match_args (bool), kw_only (bool), slots (bool), weakref_slot (bool) → Callable[[type[T]], type[T]]

Decorator that converts a class into a frozen dataclass registered as a JAX pytree node.

Wraps dataclasses.dataclass() with frozen=True and then calls _register_jax_tree() so the class can be used transparently with JAX transformations (jit, vmap, grad, etc.).

Fields marked with static_field() (or field(static=True)) are placed in the pytree auxiliary data and are excluded from JAX tracing. All other fields become pytree leaves and are traced normally.

Can be used with or without arguments:

@drinx.dataclass
class Params:
    weights: jax.Array
    lr: float = static_field(default=1e-3)

@drinx.dataclass(kw_only=True)
class Config:
    hidden_size: int = static_field(default=128)
Parameters:
  • cls (type[TypeVar(T)] | None) – The class to decorate when used without arguments (@dataclass). None when called with arguments (@dataclass(...)).

  • init (bool) – Generate __init__.

  • repr (bool) – Generate __repr__.

  • eq (bool) – Generate __eq__ and __hash__.

  • order (bool) – Generate comparison methods (<, <=, >, >=).

  • unsafe_hash (bool) – Force generation of __hash__ even when eq=True.

  • match_args (bool) – Set __match_args__ for structural pattern matching.

  • kw_only (bool) – Make all fields keyword-only in __init__.

  • slots (bool) – Not supported. Slot-based dataclasses are excluded because __slots__ attribute-access speedups are negligible compared to JAX kernel dispatch overhead, and supporting them would add significant complexity to pytree flatten/unflatten.

  • weakref_slot (bool) – Add a __weakref__ slot.

Returns:

The decorated class (when cls is provided), or a one-argument decorator (when called with keyword arguments only).

Note

frozen=True is always enforced and cannot be overridden. Mutability would break JAX’s pytree contract.