Basic Usage#

Drinx (Dataclass Registry in JAX) makes Python dataclasses work as JAX pytree nodes, so they pass through jit, grad, vmap, and other JAX transforms seamlessly.

Two usage styles are available:

  • Decorator: @drinx.dataclass

  • Inheritance: class Foo(drinx.DataClass)

import jax
import jax.numpy as jnp

import drinx

1. Decorator style#

@drinx.dataclass wraps dataclasses.dataclass and registers the class as a JAX pytree. All fields are dynamic (traced) by default.

@drinx.dataclass
class Params:
    weights: jax.Array
    bias: jax.Array

params = Params(weights=jnp.ones((3,)), bias=jnp.zeros((3,)))
print(params)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Params(weights=Array([1., 1., 1.], dtype=float32), bias=Array([0., 0., 0.], dtype=float32))
# jax.tree_util.tree_map works out of the box
doubled = jax.tree_util.tree_map(lambda x: x * 2, params)
print(doubled)
print(type(doubled))  # still a Params instance
Params(weights=Array([2., 2., 2.], dtype=float32), bias=Array([0., 0., 0.], dtype=float32))
<class '__main__.Params'>

2. Inheritance style#

Subclassing drinx.DataClass applies the transform automatically — no decorator needed.

class Model(drinx.DataClass):
    weights: jax.Array
    bias: jax.Array

model = Model(weights=jnp.ones((4, 4)), bias=jnp.zeros((4,)))
print(model)
Model(weights=Array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32), bias=Array([0., 0., 0., 0.], dtype=float32))

3. Static fields#

Mark a field as static to exclude it from JAX tracing. Static values are treated as compile-time constants by jit — changing a static field triggers recompilation.

@drinx.dataclass
class LinearLayer:
    weights: jax.Array
    # hidden_size is a compile-time constant — not traced by JAX
    hidden_size: int = drinx.static_field(default=128)

layer = LinearLayer(weights=jnp.ones((128, 32)))

@jax.jit
def forward(layer, x):
    return layer.weights[:layer.hidden_size] @ x

x = jnp.ones((32,))
result = forward(layer, x)
print(result.shape)  # (128,)
(128,)

Note: Changing a static field causes jit to recompile for the new value.

# Static field can vary per-instance at Python level
small_layer = LinearLayer(weights=jnp.ones((64, 32)), hidden_size=64)
result_small = forward(small_layer, x)
print(result_small.shape)  # (64,) — recompiled for hidden_size=64
(64,)

4. JAX transforms#

4a. jax.grad#

Gradients have the same pytree structure as the input.

class State(drinx.DataClass):
    x: jax.Array
    step_size: float = drinx.static_field(default=0.1)

def loss(state):
    return jnp.sum(state.x ** 2)

state = State(x=jnp.array([1.0, 2.0, 3.0]))
grads = jax.grad(loss)(state)

print(type(grads))   # State — same type as input
print(grads.x)       # [2. 4. 6.]  (gradient of sum(x^2) = 2x)
<class '__main__.State'>
[2. 4. 6.]

4b. jax.vmap#

Batch over dynamic fields by stacking arrays along a new leading dimension.

@jax.vmap
def scale(state):
    return state.x * 2

# Each row is one element of the batch
batched = State(x=jnp.array([[1.0, 2.0], [3.0, 4.0]]))
result = scale(batched)
print(result)  # [[2. 4.] [6. 8.]]
[[2. 4.]
 [6. 8.]]

4c. jax.lax.scan#

Dataclasses work as the carry in jax.lax.scan, enabling stateful loops without Python-level iteration.

class ScanState(drinx.DataClass):
    x: jax.Array
    step_size: float = drinx.static_field(default=0.1)

def step(carry, _):
    new_x = carry.x - carry.step_size  # gradient descent step
    return ScanState(x=new_x, step_size=carry.step_size), new_x

init = ScanState(x=jnp.array(1.0))
final, history = jax.lax.scan(step, init, None, length=5)

print("history:", history)  # [0.9, 0.8, 0.7, 0.6, 0.5]
print("final x:", final.x)
history: [0.9        0.79999995 0.6999999  0.5999999  0.4999999 ]
final x: 0.4999999

5. Nested dataclasses#

Drinx dataclasses compose naturally — nest them to represent hierarchical model parameters.

class Inner(drinx.DataClass):
    w: jax.Array

class Outer(drinx.DataClass):
    inner: Inner
    bias: jax.Array

@jax.jit
def apply(outer, x):
    return outer.inner.w @ x + outer.bias

outer = Outer(inner=Inner(w=jnp.eye(3)), bias=jnp.ones((3,)))
x = jnp.array([1.0, 2.0, 3.0])
print(apply(outer, x))  # [2. 3. 4.]
[2. 3. 4.]

6. Functional updates with aset#

Since drinx dataclasses are frozen, use .aset(path, value) to get an updated copy. Nested fields use -> as a separator. Note that this function is only available when using the inheritance style to create the dataclass.

outer = Outer(inner=Inner(w=jnp.ones((3,))), bias=jnp.zeros((3,)))

# Update a top-level field
outer2 = outer.aset("bias", jnp.ones((3,)))
print("original bias:", outer.bias)   # [0. 0. 0.]
print("updated bias: ", outer2.bias)  # [1. 1. 1.]

# Update a nested field using '->' path syntax
outer3 = outer.aset("inner->w", jnp.full((3,), 99.0))
print("nested update:", outer3.inner.w)  # [99. 99. 99.]
original bias: [0. 0. 0.]
updated bias:  [1. 1. 1.]
nested update: [99. 99. 99.]

7. Fluent updates with .at[].set()#

An alternative to aset is the .at[key].set(value) API, which mirrors JAX array indexing and supports chaining.

class Simple(drinx.DataClass):
    a: jax.Array
    b: jax.Array

tree = Simple(a=jnp.array(1.0), b=jnp.array(2.0))

# Single field update
result = tree.at["a"].set(jnp.array(99.0))
print(result)  # Simple(a=99.0, b=2.0)

# Chained updates
result2 = tree.at["a"].set(jnp.array(10.0)).at["b"].set(jnp.array(20.0))
print(result2)
Simple(a=Array(99., dtype=float32, weak_type=True), b=Array(2., dtype=float32, weak_type=True))
Simple(a=Array(10., dtype=float32, weak_type=True), b=Array(20., dtype=float32, weak_type=True))
# Mask-based update: zero out all elements greater than 5
tree = Simple(a=jnp.array([1.0, 6.0, 3.0]), b=jnp.array([7.0, 2.0, 8.0]))
mask = jax.tree.map(lambda x: x > 5, tree)
result = tree.at[mask].set(0.0)
print(result.a)  # [1. 0. 3.]
print(result.b)  # [0. 2. 0.]

8. Combining jit + grad for a training step#

A minimal example of a JAX training loop using drinx to hold model parameters.

class LinearModel(drinx.DataClass):
    w: jax.Array  # weight matrix
    b: jax.Array  # bias vector
    lr: float = drinx.static_field(default=0.01)

def mse_loss(model, x, y):
    pred = x @ model.w + model.b
    return jnp.mean((pred - y) ** 2)

@jax.jit
def train_step(model, x, y):
    loss, grads = jax.value_and_grad(mse_loss)(model, x, y)
    # Gradient descent: subtract lr * grad for each dynamic field
    new_model = jax.tree_util.tree_map(lambda p, g: p - model.lr * g, model, grads)
    return new_model, loss

# Toy dataset: y = 2x
key = jax.random.PRNGKey(0)
x_data = jax.random.normal(key, (32, 4))
y_data = x_data @ jnp.array([2.0, -1.0, 0.5, 1.5]) + 0.1

model = LinearModel(w=jnp.zeros((4,)), b=jnp.zeros(()))

for i in range(500):
    model, loss = train_step(model, x_data, y_data)

print("Final loss:", float(loss))
print("Learned w: ", model.w)
print("True w:    ", jnp.array([2.0, -1.0, 0.5, 1.5]))
Final loss: 6.751089676981792e-05
Learned w:  [ 1.9897425  -1.0014862   0.49940124  1.4957309 ]
True w:     [ 2.  -1.   0.5  1.5]