Advanced Usage#
import jax.numpy as jnp
import drinx
from drinx import DataClass, private_field, static_field
aset_inplace for in-place updates#
Drinx dataclasses are always frozen, so you can’t assign to attributes after construction.
The one exception is inside __post_init__, where the object exists but hasn’t been
handed to JAX yet. aset_inplace gives you a path-aware escape hatch for this window.
Warning:
aset_inplacemutatesselfdirectly. Use it only inside__post_init__. Everywhere else, useasetfor safe functional updates.
Deriving a cached field#
Use private_field() to declare a field that is excluded from __init__
and computed from other fields in __post_init__.
class LinearLayer(DataClass):
weights: jnp.ndarray
bias: jnp.ndarray
# Derived — not passed to __init__, set in __post_init__
n_params: int = private_field()
def __post_init__(self) -> None:
total = self.weights.size + self.bias.size
self.aset_inplace("n_params", total)
layer = LinearLayer(weights=jnp.ones((4, 8)), bias=jnp.zeros((8,)))
print(layer.n_params) # 40
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
40
Validation in __post_init__#
Compute a normalised version of a field and raise if the input is degenerate.
class UnitVector(DataClass):
vec: jnp.ndarray
def __post_init__(self) -> None:
norm = jnp.linalg.norm(self.vec)
if float(norm) == 0.0:
raise ValueError("vec must be non-zero")
self.aset_inplace("vec", self.vec / norm)
u = UnitVector(vec=jnp.array([3.0, 0.0, 4.0]))
print(u.vec) # [0.6 0. 0.8] — normalised in place
try:
UnitVector(vec=jnp.zeros(3))
except ValueError as e:
print(f"Caught: {e}")
[0.6 0. 0.8]
Caught: vec must be non-zero
Nested path#
aset_inplace accepts the same -> path syntax as aset, so you can reach
into a mutable container (list, dict) stored as a field.
class Stats(DataClass):
values: list
summary: dict = private_field()
def __post_init__(self) -> None:
arr = jnp.array(self.values)
self.aset_inplace("summary", {"mean": float(arr.mean()), "std": float(arr.std())})
s = Stats(values=[1.0, 2.0, 3.0, 4.0, 5.0])
print(s.summary)
{'mean': 3.0, 'std': 1.4142135381698608}