API#
|
|
|
Define a dataclass field with optional JAX static marking. |
|
Define a JAX-static dataclass field. |
|
Define a private (non-init) dataclass field with optional JAX static marking. |
|
Define a private (non-init), JAX-static dataclass field. |
Base class alternative to the |
|
|
Render a JAX pytree as an ASCII tree diagram. |
|
Render a JAX pytree's leaves as a tabular summary. |
Checks if an object is a JAX Tracer. |