drinx.tree_summary

Contents

drinx.tree_summary#

drinx.tree_summary(tree, max_depth=None)[source]#

Render a JAX pytree’s leaves as a tabular summary.

Columns: Name (path), Type, Count (element count), Size (byte size).

Parameters:
  • tree (Any) – Any JAX pytree.

  • max_depth (int | None) – Maximum depth to expand. None means unlimited (all leaves). When a subtree is truncated at max_depth, it appears as a single row with aggregated count and size.

Return type:

str

Returns:

A multi-line string with the summary table.