drinx.field#
- drinx.field(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=<dataclasses._MISSING_TYPE object>, static=False, on_setattr=(), on_getattr=())[source]#
Define a dataclass field with optional JAX static marking.
Thin wrapper around
dataclasses.field()that injects thejax_statickey into the field’s metadata. Whenstatic=Truethe field is excluded from JAX tracing and placed in the pytree auxiliary data; whenstatic=False(the default) the field is a traced pytree leaf.- Parameters:
default (
Any) – Default value for the field.default_factory (
Union[Callable[[],Any],Any]) – Zero-argument callable returning the default value. Mutually exclusive with default.init (
bool) – Include the field in the generated__init__.repr (
bool) – Include the field in the generated__repr__.hash (
bool|None) – Include the field when computing__hash__.Nonedefers to the value of compare.compare (
bool) – Include the field in__eq__and ordering methods.metadata (
dict[str,Any] |None) – Additional metadata dict merged with thejax_staticentry.kw_only (
Any) – Override the class-levelkw_onlysetting for this field.static (
bool) – WhenTrue, mark the field as JAX-static (excluded from tracing). Defaults toFalse.
- Return type:
Any- Returns:
A
dataclasses.Fielddescriptor (typed asAnyfor compatibility with type checkers).