from __future__ import annotations
from drinx.jax_utils import is_traced
from typing import Any
import dataclasses
import numpy as np
import jax
import jax.tree_util
def _fmt(x: float) -> str:
"""Format a float to 4 significant figures."""
return f"{x:.2g}"
def _is_array(val: Any) -> bool:
"""Return True if val is a NumPy or JAX array (not a Python scalar)."""
return isinstance(val, (np.ndarray, jax.Array))
def _dtype_str(dtype: np.dtype) -> str:
"""Return compact dtype label: 'bool' for booleans, else '<kind><bits>' e.g. 'f32'."""
return "bool" if dtype.kind == "b" else f"{dtype.kind}{dtype.itemsize * 8}"
def visualize_leaf(val: int | float | complex | bool | np.ndarray | jax.Array) -> str:
"""Return a compact human-readable summary string for a JAX pytree leaf.
Produces a type-annotated string representation with statistics appropriate
for the value's kind:
- **Python scalars** (``bool``, ``int``, ``float``, ``complex``): ``repr(val)``
- **Tracers**: ``"<dtype>[<shape>] (Tracer)"``
- **Scalar arrays** (0-d): ``"<dtype>[] <value>"``
- **Empty arrays**: ``"<dtype>[<shape>] (empty)"``
- **Boolean arrays**: ``"bool[<shape>] #T=<n_true>, #F=<n_false>"``
- **Complex arrays**: ``"c<bits>[<shape>] |·| ∈ [min, max], μ=mean, σ=std"`` (stats on magnitude)
- **Numeric arrays**: ``"<dtype>[<shape>] ∈ [min, max], μ=mean, σ=std"``
The dtype string uses ``kind + bit-width`` notation (e.g. ``f32``, ``i64``, ``u8``, ``c128``),
except booleans which are shown as ``bool``.
Args:
val: A pytree leaf — either a Python scalar or a NumPy/JAX array.
Returns:
A compact summary string describing the value's type, shape, and statistics.
Raises:
AssertionError: If ``val`` is not a supported type.
"""
# 1. Handle Python built-in scalars
if isinstance(val, (bool, int, float, complex)):
return repr(val)
if not _is_array(val):
return repr(val)
dtype, shape = val.dtype, val.shape
# 2. Build compact dtype string (NumPy's dtype.kind already returns 'f', 'i', 'u', 'c', 'b')
dtype_str = _dtype_str(dtype)
prefix = f"{dtype_str}[{','.join(str(d) for d in shape)}]"
# 3. Handle Tracers
if is_traced(val):
return f"{prefix} (Tracer)"
arr = np.asarray(val)
# 4. Handle edge-case array shapes
if arr.ndim == 0:
return f"{prefix} {repr(arr.item())}"
if arr.size == 0:
return f"{prefix} (empty)"
# 5. Handle Boolean arrays
if dtype.kind == "b":
n_true = int(arr.sum())
return f"{prefix} #T={n_true}, #F={arr.size - n_true}"
target = np.abs(arr) if dtype.kind == "c" else arr
lo, hi = float(target.min()), float(target.max())
# Calculate mean and std directly as floats to prevent overflow on smaller dtypes
mu = float(target.mean(dtype=float))
sigma = float(target.std(dtype=float))
sym = "|·| " if dtype.kind == "c" else ""
return f"{prefix} {sym}∈ [{_fmt(lo)}, {_fmt(hi)}], μ={_fmt(mu)}, σ={_fmt(sigma)}"
def _format_key(key: Any) -> str:
"""Convert a JAX path key to a display label."""
if isinstance(key, jax.tree_util.GetAttrKey):
return f".{key.name}"
elif isinstance(key, jax.tree_util.SequenceKey):
return f"[{key.idx}]"
elif isinstance(key, jax.tree_util.DictKey):
return f"['{key.key}']" if isinstance(key.key, str) else f"[{key.key}]"
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
return f"[{key.key}]"
else:
return str(key)
def _get_one_level(
node: Any, static_leaves: bool = False
) -> list[tuple[str, Any]] | None:
"""Get one level of children from a pytree node.
Returns None if node is a leaf.
"""
results, _ = jax.tree_util.tree_flatten_with_path(
node, is_leaf=lambda x: x is not node
)
# A leaf has a single entry with an empty path
if len(results) == 1 and len(results[0][0]) == 0:
return None
children = [(_format_key(path[0]), child) for path, child in results]
if static_leaves and dataclasses.is_dataclass(node) and not isinstance(node, type):
dynamic_dict = dict(children)
ordered = []
for f in dataclasses.fields(node):
key = f".{f.name}"
if f.metadata.get("jax_static"):
ordered.append((key, getattr(node, f.name)))
elif key in dynamic_dict:
ordered.append((key, dynamic_dict[key]))
children = ordered
return children
def _build_lines(
node: Any,
depth: int,
max_depth: int | None,
prefix: str,
lines: list[str],
static_leaves: bool = False,
) -> None:
children = _get_one_level(node, static_leaves)
if children is None:
return
for i, (key_label, child) in enumerate(children):
last = i == len(children) - 1
connector = "└── " if last else "├── "
child_children = _get_one_level(child, static_leaves)
if child_children is None:
lines.append(f"{prefix}{connector}{key_label}={visualize_leaf(child)}")
elif max_depth is not None and depth + 1 >= max_depth:
lines.append(f"{prefix}{connector}{key_label}:{type(child).__name__} ...")
else:
lines.append(f"{prefix}{connector}{key_label}:{type(child).__name__}")
ext = " " if last else "│ "
_build_lines(
child, depth + 1, max_depth, prefix + ext, lines, static_leaves
)
def _leaf_type_str(val: Any) -> str:
if isinstance(val, (bool, int, float, complex)):
return type(val).__name__
if _is_array(val):
dtype = val.dtype
return f"{_dtype_str(dtype)}[{','.join(map(str, val.shape))}]"
return type(val).__name__
def _leaf_count(val: Any) -> int:
if _is_array(val):
return val.size
return 1
def _format_bytes(nbytes: int) -> str:
if nbytes < 1024:
return f"{nbytes:.2f}B"
elif nbytes < 1024**2:
return f"{nbytes / 1024:.2f}KB"
elif nbytes < 1024**3:
return f"{nbytes / 1024**2:.2f}MB"
else:
return f"{nbytes / 1024**3:.2f}GB"
def _leaf_size_str(val: Any) -> str:
if _is_array(val):
return _format_bytes(val.nbytes)
return ""
def _collect_summary_entries(
node: Any, path_str: str, depth: int, max_depth: int | None
) -> list[tuple[str, Any]]:
children = _get_one_level(node)
if children is None:
return [(path_str, node)]
if max_depth is not None and depth >= max_depth:
return [(path_str, node)]
result = []
for key_label, child in children:
result.extend(
_collect_summary_entries(child, path_str + key_label, depth + 1, max_depth)
)
return result
[docs]
def tree_summary(tree: Any, max_depth: int | None = None) -> str:
"""Render a JAX pytree's leaves as a tabular summary.
Columns: Name (path), Type, Count (element count), Size (byte size).
Args:
tree: Any JAX pytree.
max_depth: Maximum depth to expand. ``None`` means unlimited (all leaves).
When a subtree is truncated at ``max_depth``, it appears as a single
row with aggregated count and size.
Returns:
A multi-line string with the summary table.
"""
entries = _collect_summary_entries(tree, "", 0, max_depth)
rows: list[tuple[str, str, str, str]] = []
total_count = 0
total_bytes = 0
for name, val in entries:
is_leaf = _get_one_level(val) is None
if is_leaf:
type_str = _leaf_type_str(val)
count = _leaf_count(val)
size_str = _leaf_size_str(val)
entry_bytes = val.nbytes if _is_array(val) else 0
else:
type_str = type(val).__name__
sub_leaves = jax.tree_util.tree_leaves(val)
count = sum(_leaf_count(lf) for lf in sub_leaves)
entry_bytes = sum(lf.nbytes for lf in sub_leaves if _is_array(lf))
size_str = _format_bytes(entry_bytes) if entry_bytes > 0 else ""
rows.append((name, type_str, str(count), size_str))
total_count += count
total_bytes += entry_bytes
summary_size = _format_bytes(total_bytes) if total_bytes > 0 else ""
rows.append(("Σ", "Tree", str(total_count), summary_size))
headers = ("Name", "Type", "Count", "Size")
col_widths = [
max(len(h), max(len(r[i]) for r in rows)) for i, h in enumerate(headers)
]
def render_row(cells: tuple[str, ...]) -> str:
return "│" + "│".join(c.ljust(w) for c, w in zip(cells, col_widths)) + "│"
def divider(left: str, mid: str, right: str) -> str:
return left + mid.join("─" * w for w in col_widths) + right
lines = [
divider("┌", "┬", "┐"),
render_row(headers),
]
for row in rows:
lines.append(divider("├", "┼", "┤"))
lines.append(render_row(row))
lines.append(divider("└", "┴", "┘"))
return "\n".join(lines)
[docs]
def tree_diagram(
tree: Any, max_depth: int | None = None, static_leaves: bool = False
) -> str:
"""Render a JAX pytree as an ASCII tree diagram.
Args:
tree: Any JAX pytree.
max_depth: Maximum depth to expand. ``None`` means unlimited.
static_leaves: If ``True``, show static fields of drinx dataclasses in
declaration order, interleaved with dynamic fields.
Returns:
A multi-line string with the tree diagram.
"""
lines = [type(tree).__name__]
_build_lines(tree, 0, max_depth, "", lines, static_leaves)
return "\n".join(lines)