-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add basic jax.Sharding support for the iterator #4969
Conversation
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
CI MESSAGE: [9206164]: BUILD STARTED |
CI MESSAGE: [9206164]: BUILD PASSED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
@@ -140,13 +143,15 @@ def __init__( | |||
auto_reset=False, | |||
last_batch_padded=False, | |||
last_batch_policy=LastBatchPolicy.FILL, | |||
prepare_first_batch=True): | |||
prepare_first_batch=True, | |||
sharding=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sharding=None): | |
sharding: jax.sharding.Sharding=None): |
Small suggestion, you may consider adding type hint here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the semantics of type hints? The documentation says it only needs to be jax.sharding.Sharding
compatible, not necessarily an instance of this type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.sharding.Sharding
is abstract base class. This should accept NamesSharding
, PositionalSharding
and maybe others in the future.
When it comes to inheritance and type hints this should be done as
from typing import Type
...
sharding; Type[jax.sharding.Sharding]=None):
I added straight up type assertion against NamesSharding
and PositionalSharding
since these are the ones we are testing against. If there is new type of sharding in the future we will add it and add tests for it.
We are not using type hints anywhere so I didn't want to add them just here.
sharding : ``jax.sharding.Sharding`` comaptible object that if present will be used to build | ||
output jax.Array for each category. If ``None`` iterator returns values compatible | ||
with pmaped JAX functions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sharding : ``jax.sharding.Sharding`` comaptible object that if present will be used to build | |
output jax.Array for each category. If ``None`` iterator returns values compatible | |
with pmaped JAX functions. | |
sharding : ``jax.sharding.Sharding`` comaptible object that, if present, will be used to build an | |
output jax.Array for each category. If ``None``, the iterator returns values compatible | |
with pmapped JAX functions. |
Not sure about "pmapped" - double 'p' if it's somehow derived from a verb "to map" (map -> mapped).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about the other changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They weren't there yet :)
Done
Signed-off-by: Albert Wolant <awolant@nvidia.com>
sharding : ``jax.sharding.Sharding`` comaptible object that if present will be used to build | ||
output jax.Array for each category. If ``None`` iterator returns values compatible |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These still apply:
sharding : ``jax.sharding.Sharding`` comaptible object that if present will be used to build | |
output jax.Array for each category. If ``None`` iterator returns values compatible | |
sharding : ``jax.sharding.Sharding`` comaptible object that, if present, will be used to build an | |
output jax.Array for each category. If ``None``, the iterator returns values compatible |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
return category_outputs | ||
|
||
def _build_output_with_devices(self, next_output, category_name, category_outputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite follow what this function does. An error message in L234 mentions sharding, but the function is invoked when _sharding is None
.
Some comment explaining high level functionality would be nice to avoid making this code "write only".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, there are two mechanisms of building sharded output: jax.sharding.Sharding
and jax.device_put_sharded()
. We want to support both. The error mentioned shard
as a general concept not sharding
object.
I changed the name to _build_output_with_device_put
to indicate which function takes which path. This should be enough to distinguish them.
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
CI MESSAGE: [9223789]: BUILD STARTED |
CI MESSAGE: [9223789]: BUILD PASSED |
Add basic jax.Sharding support for the iterator Signed-off-by: Albert Wolant <awolant@nvidia.com>
Category:
New feature
Description:
Add basic jax.Sharding support for the iterator
Additional information:
Affected modules and functionalities:
JAX iterator has new argument. If it is provided we have different output format.
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: DALI-3558