View source on GitHub
|
Returns a distributed TPU mesh optimized for AllReduce ring reductions.
tf.experimental.dtensor.create_tpu_mesh(
mesh_dim_names: List[str],
mesh_shape: List[int],
mesh_name: str,
ring_dims: Optional[int] = None,
ring_axes: Optional[List[str]] = None,
ring_bounds: Optional[List[int]] = None,
can_split_host_across_rings: bool = True,
build_ring_across_rings: bool = False,
rotate_ring_across_rings: bool = False,
use_xla_spmd: bool = layout_lib.USE_XLA_SPMD
) -> tf.experimental.dtensor.Mesh
Only as many as leading axes specified by ring_axes as necessary will be
used to build rings, as long as the subslice formed by these axes have enough
cores to contain a ring of the required size. The leftover axes in ring_axes
won't affect results.
This function always uses all TPU devices, and offers more customization than
tf.experimental.dtensor.create_distributed_mesh.
Args |
|---|
mesh_dim_names
mesh_shape
mesh_name
ring_dims
ring_axes
ring_bounds
can_split_host_across_rings
build_ring_across_rings
rotate_ring_across_rings
use_xla_spmd
View source on GitHub