Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for multi node JAX sharding #5242

Merged
merged 9 commits into from
Jan 29, 2024

Conversation

awolant
Copy link
Contributor

@awolant awolant commented Dec 13, 2023

Category:

Bug fix

Description:

In some situations data_iterator for JAX did not work well in multiprocess environment. This PR improves that.

Additional information:

Affected modules and functionalities:

Iterator for JAX. Some adjustments were mode on a code path where sharding argument is provided.

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: DALI-3670

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
@@ -172,7 +172,7 @@ def _next_impl(self):
for category_id, category_name in enumerate(self.output_map):
category_outputs = self._gather_outputs_for_category(pipelines_outputs, category_id)

if self._num_gpus == 1:
if self._num_gpus == 1 and self._sharding is None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._num_gpus is equal to the number of pipelines run by this instance of the iterator. So if we need to distinguish multi node training with one GPU per node (process) from just one GPU training.

Comment on lines +233 to +236
if isinstance(self._sharding, NamedSharding):
global_shape = (self._sharding.mesh.size * shard_shape[0], *shard_shape[1:])
else:
global_shape = (self._sharding.shape[0] * shard_shape[0], *shard_shape[1:])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sharding variants have inconsisten APIs when it comes to getting the global shape.

@awolant
Copy link
Contributor Author

awolant commented Dec 18, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [11566280]: BUILD STARTED

@awolant awolant changed the title [WIP] Add support for multi node JAX sharding Add support for multi node JAX sharding Dec 18, 2023
@awolant awolant changed the title Add support for multi node JAX sharding Fix support for multi node JAX sharding Dec 18, 2023
@awolant awolant marked this pull request as ready for review December 18, 2023 07:50
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [11566280]: BUILD PASSED

Comment on lines 318 to 321
assert jax.local_device_count() == jax.device_count(), (
"Iterator compatible with pmapped JAX functions does not support "
"running in multiprocess mode. Use `sharding` argument instead."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assertions != error checking. If you intend this to be a proper error, please use appropriate exception with explicit raise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@awolant
Copy link
Contributor Author

awolant commented Dec 18, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [11573496]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [11573496]: BUILD PASSED

@awolant
Copy link
Contributor Author

awolant commented Dec 20, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [11618023]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [11618023]: BUILD FAILED

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@awolant
Copy link
Contributor Author

awolant commented Jan 7, 2024

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [11946905]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [11946905]: BUILD PASSED

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
@awolant
Copy link
Contributor Author

awolant commented Jan 23, 2024

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [12275250]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [12275250]: BUILD PASSED

@awolant awolant merged commit b4c83b9 into NVIDIA:main Jan 29, 2024
7 checks passed
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants