API

API#

drinx.dataclass([cls, init, repr, eq, ...])

drinx.field(*[, default, default_factory, ...])

Define a dataclass field with optional JAX static marking.

drinx.static_field(*[, default, ...])

Define a JAX-static dataclass field.

drinx.private_field(*[, default, ...])

Define a private (non-init) dataclass field with optional JAX static marking.

drinx.static_private_field(*[, default, ...])

Define a private (non-init), JAX-static dataclass field.

drinx.DataClass()

Base class alternative to the @drinx.dataclass decorator.

drinx.tree_diagram(tree[, max_depth, ...])

Render a JAX pytree as an ASCII tree diagram.

drinx.tree_summary(tree[, max_depth])

Render a JAX pytree's leaves as a tabular summary.

drinx.is_traced(x)

Checks if an object is a JAX Tracer.