drinx.field

Contents

drinx.field#

drinx.field(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=<dataclasses._MISSING_TYPE object>, static=False, on_setattr=(), on_getattr=())[source]#

Define a dataclass field with optional JAX static marking.

Thin wrapper around dataclasses.field() that injects the jax_static key into the field’s metadata. When static=True the field is excluded from JAX tracing and placed in the pytree auxiliary data; when static=False (the default) the field is a traced pytree leaf.

Parameters:
  • default (Any) – Default value for the field.

  • default_factory (Union[Callable[[], Any], Any]) – Zero-argument callable returning the default value. Mutually exclusive with default.

  • init (bool) – Include the field in the generated __init__.

  • repr (bool) – Include the field in the generated __repr__.

  • hash (bool | None) – Include the field when computing __hash__. None defers to the value of compare.

  • compare (bool) – Include the field in __eq__ and ordering methods.

  • metadata (dict[str, Any] | None) – Additional metadata dict merged with the jax_static entry.

  • kw_only (Any) – Override the class-level kw_only setting for this field.

  • static (bool) – When True, mark the field as JAX-static (excluded from tracing). Defaults to False.

Return type:

Any

Returns:

A dataclasses.Field descriptor (typed as Any for compatibility with type checkers).