Skip to content

Commit

Permalink
Added manual partitioning to JAX TPU embedding to support the use cas…
Browse files Browse the repository at this point in the history
…e of `shard_map`.

PiperOrigin-RevId: 677110682
  • Loading branch information
Dateng Lin authored and The jax_tpu_embedding Authors committed Sep 21, 2024
1 parent 7a703de commit f4095d2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
24 changes: 20 additions & 4 deletions jax_tpu_embedding/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ class TpuEmbeddingConfigSpecs:
overlap with the TensorCore computations.
num_hosts: number of hosts.
num_tensor_cores: number of tensor cores.
cores_per_replica: number of cores per replica, use for spmd when it's
not None.
cores_per_replica: number of cores per replica, use for spmd when it's not
None.
manual_partitioning: If True, the tensors are manually partitioned.
Otherwise, use the automatic SPMD partitioning.
"""
feature_config: NestedFeatureConfig
output_shapes: List[OutputShape]
Expand All @@ -61,6 +63,7 @@ class TpuEmbeddingConfigSpecs:
num_hosts: int
num_tensor_cores: int
cores_per_replica: Optional[int]
manual_partitioning: bool


def create_tpu_embedding_config(
Expand All @@ -69,7 +72,9 @@ def create_tpu_embedding_config(
pipeline_execution_with_tensor_core: bool,
num_hosts: int,
num_tensor_cores: int,
cores_per_replica: Optional[int] = None) -> TpuEmbeddingConfigSpecs:
cores_per_replica: Optional[int] = None,
manual_partitioning: bool = False,
) -> TpuEmbeddingConfigSpecs:
"""Creates TpuEmbeddingConfigSpecs.
Args:
Expand All @@ -82,6 +87,8 @@ def create_tpu_embedding_config(
num_tensor_cores: number of tensor cores.
cores_per_replica: number of cores for one replica. If None, config would be
for data parallelism only, if not None config will be set for SPMD.
manual_partitioning: If True, the tensors are manually partitioned.
Otherwise, use the automatic SPMD partitioning.
Raises:
ValueError: If optimizer is not one of tf.tpu.experimental.embedding.(SGD,
Expand Down Expand Up @@ -136,7 +143,9 @@ def create_tpu_embedding_config(
pipeline_execution_with_tensor_core=pipeline_execution_with_tensor_core,
num_hosts=num_hosts,
num_tensor_cores=num_tensor_cores,
cores_per_replica=cores_per_replica)
cores_per_replica=cores_per_replica,
manual_partitioning=manual_partitioning,
)


def create_tpu_embedding_configs(
Expand All @@ -146,6 +155,7 @@ def create_tpu_embedding_configs(
num_hosts: int,
num_tensor_cores: int,
cores_per_replica: Optional[int] = None,
manual_partitioning: bool = False,
) -> List[TpuEmbeddingConfigSpecs]:
"""Creates a list of TpuEmbeddingConfigSpecs.
Expand All @@ -159,6 +169,8 @@ def create_tpu_embedding_configs(
num_tensor_cores: number of tensor cores.
cores_per_replica: number of cores for one replica. If None, config would be
for data parallelism only, if not None config will be set for SPMD.
manual_partitioning: If True, the tensors are manually partitioned.
Otherwise, use the automatic SPMD partitioning.
Raises:
ValueError: If optimizer is not one of tf.tpu.experimental.embedding.(SGD,
Expand All @@ -177,6 +189,7 @@ def create_tpu_embedding_configs(
num_hosts,
num_tensor_cores,
cores_per_replica,
manual_partitioning,
)
tpu_embedding_configs.append(tpu_embedding_config)
return tpu_embedding_configs
Expand Down Expand Up @@ -342,6 +355,9 @@ def create_config_proto(
config_proto.spmd_sharding.enabled = True
config_proto.spmd_sharding.num_cores_per_replica = (
tpu_embedding_config.cores_per_replica)
config_proto.spmd_sharding.manual_partitioning = (
tpu_embedding_config.manual_partitioning
)

return config_proto

Expand Down
10 changes: 8 additions & 2 deletions jax_tpu_embedding/tpu_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,14 @@ def __init__(
use_pathways: bool = False,
num_shards: int | None = None,
input_split_fn: Callable[..., dict[str, Any]] | None = None,
manual_partitioning: bool = False,
):
"""Creates JAX TPUEmbedding object.
Args:
feature_configs: A nested structure, or a list of nested structure, or a
standalone instance of `tf.tpu.experimental.embedding.FeatureConfig`
configs. It must be a list of nested structure if using Pathways (see
configs. It must be a list of nested structure if using Pathways (see
below); and otherwise it must be the others.
optimizer: An instance of one of embedding optimizers like
`tf.tpu.experimental.embedding.SGD`,
Expand All @@ -164,6 +165,8 @@ def __init__(
input_split_fn: A callable function takes elements from iterator, yields
splits pytree of host and device batches in a dictionary. This should be
supplied if users want to call `experimental_get_next`.
manual_partitioning: If True, the tensors are manually partitioned.
Otherwise, use the automatic SPMD partitioning.
Raises:
ValueError: when cores_per_replica is not legal.
Expand All @@ -184,6 +187,7 @@ def __init__(
pipeline_execution_with_tensor_core
)
self._cores_per_replica = cores_per_replica
self._manual_partitioning = manual_partitioning

# Create config_utils.TpuEmbeddingConfig instance.
if use_pathways:
Expand All @@ -194,6 +198,7 @@ def __init__(
num_hosts=self._num_hosts,
num_tensor_cores=self._num_tensor_cores,
cores_per_replica=self._cores_per_replica,
manual_partitioning=self._manual_partitioning,
)
# We assume the output shapes and dynamic learning rates are the same
# across all the tasks. This may be relaxed in the future.
Expand All @@ -212,6 +217,7 @@ def __init__(
num_hosts=self._num_hosts,
num_tensor_cores=self._num_tensor_cores,
cores_per_replica=self._cores_per_replica,
manual_partitioning=self._manual_partitioning,
)
self._table_config_list = self._tpu_embedding_config.table_config_list
self._output_shapes = self._tpu_embedding_config.output_shapes
Expand Down Expand Up @@ -716,7 +722,7 @@ def _gradients_fn(
local_shape = gradient.shape.as_list()

# When self._core_per_replica is not None, it uses BC spmd.
if self._cores_per_replica:
if self._cores_per_replica and not self._manual_partitioning:
local_shape[0] = local_shape[0] // self._cores_per_replica
if local_shape != full_output_shape:
raise ValueError("Found gradient of shape {} at path {}. Expected "
Expand Down

0 comments on commit f4095d2

Please sign in to comment.