Advanced Usage#

import jax.numpy as jnp
import drinx
from drinx import DataClass, private_field, static_field

aset_inplace for in-place updates#

Drinx dataclasses are always frozen, so you can’t assign to attributes after construction. The one exception is inside __post_init__, where the object exists but hasn’t been handed to JAX yet. aset_inplace gives you a path-aware escape hatch for this window.

Warning: aset_inplace mutates self directly. Use it only inside __post_init__. Everywhere else, use aset for safe functional updates.

Deriving a cached field#

Use private_field() to declare a field that is excluded from __init__ and computed from other fields in __post_init__.

class LinearLayer(DataClass):
    weights: jnp.ndarray
    bias: jnp.ndarray
    # Derived — not passed to __init__, set in __post_init__
    n_params: int = private_field()

    def __post_init__(self) -> None:
        total = self.weights.size + self.bias.size
        self.aset_inplace("n_params", total)

layer = LinearLayer(weights=jnp.ones((4, 8)), bias=jnp.zeros((8,)))
print(layer.n_params)  # 40
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
40

Validation in __post_init__#

Compute a normalised version of a field and raise if the input is degenerate.

class UnitVector(DataClass):
    vec: jnp.ndarray

    def __post_init__(self) -> None:
        norm = jnp.linalg.norm(self.vec)
        if float(norm) == 0.0:
            raise ValueError("vec must be non-zero")
        self.aset_inplace("vec", self.vec / norm)

u = UnitVector(vec=jnp.array([3.0, 0.0, 4.0]))
print(u.vec)  # [0.6  0.   0.8]  — normalised in place

try:
    UnitVector(vec=jnp.zeros(3))
except ValueError as e:
    print(f"Caught: {e}")
[0.6 0.  0.8]
Caught: vec must be non-zero

Nested path#

aset_inplace accepts the same -> path syntax as aset, so you can reach into a mutable container (list, dict) stored as a field.

class Stats(DataClass):
    values: list
    summary: dict = private_field()

    def __post_init__(self) -> None:
        arr = jnp.array(self.values)
        self.aset_inplace("summary", {"mean": float(arr.mean()), "std": float(arr.std())})

s = Stats(values=[1.0, 2.0, 3.0, 4.0, 5.0])
print(s.summary)
{'mean': 3.0, 'std': 1.4142135381698608}