drinx.DataClass#

class drinx.DataClass[source]#

Bases: object

Base class alternative to the @drinx.dataclass decorator.

Subclassing DataClass automatically applies the @drinx.dataclass transform, registering the subclass as a frozen dataclass and a JAX pytree node. Fields annotated with static_field() (or field(static=True)) are placed in the pytree auxiliary data (not traced by JAX); all other fields become pytree leaves.

Usage:

class MyModel(DataClass):
    weights: jax.Array
    learning_rate: float = static_field(default=1e-3)

Dataclass keyword arguments (init, repr, eq, etc.) can be forwarded via the class definition:

class MyModel(DataClass, order=True, kw_only=True):
    ...

Also provides aset() for functional nested updates and updated_copy() as a convenience wrapper around dataclasses.replace().

__init__()#

Methods

__init__()

aset(attr_name, val[, create_new_ok, ...])

Sets an attribute of this class.

aset_inplace(attr_name, val[, ...])

Sets an attribute of this dataclass in place by bypassing the frozen restriction via object.__setattr__.

updated_copy(**kwargs)

Returns an updated copy of the tree with modified top-level attributes.

Attributes

at

Returns an _AtProxy for JAX-style .at[].set() fluent updates.

aset(attr_name, val, create_new_ok=False, allow_private=False, bypass_callbacks=False)[source]#

Sets an attribute of this class. In contrast to the classical .at[].set(), this method updates the class attribute directly and does not only operate on jax pytree leaf nodes. Instead, replaces the full attribute with the new value.

The attribute can either be the attribute name of this class, or for nested classes it can also be the attribute name of a class, which itself is an attribute of this class. The syntax for this operation could look like this: “a->b->[0]->[‘name’]”. Here, the current class has an attribute a, which has an attribute b, which is a list, which we index at index 0, which is an element of type dictionary, which we index using the dictionary key ‘name’.

Note that dictionary keys cannot contain square brackets or single quotes (even if they are escaped).

Parameters:
  • attr_name (str) – Name of attribute to set

  • val (Any) – Value to set the attribute to

  • create_new_ok (bool) – If false (default), throw an error if the attribute does not exist. If true, creates a new attribute if the attribute name does not exist yet.

  • bypass_callbacks (bool) – If True, skip on_setattr callbacks for all attribute operations in the path. If False (default), each attribute write runs the callbacks registered on that field, mirroring the behaviour of __setattr__ during __init__.

Returns:

Updated instance with new attribute value

Return type:

Self

aset_inplace(attr_name, val, create_new_ok=False, bypass_callbacks=False)[source]#

Sets an attribute of this dataclass in place by bypassing the frozen restriction via object.__setattr__.

Warning

This method is NOT functional and is potentially very unsafe. It mutates self directly, violating the immutability contract that JAX pytree registration relies on. Using it outside of __post_init__ — or on an object that is already part of a JAX computation — can lead to silent correctness bugs, broken JAX caches, and undefined behaviour. Prefer aset() for all normal use.

The primary intended use case is setting derived or cached fields inside __post_init__, before the instance is passed to JAX.

Supports the same path syntax as aset(): "a->b->[0]->['key']".

Parameters:
  • attr_name (str) – Path string (see aset() for syntax).

  • val (Any) – Value to assign at the target location.

  • create_new_ok (bool) – If False (default), raise if the target attribute or dictionary key does not already exist. If True, allow setting new attributes or creating new dictionary keys.

  • bypass_callbacks (bool) – If True, skip on_setattr callbacks. If False, fire the callbacks registered on the target field before writing, mirroring the behaviour of aset() (default).

Return type:

None

property at: _AtProxy[Self]#

Returns an _AtProxy for JAX-style .at[].set() fluent updates.

Example:

tree = tree.at["weights"].set(new_weights)
tree = tree.at["layer"][0].set(new_layer)
mask = jax.tree_map(lambda x: x > 0, tree)
tree = tree.at[mask].set(0.0)
updated_copy(**kwargs)[source]#

Returns an updated copy of the tree with modified top-level attributes.

Parameters:

**kwargs (Any) – Dictionary mapping immediate attribute names to their new values.

Returns:

A newly instantiated object with the updated attributes.

Return type:

Self