drinx.static_field

Contents

drinx.static_field#

drinx.static_field(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=<dataclasses._MISSING_TYPE object>, on_setattr=(), on_getattr=())[source]#

Define a JAX-static dataclass field.

Convenience wrapper around field() with static=True pre-set. The field is excluded from JAX tracing and stored as pytree auxiliary data, meaning changes to it trigger recompilation under jit. Use this for configuration values, shapes, or other compile-time constants.

Parameters:
  • default (Any) – Default value for the field.

  • default_factory (Union[Callable[[], Any], Any]) – Zero-argument callable returning the default value.

  • init (bool) – Include the field in the generated __init__.

  • repr (bool) – Include the field in the generated __repr__.

  • hash (bool | None) – Include the field in __hash__ (None defers to compare).

  • compare (bool) – Include the field in __eq__ and ordering methods.

  • metadata (dict[str, Any] | None) – Additional metadata merged with the jax_static entry.

  • kw_only (Any) – Override the class-level kw_only setting for this field.

Return type:

Any

Returns:

A dataclasses.Field descriptor (typed as Any).