-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add input type validation to feed_ndarray in MXNet and PyTorch #3308
Conversation
- adds type validation between DALI Tensor/TensorList and MXNet/PyTorch tensor inside feed_ndarray just in case anyone wants to use feed_ndarray directly - it doesn't cover PaddlePaddle as feed_ndarray accepts a raw pointer to PaddlePaddle tensor and there is no API to check the type - updates usage of raises, assert_raises in test_fw_iterators.py to use implementation from nose_utils Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
!build |
CI MESSAGE: [2901568]: BUILD STARTED |
CI MESSAGE: [2901568]: BUILD PASSED |
@@ -52,6 +52,15 @@ def feed_ndarray(dali_tensor, arr, cuda_stream = None): | |||
In most cases, using the default internal user stream or stream 0 | |||
is expected. | |||
""" | |||
if isinstance(dali_tensor, (TensorListCPU, TensorListGPU)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cannot do such thing for PaddlePaddle as feed_ndarray
there accepts raw pointer. We cannot accept tensor itself there as we cannot extract pointer itself to it without providing placement and type, so we would need to extend feed_ndarray
signature and move allocation of data there as well (setting shape and placement does that). See the API.
pipe.build() | ||
out = pipe.run()[0] | ||
torch_tensor = torch.empty((1), dtype=torch.int8, device = 'cpu') | ||
assert_raises(AssertionError, feed_ndarray, out, torch_tensor, glob="Type of DALI Tensor/TensorList doesn't match Torch tensor type:*") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a nitpick, error message checking looks for any occurrence of the pattern, so stars are not necessary at the beginning and end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
!build |
CI MESSAGE: [2909038]: BUILD STARTED |
CI MESSAGE: [2909038]: BUILD FAILED |
CI MESSAGE: [2909038]: BUILD PASSED |
@@ -344,7 +345,8 @@ def check_mxnet_iterator_pass_reader_name(shards_num, pipes_number, batch_size, | |||
|
|||
if batch_size > data_set_size // shards_num and last_batch_policy == LastBatchPolicy.DROP: | |||
assert_raises(AssertionError, MXNetIterator, pipes, [ | |||
("ids", MXNetIterator.DATA_TAG)], reader_name="Reader", last_batch_policy=last_batch_policy) | |||
("ids", MXNetIterator.DATA_TAG)], reader_name="Reader", last_batch_policy=last_batch_policy, | |||
glob="It seems that there is no data in the pipeline. This may happen if `last_batch_policy` is set to PARTIAL and the requested batch size is greater than the shard size.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should put verbatim copies of the entire message here - just enough to make sure it's the error we expect - like:
glob="It seems that there is no data in the pipeline. This may happen if `last_batch_policy` is set to PARTIAL and the requested batch size is greater than the shard size.") | |
glob="It seems that there is no data in the pipeline*last_batch_policy*") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
assert dali_type == arr.dtype, ("Type of DALI Tensor/TensorList" | ||
" doesn't match MXNet tensor type: {} vs {}".format(dali_type, np.dtype(arr.dtype))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert dali_type == arr.dtype, ("Type of DALI Tensor/TensorList" | |
" doesn't match MXNet tensor type: {} vs {}".format(dali_type, np.dtype(arr.dtype))) | |
assert dali_type == arr.dtype, ("The element type of DALI Tensor/TensorList" | |
" doesn't match the element type of the target MXNet NDArray: {} vs {}".format(dali_type, np.dtype(arr.dtype))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
assert to_torch_type[dali_type] == arr.dtype, ("Type of DALI Tensor/TensorList" | ||
" doesn't match Torch tensor type: {} vs {}".format(to_torch_type[dali_type], arr.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert to_torch_type[dali_type] == arr.dtype, ("Type of DALI Tensor/TensorList" | |
" doesn't match Torch tensor type: {} vs {}".format(to_torch_type[dali_type], arr.dtype)) | |
assert to_torch_type[dali_type] == arr.dtype, ("The element type of DALI Tensor/TensorList" | |
" doesn't match the element type of the target PyTorch Tensor: {} vs {}".format(to_torch_type[dali_type], arr.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
pipe.build() | ||
out = pipe.run()[0] | ||
mxnet_tensor = mxnet.nd.empty([1], None, np.int8) | ||
assert_raises(AssertionError, feed_ndarray, out, mxnet_tensor, glob="Type of DALI Tensor/TensorList doesn't match MXNet tensor type:") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the pattern here if you update the message as suggested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adjusted
Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
8a5dfef
to
2cacd47
Compare
!build |
CI MESSAGE: [2925593]: BUILD STARTED |
CI MESSAGE: [2925593]: BUILD PASSED |
inside feed_ndarray just in case anyone wants to use feed_ndarray directly
PaddlePaddle tensor and there is no API to check the type
implementation from nose_utils
Signed-off-by: Janusz Lisiecki jlisiecki@nvidia.com
Description
What happened in this PR
inside feed_ndarray just in case anyone wants to use feed_ndarray directly
PaddlePaddle tensor and there is no API to check the type
implementation from nose_utils
Additional information
Checklist
Tests
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: N/A