Basic Usage#
Drinx (Dataclass Registry in JAX) makes Python dataclasses work as JAX pytree nodes, so they pass through jit, grad, vmap, and other JAX transforms seamlessly.
Two usage styles are available:
Decorator:
@drinx.dataclassInheritance:
class Foo(drinx.DataClass)
import jax
import jax.numpy as jnp
import drinx
1. Decorator style#
@drinx.dataclass wraps dataclasses.dataclass and registers the class as a JAX pytree. All fields are dynamic (traced) by default.
@drinx.dataclass
class Params:
weights: jax.Array
bias: jax.Array
params = Params(weights=jnp.ones((3,)), bias=jnp.zeros((3,)))
print(params)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Params(weights=Array([1., 1., 1.], dtype=float32), bias=Array([0., 0., 0.], dtype=float32))
# jax.tree_util.tree_map works out of the box
doubled = jax.tree_util.tree_map(lambda x: x * 2, params)
print(doubled)
print(type(doubled)) # still a Params instance
Params(weights=Array([2., 2., 2.], dtype=float32), bias=Array([0., 0., 0.], dtype=float32))
<class '__main__.Params'>
2. Inheritance style#
Subclassing drinx.DataClass applies the transform automatically — no decorator needed.
class Model(drinx.DataClass):
weights: jax.Array
bias: jax.Array
model = Model(weights=jnp.ones((4, 4)), bias=jnp.zeros((4,)))
print(model)
Model(weights=Array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32), bias=Array([0., 0., 0., 0.], dtype=float32))
3. Static fields#
Mark a field as static to exclude it from JAX tracing. Static values are treated as compile-time constants by jit — changing a static field triggers recompilation.
@drinx.dataclass
class LinearLayer:
weights: jax.Array
# hidden_size is a compile-time constant — not traced by JAX
hidden_size: int = drinx.static_field(default=128)
layer = LinearLayer(weights=jnp.ones((128, 32)))
@jax.jit
def forward(layer, x):
return layer.weights[:layer.hidden_size] @ x
x = jnp.ones((32,))
result = forward(layer, x)
print(result.shape) # (128,)
(128,)
Note: Changing a static field causes
jitto recompile for the new value.
# Static field can vary per-instance at Python level
small_layer = LinearLayer(weights=jnp.ones((64, 32)), hidden_size=64)
result_small = forward(small_layer, x)
print(result_small.shape) # (64,) — recompiled for hidden_size=64
(64,)
4. JAX transforms#
4a. jax.grad#
Gradients have the same pytree structure as the input.
class State(drinx.DataClass):
x: jax.Array
step_size: float = drinx.static_field(default=0.1)
def loss(state):
return jnp.sum(state.x ** 2)
state = State(x=jnp.array([1.0, 2.0, 3.0]))
grads = jax.grad(loss)(state)
print(type(grads)) # State — same type as input
print(grads.x) # [2. 4. 6.] (gradient of sum(x^2) = 2x)
<class '__main__.State'>
[2. 4. 6.]
4b. jax.vmap#
Batch over dynamic fields by stacking arrays along a new leading dimension.
@jax.vmap
def scale(state):
return state.x * 2
# Each row is one element of the batch
batched = State(x=jnp.array([[1.0, 2.0], [3.0, 4.0]]))
result = scale(batched)
print(result) # [[2. 4.] [6. 8.]]
[[2. 4.]
[6. 8.]]
4c. jax.lax.scan#
Dataclasses work as the carry in jax.lax.scan, enabling stateful loops without Python-level iteration.
class ScanState(drinx.DataClass):
x: jax.Array
step_size: float = drinx.static_field(default=0.1)
def step(carry, _):
new_x = carry.x - carry.step_size # gradient descent step
return ScanState(x=new_x, step_size=carry.step_size), new_x
init = ScanState(x=jnp.array(1.0))
final, history = jax.lax.scan(step, init, None, length=5)
print("history:", history) # [0.9, 0.8, 0.7, 0.6, 0.5]
print("final x:", final.x)
history: [0.9 0.79999995 0.6999999 0.5999999 0.4999999 ]
final x: 0.4999999
5. Nested dataclasses#
Drinx dataclasses compose naturally — nest them to represent hierarchical model parameters.
class Inner(drinx.DataClass):
w: jax.Array
class Outer(drinx.DataClass):
inner: Inner
bias: jax.Array
@jax.jit
def apply(outer, x):
return outer.inner.w @ x + outer.bias
outer = Outer(inner=Inner(w=jnp.eye(3)), bias=jnp.ones((3,)))
x = jnp.array([1.0, 2.0, 3.0])
print(apply(outer, x)) # [2. 3. 4.]
[2. 3. 4.]
6. Functional updates with aset#
Since drinx dataclasses are frozen, use .aset(path, value) to get an updated copy. Nested fields use -> as a separator. Note that this function is only available when using the inheritance style to create the dataclass.
outer = Outer(inner=Inner(w=jnp.ones((3,))), bias=jnp.zeros((3,)))
# Update a top-level field
outer2 = outer.aset("bias", jnp.ones((3,)))
print("original bias:", outer.bias) # [0. 0. 0.]
print("updated bias: ", outer2.bias) # [1. 1. 1.]
# Update a nested field using '->' path syntax
outer3 = outer.aset("inner->w", jnp.full((3,), 99.0))
print("nested update:", outer3.inner.w) # [99. 99. 99.]
original bias: [0. 0. 0.]
updated bias: [1. 1. 1.]
nested update: [99. 99. 99.]
7. Fluent updates with .at[].set()#
An alternative to aset is the .at[key].set(value) API, which mirrors JAX array indexing and supports chaining.
class Simple(drinx.DataClass):
a: jax.Array
b: jax.Array
tree = Simple(a=jnp.array(1.0), b=jnp.array(2.0))
# Single field update
result = tree.at["a"].set(jnp.array(99.0))
print(result) # Simple(a=99.0, b=2.0)
# Chained updates
result2 = tree.at["a"].set(jnp.array(10.0)).at["b"].set(jnp.array(20.0))
print(result2)
Simple(a=Array(99., dtype=float32, weak_type=True), b=Array(2., dtype=float32, weak_type=True))
Simple(a=Array(10., dtype=float32, weak_type=True), b=Array(20., dtype=float32, weak_type=True))
# Mask-based update: zero out all elements greater than 5
tree = Simple(a=jnp.array([1.0, 6.0, 3.0]), b=jnp.array([7.0, 2.0, 8.0]))
mask = jax.tree.map(lambda x: x > 5, tree)
result = tree.at[mask].set(0.0)
print(result.a) # [1. 0. 3.]
print(result.b) # [0. 2. 0.]
8. Combining jit + grad for a training step#
A minimal example of a JAX training loop using drinx to hold model parameters.
class LinearModel(drinx.DataClass):
w: jax.Array # weight matrix
b: jax.Array # bias vector
lr: float = drinx.static_field(default=0.01)
def mse_loss(model, x, y):
pred = x @ model.w + model.b
return jnp.mean((pred - y) ** 2)
@jax.jit
def train_step(model, x, y):
loss, grads = jax.value_and_grad(mse_loss)(model, x, y)
# Gradient descent: subtract lr * grad for each dynamic field
new_model = jax.tree_util.tree_map(lambda p, g: p - model.lr * g, model, grads)
return new_model, loss
# Toy dataset: y = 2x
key = jax.random.PRNGKey(0)
x_data = jax.random.normal(key, (32, 4))
y_data = x_data @ jnp.array([2.0, -1.0, 0.5, 1.5]) + 0.1
model = LinearModel(w=jnp.zeros((4,)), b=jnp.zeros(()))
for i in range(500):
model, loss = train_step(model, x_data, y_data)
print("Final loss:", float(loss))
print("Learned w: ", model.w)
print("True w: ", jnp.array([2.0, -1.0, 0.5, 1.5]))
Final loss: 6.751089676981792e-05
Learned w: [ 1.9897425 -1.0014862 0.49940124 1.4957309 ]
True w: [ 2. -1. 0.5 1.5]