Skip to main content

Python module

engine

The APIs in this module allow you to run inference with MAX Engine—a graph compiler and runtime that accelerates your AI models on a wide variety of hardware.

InferenceSession​

class max.engine.InferenceSession(devices, num_threads=None, *, custom_extensions=None)

Manages an inference session in which you can load and run models.

You need an instance of this to load a model as a Model object. For example:

session = engine.InferenceSession(devices=[CPU()])
model_path = Path('bert-base-uncased')
model = session.load(model_path)

Construct an inference session.

Parameters:

  • num_threads (int | None) – Number of threads to use for the inference session. This defaults to the number of physical cores on your machine.
  • devices (Iterable[Device]) – A list of devices on which to run inference. Default is the host CPU only.
  • custom_extensions (CustomExtensionsType | None) – The extensions to load for the model. Supports paths to a .mojopkg custom ops library or a .mojo source file.

devices​

property devices: list[Device]

A list of available devices.

gpu_profiling()​

gpu_profiling(mode)

Enables GPU profiling instrumentation for the session.

This enables GPU profiling instrumentation that works with NVIDIA Nsight Systems and Nsight Compute. When enabled, the runtime adds CUDA driver calls and NVTX markers that allow profiling tools to correlate GPU kernel executions with host-side code.

For example, to enable detailed profiling for Nsight Systems analysis, call gpu_profiling() before load():

from max.engine import InferenceSession
from max.driver import Accelerator

session = InferenceSession(devices=[Accelerator()])
session.gpu_profiling("detailed")
model = session.load(my_graph)

Then run it with nsys:

nsys profile --trace=cuda,nvtx python example.py

Or, instead of calling session.gpu_profiling() in the code, you can set the MODULAR_ENABLE_PROFILING environment variable when you call nsys profile:

MODULAR_ENABLE_PROFILING=detailed nsys profile --trace=cuda,nvtx python script.py

Beware that gpu_profiling() overrides the MODULAR_ENABLE_PROFILING environment variable if also used.

Parameters:

mode (Literal['off', 'on', 'detailed']) –

The profiling mode to set. One of:

  • "off": Disable profiling (default).
  • "on": Enable basic profiling with NVTX markers for kernel correlation.
  • "detailed": Enable detailed profiling with additional Python-level NVTX markers.

Return type:

None

load()​

load(model, *, custom_extensions=None, weights_registry=None)

Loads a trained model and compiles it for inference.

Parameters:

  • model (str | Path | Graph) – Path to a model.
  • custom_extensions (CustomExtensionsType | None) – The extensions to load for the model. Supports paths to .mojopkg custom ops.
  • weights_registry (Mapping[str, DLPackArray] | None) – A mapping from names of model weights’ names to their values. The values are currently expected to be dlpack arrays. If an array is a read-only numpy array, the user must ensure that its lifetime extends beyond the lifetime of the model.

Returns:

The loaded model, compiled and ready to execute.

Raises:

RuntimeError – If the path provided is invalid.

Return type:

Model

set_mojo_assert_level()​

set_mojo_assert_level(level)

Sets which mojo asserts are kept in the compiled model.

Parameters:

level (AssertLevel)

Return type:

None

set_mojo_log_level()​

set_mojo_log_level(level)

Sets the verbosity of mojo logging in the compiled model.

Parameters:

level (str | LogLevel)

Return type:

None

set_split_k_reduction_precision()​

set_split_k_reduction_precision(precision)

Sets the accumulation precision for split k reductions in large matmuls.

Parameters:

precision (str | SplitKReductionPrecision)

Return type:

None

use_old_top_k_kernel()​

use_old_top_k_kernel(mode)

Enables the old top-k kernel.

Default is to use the new top-k kernel to keep it consistent with max/kernels/src/nn/topk.mojo

Parameters:

mode (str) – String to enable/disable. Accepts “false”, “off”, “no”, “0” to disable, any other value to enable.

Return type:

None

Model​

class max.engine.Model

A loaded model that you can execute.

Do not instantiate this class directly. Instead, create it with InferenceSession.

__call__()​

__call__(*args, **kwargs)

Call self as a function.

Parameters:

Return type:

list[Buffer]

capture()​

capture(graph_key, *inputs)

Capture execution into a device graph for caller-provided key.

Capture is best-effort and model-dependent. If the model issues capture-unsafe operations (for example, host-device synchronization), graph capture may fail. Callers should choose capture-safe execution paths.

Parameters:

Return type:

list[Buffer]

debug_verify_replay()​

debug_verify_replay(graph_key, *inputs)

Execute eagerly and verify the launch trace matches the captured graph.

This method validates that graph capture correctly represents eager execution by running the model and comparing kernel launch sequences against a previously captured device graph.

Parameters:

  • self (Model) – The model to debug/verify
  • graph_key (int) – Caller-provided key identifying the captured graph.
  • inputs (Buffer) – Input buffers matching the captured input signature (same shapes and dtypes used during capture).

Raises:

  • TypeError – If graph_key is not an integer.
  • ValueError – If graph_key is out of uint64 range.
  • ValueError – If no input buffers are provided.
  • RuntimeError – If no graph has been captured for graph_key.
  • RuntimeError – If the eager execution trace doesn’t match the captured graph.

Return type:

None

Example:

>>> model.capture(1, input_tensor)
>>> model.debug_verify_replay(1, input_tensor)  # Validates capture
>>> model.replay(1, input_tensor)  # Safe to use optimized replay

execute()​

execute(*args)

Parameters:

Return type:

list[Buffer]

input_metadata​

property input_metadata

Metadata about the model’s input tensors, as a list of TensorSpec objects.

For example, you can print the input tensor names, shapes, and dtypes:

for tensor in model.input_metadata:
    print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')

output_metadata​

property output_metadata

Metadata about the model’s output tensors, as a list of TensorSpec objects.

For example, you can print the output tensor names, shapes, and dtypes:

for tensor in model.output_metadata:
    print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')

replay()​

replay(graph_key, *inputs)

Replay the captured device graph for a caller-provided key.

Parameters:

Return type:

None

GPUProfilingMode​

max.engine.GPUProfilingMode

alias of Literal[‘off’, ‘on’, ‘detailed’]

LogLevel​

class max.engine.LogLevel(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

The LogLevel specifies the log level used by the Mojo Ops.

CRITICAL​

CRITICAL = 'critical'

DEBUG​

DEBUG = 'debug'

ERROR​

ERROR = 'error'

INFO​

INFO = 'info'

NOTSET​

NOTSET = 'notset'

TRACE​

TRACE = 'trace'

WARNING​

WARNING = 'warning'

TensorSpec​

class max.engine.TensorSpec

Defines the properties of a tensor, including its name, shape and data type.

For usage examples, see Model.input_metadata.

dtype​

property dtype

A tensor data type.

name​

property name

A tensor name.

shape​

property shape

The shape of the tensor as a list of integers.

If a dimension size is unknown/dynamic (such as the batch size), its value is None.

CustomExtensionsType​

max.engine.CustomExtensionsType = collections.abc.Sequence[str | pathlib._local.Path] | str | pathlib._local.Path

Represent a PEP 604 union type

E.g. for int | str

Was this page helpful?