drinx.tree_summary#
- drinx.tree_summary(tree, max_depth=None)[source]#
Render a JAX pytree’s leaves as a tabular summary.
Columns: Name (path), Type, Count (element count), Size (byte size).
- Parameters:
tree (
Any) – Any JAX pytree.max_depth (
int|None) – Maximum depth to expand.Nonemeans unlimited (all leaves). When a subtree is truncated atmax_depth, it appears as a single row with aggregated count and size.
- Return type:
str- Returns:
A multi-line string with the summary table.