Visualization#
Drinx ships three functions for inspecting pytrees at a glance:
visualize_leaf(val)— compact one-line summary of a single leaf (dtype, shape, statistics)tree_diagram(tree)— ASCII tree showing the full pytree structure with leaf valuestree_summary(tree)— tabular overview of all leaves with element counts and byte sizes
import jax
import jax.numpy as jnp
import numpy as np
import drinx
from drinx import visualize_leaf, tree_diagram, tree_summary
1. visualize_leaf#
visualize_leaf produces a compact human-readable string for any pytree leaf.
The output format adapts to the type and dtype of the value.
Python scalars#
Plain Python scalars are shown via repr.
print(visualize_leaf(42)) # int
print(visualize_leaf(3.14)) # float
print(visualize_leaf(True)) # bool
print(visualize_leaf(1+2j)) # complex
42
3.14
True
(1+2j)
0-d (scalar) arrays#
Zero-dimensional arrays show their dtype and the scalar value.
print(visualize_leaf(jnp.array(2.5))) # f32[] 2.5
print(visualize_leaf(jnp.array(7, dtype=jnp.int32))) # i32[] 7
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
f32[] 2.5
i32[] 7
Numeric arrays#
For non-empty numeric arrays the output shows dtype, shape, value range [min, max], mean μ, and standard deviation σ.
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (4, 8)) # float32 matrix
print(visualize_leaf(x))
i = jnp.arange(10, dtype=jnp.int16) # integer 1-D array
print(visualize_leaf(i))
u = jnp.array([0, 1, 255], dtype=jnp.uint8) # unsigned 8-bit
print(visualize_leaf(u))
f32[4,8] ∈ [-2, 2.2], μ=0.17, σ=1.1
i16[10] ∈ [0, 9], μ=4.5, σ=2.9
u8[3] ∈ [0, 2.6e+02], μ=85, σ=1.2e+02
Boolean arrays#
Boolean arrays report the count of True and False elements instead of statistics.
mask = jnp.array([True, False, True, True, False])
print(visualize_leaf(mask)) # bool[5] #T=3, #F=2
bool[5] #T=3, #F=2
Complex arrays#
Statistics for complex arrays are computed on the magnitude |·|.
c = jnp.array([1+2j, 3+4j, 0+1j], dtype=jnp.complex64)
print(visualize_leaf(c)) # c64[3] |·| ∈ [...]
c64[3] |·| ∈ [1, 5], μ=2.7, σ=1.7
Empty arrays#
Arrays with no elements are flagged with (empty).
empty = jnp.zeros((0, 4))
print(visualize_leaf(empty)) # f32[0,4] (empty)
f32[0,4] (empty)
2. tree_diagram#
tree_diagram renders the full pytree structure as an ASCII tree.
Intermediate nodes show their class name; leaves show the visualize_leaf summary.
Simple drinx dataclass#
class Params(drinx.DataClass):
weights: jax.Array
bias: jax.Array
params = Params(
weights=jax.random.normal(key, (4, 8)),
bias=jnp.zeros((8,)),
)
print(tree_diagram(params))
Params
├── .weights=f32[4,8] ∈ [-2, 2.2], μ=0.17, σ=1.1
└── .bias=f32[8] ∈ [0, 0], μ=0, σ=0
Nested dataclasses#
Nesting is reflected in the tree indentation.
class Encoder(drinx.DataClass):
w: jax.Array
b: jax.Array
class Decoder(drinx.DataClass):
w: jax.Array
b: jax.Array
class Autoencoder(drinx.DataClass):
encoder: Encoder
decoder: Decoder
model = Autoencoder(
encoder=Encoder(w=jax.random.normal(key, (16, 32)), b=jnp.zeros((16,))),
decoder=Decoder(w=jax.random.normal(key, (32, 16)), b=jnp.zeros((32,))),
)
print(tree_diagram(model))
Autoencoder
├── .encoder:Encoder
│ ├── .w=f32[16,32] ∈ [-3.3, 2.6], μ=0.012, σ=0.98
│ └── .b=f32[16] ∈ [0, 0], μ=0, σ=0
└── .decoder:Decoder
├── .w=f32[32,16] ∈ [-3.3, 2.6], μ=0.012, σ=0.98
└── .b=f32[32] ∈ [0, 0], μ=0, σ=0
Limiting depth with max_depth#
Pass max_depth to stop expanding the tree beyond a certain level.
Truncated subtrees are shown with ....
print(tree_diagram(model, max_depth=1)) # stops at Encoder/Decoder level
Autoencoder
├── .encoder:Encoder ...
└── .decoder:Decoder ...
Showing static fields with static_leaves=True#
By default, static fields are omitted from the diagram because JAX does not trace them.
Set static_leaves=True to include them in their declaration order.
class MLP(drinx.DataClass):
weights: jax.Array
bias: jax.Array
activation: str = drinx.static_field(default="relu")
hidden_size: int = drinx.static_field(default=64)
mlp = MLP(weights=jax.random.normal(key, (64, 32)), bias=jnp.zeros((64,)))
print("--- dynamic only (default) ---")
print(tree_diagram(mlp))
print()
print("--- including static fields ---")
print(tree_diagram(mlp, static_leaves=True))
--- dynamic only (default) ---
MLP
├── .weights=f32[64,32] ∈ [-3.9, 3.6], μ=-0.019, σ=1
└── .bias=f32[64] ∈ [0, 0], μ=0, σ=0
--- including static fields ---
MLP
├── .weights=f32[64,32] ∈ [-3.9, 3.6], μ=-0.019, σ=1
├── .bias=f32[64] ∈ [0, 0], μ=0, σ=0
├── .activation='relu'
└── .hidden_size=64
Standard JAX pytrees#
tree_diagram works on any JAX-compatible pytree, not just drinx dataclasses.
nested = {
"layer1": {"w": jnp.ones((4, 4)), "b": jnp.zeros((4,))},
"layer2": {"w": jnp.ones((2, 4)), "b": jnp.zeros((2,))},
}
print(tree_diagram(nested))
dict
├── ['layer1']:dict
│ ├── ['b']=f32[4] ∈ [0, 0], μ=0, σ=0
│ └── ['w']=f32[4,4] ∈ [1, 1], μ=1, σ=0
└── ['layer2']:dict
├── ['b']=f32[2] ∈ [0, 0], μ=0, σ=0
└── ['w']=f32[2,4] ∈ [1, 1], μ=1, σ=0
mixed = [jnp.array([1.0, 2.0]), (jnp.eye(3), jnp.zeros((3,)))]
print(tree_diagram(mixed))
list
├── [0]=f32[2] ∈ [1, 2], μ=1.5, σ=0.5
└── [1]:tuple
├── [0]=f32[3,3] ∈ [0, 1], μ=0.33, σ=0.47
└── [1]=f32[3] ∈ [0, 0], μ=0, σ=0
3. tree_summary#
tree_summary renders a tabular overview of all pytree leaves.
Each row shows the leaf’s path, type, element count, and byte size.
A totals row (Σ) is appended at the bottom.
Basic summary#
print(tree_summary(model))
┌──────────┬──────────┬─────┬───────┐
│Name │Type │Count│Size │
├──────────┼──────────┼─────┼───────┤
│.encoder.w│f32[16,32]│512 │2.00KB │
├──────────┼──────────┼─────┼───────┤
│.encoder.b│f32[16] │16 │64.00B │
├──────────┼──────────┼─────┼───────┤
│.decoder.w│f32[32,16]│512 │2.00KB │
├──────────┼──────────┼─────┼───────┤
│.decoder.b│f32[32] │32 │128.00B│
├──────────┼──────────┼─────┼───────┤
│Σ │Tree │1072 │4.19KB │
└──────────┴──────────┴─────┴───────┘
Limiting depth#
With max_depth, subtrees that exceed the depth limit are collapsed into a single
aggregated row showing combined element count and size.
print(tree_summary(model, max_depth=1))
┌────────┬───────┬─────┬──────┐
│Name │Type │Count│Size │
├────────┼───────┼─────┼──────┤
│.encoder│Encoder│528 │2.06KB│
├────────┼───────┼─────┼──────┤
│.decoder│Decoder│544 │2.12KB│
├────────┼───────┼─────┼──────┤
│Σ │Tree │1072 │4.19KB│
└────────┴───────┴─────┴──────┘
Larger model#
The summary is especially useful for larger models where you want a quick parameter count.
class TransformerBlock(drinx.DataClass):
q_proj: jax.Array # query projection
k_proj: jax.Array # key projection
v_proj: jax.Array # value projection
out_proj: jax.Array
ff1: jax.Array # feed-forward layer 1
ff2: jax.Array # feed-forward layer 2
d_model: int = drinx.static_field(default=128)
d = 128
block = TransformerBlock(
q_proj=jax.random.normal(key, (d, d)),
k_proj=jax.random.normal(key, (d, d)),
v_proj=jax.random.normal(key, (d, d)),
out_proj=jax.random.normal(key, (d, d)),
ff1=jax.random.normal(key, (d, 4 * d)),
ff2=jax.random.normal(key, (4 * d, d)),
)
print(tree_summary(block))
┌─────────┬────────────┬──────┬────────┐
│Name │Type │Count │Size │
├─────────┼────────────┼──────┼────────┤
│.q_proj │f32[128,128]│16384 │64.00KB │
├─────────┼────────────┼──────┼────────┤
│.k_proj │f32[128,128]│16384 │64.00KB │
├─────────┼────────────┼──────┼────────┤
│.v_proj │f32[128,128]│16384 │64.00KB │
├─────────┼────────────┼──────┼────────┤
│.out_proj│f32[128,128]│16384 │64.00KB │
├─────────┼────────────┼──────┼────────┤
│.ff1 │f32[128,512]│65536 │256.00KB│
├─────────┼────────────┼──────┼────────┤
│.ff2 │f32[512,128]│65536 │256.00KB│
├─────────┼────────────┼──────┼────────┤
│Σ │Tree │196608│768.00KB│
└─────────┴────────────┴──────┴────────┘