Source code for drinx.jax_utils
from typing import Any
from jax import core
[docs]
def is_traced(x: Any) -> bool:
"""
Checks if an object is a JAX Tracer.
In JAX, tracers are used during transformations (like `jit`, `grad`, or `vmap`)
to represent abstract values rather than concrete arrays. This function
identifies if the input is currently being tracked by the JAX dispatcher.
Args:
x: The object to check.
Returns:
True if `x` is a `jax.core.Tracer`, False otherwise.
"""
return isinstance(x, core.Tracer)