Python module
engine
The APIs in this module allow you to run inference with MAX Engineâa graph compiler and runtime that accelerates your AI models on a wide variety of hardware.
InferenceSessionâ
class max.engine.InferenceSession(devices, num_threads=None, *, custom_extensions=None)
Manages an inference session in which you can load and run models.
You need an instance of this to load a model as a Model object.
For example:
session = engine.InferenceSession(devices=[CPU()])
model_path = Path('bert-base-uncased')
model = session.load(model_path)Construct an inference session.
-
Parameters:
-
- num_threads (int | None) â Number of threads to use for the inference session. This defaults to the number of physical cores on your machine.
- devices (Iterable[Device]) â A list of devices on which to run inference. Default is the host CPU only.
- custom_extensions (CustomExtensionsType | None) â The extensions to load for the model. Supports paths to a .mojopkg custom ops library or a .mojo source file.
devicesâ
A list of available devices.
gpu_profiling()â
gpu_profiling(mode)
Enables GPU profiling instrumentation for the session.
This enables GPU profiling instrumentation that works with NVIDIA Nsight Systems and Nsight Compute. When enabled, the runtime adds CUDA driver calls and NVTX markers that allow profiling tools to correlate GPU kernel executions with host-side code.
For example, to enable detailed profiling for Nsight Systems analysis,
call gpu_profiling() before load():
from max.engine import InferenceSession
from max.driver import Accelerator
session = InferenceSession(devices=[Accelerator()])
session.gpu_profiling("detailed")
model = session.load(my_graph)Then run it with nsys:
nsys profile --trace=cuda,nvtx python example.pyOr, instead of calling session.gpu_profiling() in the code, you can
set the MODULAR_ENABLE_PROFILING environment variable when you call
nsys profile:
MODULAR_ENABLE_PROFILING=detailed nsys profile --trace=cuda,nvtx python script.pyBeware that gpu_profiling() overrides the
MODULAR_ENABLE_PROFILING environment variable if also used.
-
Parameters:
-
mode (Literal['off', 'on', 'detailed']) â
The profiling mode to set. One of:
"off": Disable profiling (default)."on": Enable basic profiling with NVTX markers for kernel correlation."detailed": Enable detailed profiling with additional Python-level NVTX markers.
-
Return type:
-
None
load()â
load(model, *, custom_extensions=None, weights_registry=None)
Loads a trained model and compiles it for inference.
-
Parameters:
-
- model (str | Path | Graph) â Path to a model.
- custom_extensions (CustomExtensionsType | None) â The extensions to load for the model. Supports paths to .mojopkg custom ops.
- weights_registry (Mapping[str, DLPackArray] | None) â A mapping from names of model weightsâ names to their values. The values are currently expected to be dlpack arrays. If an array is a read-only numpy array, the user must ensure that its lifetime extends beyond the lifetime of the model.
-
Returns:
-
The loaded model, compiled and ready to execute.
-
Raises:
-
RuntimeError â If the path provided is invalid.
-
Return type:
set_mojo_assert_level()â
set_mojo_assert_level(level)
Sets which mojo asserts are kept in the compiled model.
-
Parameters:
-
level (AssertLevel)
-
Return type:
-
None
set_mojo_log_level()â
set_mojo_log_level(level)
Sets the verbosity of mojo logging in the compiled model.
set_split_k_reduction_precision()â
set_split_k_reduction_precision(precision)
Sets the accumulation precision for split k reductions in large matmuls.
-
Parameters:
-
precision (str | SplitKReductionPrecision)
-
Return type:
-
None
use_old_top_k_kernel()â
use_old_top_k_kernel(mode)
Enables the old top-k kernel.
Default is to use the new top-k kernel to keep it consistent with max/kernels/src/nn/topk.mojo
-
Parameters:
-
mode (str) â String to enable/disable. Accepts âfalseâ, âoffâ, ânoâ, â0â to disable, any other value to enable.
-
Return type:
-
None
Modelâ
class max.engine.Model
A loaded model that you can execute.
Do not instantiate this class directly. Instead, create it with
InferenceSession.
__call__()â
__call__(*args, **kwargs)
Call self as a function.
capture()â
capture(graph_key, *inputs)
Capture execution into a device graph for caller-provided key.
Capture is best-effort and model-dependent. If the model issues capture-unsafe operations (for example, host-device synchronization), graph capture may fail. Callers should choose capture-safe execution paths.
debug_verify_replay()â
debug_verify_replay(graph_key, *inputs)
Execute eagerly and verify the launch trace matches the captured graph.
This method validates that graph capture correctly represents eager execution by running the model and comparing kernel launch sequences against a previously captured device graph.
-
Parameters:
-
Raises:
-
- TypeError â If
graph_keyis not an integer. - ValueError â If
graph_keyis out of uint64 range. - ValueError â If no input buffers are provided.
- RuntimeError â If no graph has been captured for
graph_key. - RuntimeError â If the eager execution trace doesnât match the captured graph.
- TypeError â If
-
Return type:
-
None
Example:
>>> model.capture(1, input_tensor)
>>> model.debug_verify_replay(1, input_tensor) # Validates capture
>>> model.replay(1, input_tensor) # Safe to use optimized replayexecute()â
execute(*args)
input_metadataâ
property input_metadata
Metadata about the modelâs input tensors, as a list of
TensorSpec objects.
For example, you can print the input tensor names, shapes, and dtypes:
for tensor in model.input_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')output_metadataâ
property output_metadata
Metadata about the modelâs output tensors, as a list of
TensorSpec objects.
For example, you can print the output tensor names, shapes, and dtypes:
for tensor in model.output_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')replay()â
replay(graph_key, *inputs)
Replay the captured device graph for a caller-provided key.
GPUProfilingModeâ
max.engine.GPUProfilingMode
alias of Literal[âoffâ, âonâ, âdetailedâ]
LogLevelâ
class max.engine.LogLevel(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
The LogLevel specifies the log level used by the Mojo Ops.
CRITICALâ
CRITICAL = 'critical'
DEBUGâ
DEBUG = 'debug'
ERRORâ
ERROR = 'error'
INFOâ
INFO = 'info'
NOTSETâ
NOTSET = 'notset'
TRACEâ
TRACE = 'trace'
WARNINGâ
WARNING = 'warning'
TensorSpecâ
class max.engine.TensorSpec
Defines the properties of a tensor, including its name, shape and data type.
For usage examples, see Model.input_metadata.
dtypeâ
property dtype
A tensor data type.
nameâ
property name
A tensor name.
shapeâ
property shape
The shape of the tensor as a list of integers.
If a dimension size is unknown/dynamic (such as the batch size), its
value is None.
CustomExtensionsTypeâ
max.engine.CustomExtensionsType = collections.abc.Sequence[str | pathlib._local.Path] | str | pathlib._local.Path
Represent a PEP 604 union type
E.g. for int | str
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!