Skip to main content

Python package

nn

APIs to build neural network components for deep learning models with Python.

The MAX neural network API provides two namespaces:

  • max.nn: Graph-based API for building computational graphs.
  • max.nn.module_v3: Eager-style execution with PyTorch-style syntax.

For functional operations like relu, softmax, and more, see the functional module.

Graph API​

Use these modules for building graph-based neural networks.

  • attention: Attention mechanisms for sequence modeling.
  • clamp: Value clamping utilities for tensor operations.
  • comm: Communication primitives for distributed training.
  • conv: Convolutional layers for spatial processing.
  • conv_transpose: Transposed convolution for upsampling.
  • data_parallelism: Utilities for splitting batches across devices.
  • embedding: Embedding layers with vocabulary support.
  • float8_config: Configuration for FP8 quantization.
  • hooks: Extension hooks for layer customization.
  • identity: Identity layer that passes inputs through unchanged.
  • kernels: Custom kernel implementations.
  • kv_cache: Key-value cache for efficient generation.
  • layer: Base classes for building graph-based layers.
  • linear: Linear transformation layers with optional parallelism.
  • lora: Low-Rank Adaptation for efficient fine-tuning.
  • moe: Mixture of Experts layer implementations.
  • norm: Normalization layers for training stability.
  • rotary_embedding: Rotary position embeddings for sequences.
  • sampling: Sampling strategies for generation.
  • sequential: Container for sequential layer composition.
  • transformer: Transformer building blocks and layers.

Eager API (module_v3)​

  • module: Base class for all neural network modules.
  • Conv2d: 2D convolution layer.
  • Embedding: Vector embedding layer for token representation.
  • Linear: Linear transformation layer with weights and bias.
  • sequential: Containers for composing modules sequentially.
  • norm: Normalization layers (GemmaRMSNorm, RMSNorm, LayerNorm, GroupNorm).
  • rope: Rotary position embeddings (RotaryEmbedding, TransposedRotaryEmbedding).

Was this page helpful?