NNX#

NNX is a JAX-based neural network library designed for simplicity and power. Its modular approach follows standard Python conventions, making it both intuitive and compatible with the broader JAX ecosystem.

Note

NNX is currently in an experimental state and is subject to change. Linen is still the recommended option for large-scale projects. Feedback and contributions are welcome!

Features#

Pythonic

Modules are standard Python classes, promoting ease of use and a more familiar development experience.

Compatible

Effortlessly convert between Modules and pytrees using the Functional API for maximum flexibility.

Control

Manage a Module’s state with precision using typed Variable collections, enabling fine-grained control on JAX transformations.

User-friendly

NNX prioritizes simplicity for common use cases, building upon lessons learned from Linen to provide a streamlined experience.

Basic usage#

from flax.experimental import nnx
import optax


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing

@nnx.jit # automatic state management
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # inplace updates

  return loss

Installation#

NNX is under active development, we recommend using the latest version from Flax’s GitHub repository:

pip install git+https://github.com/google/flax.git

Learn more#

NNX Basics
MNIST Tutorial
NNX vs JAX Transformations
API reference