Visualization#

Drinx ships three functions for inspecting pytrees at a glance:

  • visualize_leaf(val) — compact one-line summary of a single leaf (dtype, shape, statistics)

  • tree_diagram(tree) — ASCII tree showing the full pytree structure with leaf values

  • tree_summary(tree) — tabular overview of all leaves with element counts and byte sizes

import jax
import jax.numpy as jnp
import numpy as np

import drinx
from drinx import visualize_leaf, tree_diagram, tree_summary

1. visualize_leaf#

visualize_leaf produces a compact human-readable string for any pytree leaf. The output format adapts to the type and dtype of the value.

Python scalars#

Plain Python scalars are shown via repr.

print(visualize_leaf(42))         # int
print(visualize_leaf(3.14))       # float
print(visualize_leaf(True))       # bool
print(visualize_leaf(1+2j))       # complex
42
3.14
True
(1+2j)

0-d (scalar) arrays#

Zero-dimensional arrays show their dtype and the scalar value.

print(visualize_leaf(jnp.array(2.5)))           # f32[] 2.5
print(visualize_leaf(jnp.array(7, dtype=jnp.int32)))  # i32[] 7
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
f32[] 2.5
i32[] 7

Numeric arrays#

For non-empty numeric arrays the output shows dtype, shape, value range [min, max], mean μ, and standard deviation σ.

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (4, 8))          # float32 matrix
print(visualize_leaf(x))

i = jnp.arange(10, dtype=jnp.int16)         # integer 1-D array
print(visualize_leaf(i))

u = jnp.array([0, 1, 255], dtype=jnp.uint8) # unsigned 8-bit
print(visualize_leaf(u))
f32[4,8] ∈ [-2, 2.2], μ=0.17, σ=1.1
i16[10] ∈ [0, 9], μ=4.5, σ=2.9
u8[3] ∈ [0, 2.6e+02], μ=85, σ=1.2e+02

Boolean arrays#

Boolean arrays report the count of True and False elements instead of statistics.

mask = jnp.array([True, False, True, True, False])
print(visualize_leaf(mask))  # bool[5] #T=3, #F=2
bool[5] #T=3, #F=2

Complex arrays#

Statistics for complex arrays are computed on the magnitude |·|.

c = jnp.array([1+2j, 3+4j, 0+1j], dtype=jnp.complex64)
print(visualize_leaf(c))  # c64[3] |·| ∈ [...]
c64[3] |·| ∈ [1, 5], μ=2.7, σ=1.7

Empty arrays#

Arrays with no elements are flagged with (empty).

empty = jnp.zeros((0, 4))
print(visualize_leaf(empty))  # f32[0,4] (empty)
f32[0,4] (empty)

2. tree_diagram#

tree_diagram renders the full pytree structure as an ASCII tree. Intermediate nodes show their class name; leaves show the visualize_leaf summary.

Simple drinx dataclass#

class Params(drinx.DataClass):
    weights: jax.Array
    bias: jax.Array

params = Params(
    weights=jax.random.normal(key, (4, 8)),
    bias=jnp.zeros((8,)),
)

print(tree_diagram(params))
Params
├── .weights=f32[4,8] ∈ [-2, 2.2], μ=0.17, σ=1.1
└── .bias=f32[8] ∈ [0, 0], μ=0, σ=0

Nested dataclasses#

Nesting is reflected in the tree indentation.

class Encoder(drinx.DataClass):
    w: jax.Array
    b: jax.Array

class Decoder(drinx.DataClass):
    w: jax.Array
    b: jax.Array

class Autoencoder(drinx.DataClass):
    encoder: Encoder
    decoder: Decoder

model = Autoencoder(
    encoder=Encoder(w=jax.random.normal(key, (16, 32)), b=jnp.zeros((16,))),
    decoder=Decoder(w=jax.random.normal(key, (32, 16)), b=jnp.zeros((32,))),
)

print(tree_diagram(model))
Autoencoder
├── .encoder:Encoder
│   ├── .w=f32[16,32] ∈ [-3.3, 2.6], μ=0.012, σ=0.98
│   └── .b=f32[16] ∈ [0, 0], μ=0, σ=0
└── .decoder:Decoder
    ├── .w=f32[32,16] ∈ [-3.3, 2.6], μ=0.012, σ=0.98
    └── .b=f32[32] ∈ [0, 0], μ=0, σ=0

Limiting depth with max_depth#

Pass max_depth to stop expanding the tree beyond a certain level. Truncated subtrees are shown with ....

print(tree_diagram(model, max_depth=1))  # stops at Encoder/Decoder level
Autoencoder
├── .encoder:Encoder ...
└── .decoder:Decoder ...

Showing static fields with static_leaves=True#

By default, static fields are omitted from the diagram because JAX does not trace them. Set static_leaves=True to include them in their declaration order.

class MLP(drinx.DataClass):
    weights: jax.Array
    bias: jax.Array
    activation: str = drinx.static_field(default="relu")
    hidden_size: int = drinx.static_field(default=64)

mlp = MLP(weights=jax.random.normal(key, (64, 32)), bias=jnp.zeros((64,)))

print("--- dynamic only (default) ---")
print(tree_diagram(mlp))

print()
print("--- including static fields ---")
print(tree_diagram(mlp, static_leaves=True))
--- dynamic only (default) ---
MLP
├── .weights=f32[64,32] ∈ [-3.9, 3.6], μ=-0.019, σ=1
└── .bias=f32[64] ∈ [0, 0], μ=0, σ=0

--- including static fields ---
MLP
├── .weights=f32[64,32] ∈ [-3.9, 3.6], μ=-0.019, σ=1
├── .bias=f32[64] ∈ [0, 0], μ=0, σ=0
├── .activation='relu'
└── .hidden_size=64

Standard JAX pytrees#

tree_diagram works on any JAX-compatible pytree, not just drinx dataclasses.

nested = {
    "layer1": {"w": jnp.ones((4, 4)), "b": jnp.zeros((4,))},
    "layer2": {"w": jnp.ones((2, 4)), "b": jnp.zeros((2,))},
}

print(tree_diagram(nested))
dict
├── ['layer1']:dict
│   ├── ['b']=f32[4] ∈ [0, 0], μ=0, σ=0
│   └── ['w']=f32[4,4] ∈ [1, 1], μ=1, σ=0
└── ['layer2']:dict
    ├── ['b']=f32[2] ∈ [0, 0], μ=0, σ=0
    └── ['w']=f32[2,4] ∈ [1, 1], μ=1, σ=0
mixed = [jnp.array([1.0, 2.0]), (jnp.eye(3), jnp.zeros((3,)))]
print(tree_diagram(mixed))
list
├── [0]=f32[2] ∈ [1, 2], μ=1.5, σ=0.5
└── [1]:tuple
    ├── [0]=f32[3,3] ∈ [0, 1], μ=0.33, σ=0.47
    └── [1]=f32[3] ∈ [0, 0], μ=0, σ=0

3. tree_summary#

tree_summary renders a tabular overview of all pytree leaves. Each row shows the leaf’s path, type, element count, and byte size. A totals row (Σ) is appended at the bottom.

Basic summary#

print(tree_summary(model))
┌──────────┬──────────┬─────┬───────┐
│Name      │Type      │Count│Size   │
├──────────┼──────────┼─────┼───────┤
│.encoder.w│f32[16,32]│512  │2.00KB │
├──────────┼──────────┼─────┼───────┤
│.encoder.b│f32[16]   │16   │64.00B │
├──────────┼──────────┼─────┼───────┤
│.decoder.w│f32[32,16]│512  │2.00KB │
├──────────┼──────────┼─────┼───────┤
│.decoder.b│f32[32]   │32   │128.00B│
├──────────┼──────────┼─────┼───────┤
│Σ         │Tree      │1072 │4.19KB │
└──────────┴──────────┴─────┴───────┘

Limiting depth#

With max_depth, subtrees that exceed the depth limit are collapsed into a single aggregated row showing combined element count and size.

print(tree_summary(model, max_depth=1))
┌────────┬───────┬─────┬──────┐
│Name    │Type   │Count│Size  │
├────────┼───────┼─────┼──────┤
│.encoder│Encoder│528  │2.06KB│
├────────┼───────┼─────┼──────┤
│.decoder│Decoder│544  │2.12KB│
├────────┼───────┼─────┼──────┤
│Σ       │Tree   │1072 │4.19KB│
└────────┴───────┴─────┴──────┘

Larger model#

The summary is especially useful for larger models where you want a quick parameter count.

class TransformerBlock(drinx.DataClass):
    q_proj: jax.Array  # query projection
    k_proj: jax.Array  # key projection
    v_proj: jax.Array  # value projection
    out_proj: jax.Array
    ff1: jax.Array     # feed-forward layer 1
    ff2: jax.Array     # feed-forward layer 2
    d_model: int = drinx.static_field(default=128)

d = 128
block = TransformerBlock(
    q_proj=jax.random.normal(key, (d, d)),
    k_proj=jax.random.normal(key, (d, d)),
    v_proj=jax.random.normal(key, (d, d)),
    out_proj=jax.random.normal(key, (d, d)),
    ff1=jax.random.normal(key, (d, 4 * d)),
    ff2=jax.random.normal(key, (4 * d, d)),
)

print(tree_summary(block))
┌─────────┬────────────┬──────┬────────┐
│Name     │Type        │Count │Size    │
├─────────┼────────────┼──────┼────────┤
│.q_proj  │f32[128,128]│16384 │64.00KB │
├─────────┼────────────┼──────┼────────┤
│.k_proj  │f32[128,128]│16384 │64.00KB │
├─────────┼────────────┼──────┼────────┤
│.v_proj  │f32[128,128]│16384 │64.00KB │
├─────────┼────────────┼──────┼────────┤
│.out_proj│f32[128,128]│16384 │64.00KB │
├─────────┼────────────┼──────┼────────┤
│.ff1     │f32[128,512]│65536 │256.00KB│
├─────────┼────────────┼──────┼────────┤
│.ff2     │f32[512,128]│65536 │256.00KB│
├─────────┼────────────┼──────┼────────┤
│Σ        │Tree        │196608│768.00KB│
└─────────┴────────────┴──────┴────────┘