Skip to content

Commit

Permalink
Fix or ignore some pytype errors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675513581
  • Loading branch information
Martin Huschenbett authored and The jax_tpu_embedding Authors committed Sep 17, 2024
1 parent e59add0 commit 7a703de
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax_tpu_embedding/examples/singlehost_pjit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def build_embedding_configs(batch_size_per_device: int,


def create_global_mesh(mesh_shape: Tuple[int, ...],
axis_names: Sequence[jax.pxla.MeshAxisName]) -> Mesh:
axis_names: Sequence[jax.pxla.MeshAxisName]) -> Mesh: # pytype: disable=module-attr
size = np.prod(mesh_shape)
if len(jax.devices()) < size:
raise ValueError(f'Test requires {size} global devices.')
Expand Down

0 comments on commit 7a703de

Please sign in to comment.