module#

class flax.experimental.nnx.Module(*args, **kwargs)[source]#
sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#
iter_modules()[source]#

Iterates over all nested Modules of the current Module, including the current Module.

iter_modules creates a generator that yields the path and the Module instance, where the path is a tuple of strings or integers representing the path to the Module from the root Module.

Example:

>>> from flax.experimental import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_modules():
...   print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
() Block
eval(**attributes)[source]#

Sets the Module to evaluation mode.

eval uses set_attributes to recursively set attributes deterministic=True and use_running_average=True of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax.experimental import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
Parameters

**attributes – additional attributes passed to set_attributes.

property init#

Calls a method in initialization mode.

When a method is called using init, the is_initializing method will return True. This is useful to implement Modules that support lazy initialization.

Example:

>>> from flax.experimental import nnx
>>> import jax
>>> import jax.numpy as jnp
...
>>> class Linear(nnx.Module):
...   def __init__(self, dout, rngs: nnx.Rngs):
...     self.dout = dout
...     self.rngs = rngs
...
...   def __call__(self, x):
...     if self.is_initializing():
...       din = x.shape[-1]
...       if not hasattr(self, 'w'):
...         key = self.rngs.params()
...         self.w = nnx.Param(jax.random.uniform(key, (din, self.dout)))
...       if not hasattr(self, 'b'):
...         self.b = nnx.Param(jnp.zeros((self.dout,)))
...
...     return x @ self.w + self.b
...
>>> linear = Linear(3, nnx.Rngs(0))
>>> x = jnp.ones((5, 2))
>>> y = linear.init(x)
>>> linear.w.value.shape
(2, 3)
>>> linear.b.value.shape
(3,)
>>> y.shape
(5, 3)
is_initializing()[source]#

Returns whether the Module is initializing.

is_initializing returns True if the Module is currently being run under init.

iter_modules()[source]#

Iterates over all nested Modules of the current Module, including the current Module.

iter_modules creates a generator that yields the path and the Module instance, where the path is a tuple of strings or integers representing the path to the Module from the root Module.

Example:

>>> from flax.experimental import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_modules():
...   print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
() Block
set_attributes(*filters, raise_if_not_found=True, **attributes)[source]#

Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.

Example:

>>> from flax.experimental import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

Filter’s can be used to set the attributes of specific Modules:

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
Parameters
  • *filters – Filters to select the Modules to set the attributes of.

  • raise_if_not_found – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.

  • **attributes – The attributes to set.

train(**attributes)[source]#

Sets the Module to training mode.

train uses set_attributes to recursively set attributes deterministic=False and use_running_average=False of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax.experimental import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     # initialize Dropout and BatchNorm in eval mode
...     self.dropout = nnx.Dropout(0.5, deterministic=True)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
Parameters

**attributes – additional attributes passed to set_attributes.