-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[split_dataset] migrating from tf.keras to keras_core #505
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Please also bring in the unit tests and add tests for torch datasets.
QQ: rather than special casing torch datasets, could we just support any object that implements len and getitem?
keras_core/utils/dataset_utils.py
Outdated
@@ -1,3 +1,11 @@ | |||
import tensorflow as tf | |||
import torch | |||
from torch.utils.data import Dataset as torchDataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only import torch, access Dataset from there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet sure, but I argue that using this specific import is much error proof while importing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added changes for this.
keras_core/utils/dataset_utils.py
Outdated
|
||
Args: | ||
dataset: A `tf.data.Dataset` object, or a list/tuple of arrays with the | ||
same length. | ||
same length. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use 4 space indent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet yes will apply black in later stages.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
keras_core/utils/dataset_utils.py
Outdated
@@ -30,20 +36,440 @@ def split_dataset( | |||
Example: | |||
|
|||
>>> data = np.random.random(size=(1000, 4)) | |||
>>> left_ds, right_ds = split_dataset(data, left_size=0.8) | |||
>>> left_ds, right_ds = tf.keras.utils.split_dataset(data, left_size=0.8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No TF references
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet will add a separate commit for doc string.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
start_time, | ||
): | ||
if dataset_type_spec in [tuple, list]: | ||
# The try-except here is for NumPy 1.24 compatibility, see: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this actually needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet yes since we allow dataspec to be tuple
this will be required not sure about list
will test it out.
@fchollet hmm🤔 need to think about it. will add further details to this comment. |
keras_core/utils/dataset_utils.py
Outdated
@@ -1,3 +1,11 @@ | |||
import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing -- we should not import tf or torch at the top of the file since that would make them required dependencies. They should be imported when needed (e.g. only import torch when you need to process a torch dataset).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
umm ok but since they are being used in multiple function declaring there would allows us to have multiple import statements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done for torch also importing tf
from module_utils
keras_core/utils/dataset_utils.py
Outdated
@@ -1,3 +1,8 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be imported here. Torch is not a dependency of the package.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
keras_core/utils/dataset_utils.py
Outdated
If integer, it signifies the number of samples to pack | ||
in the left dataset. If `None`, it defaults to the complement | ||
to `right_size`. | ||
the fraction of the data to pack in the left dataset. If integer, it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use 4 space indent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indent is still 2 spaces.
You shouldn't actually need to modify the docstring at all, mind you. The original docstring was fine.
|
||
class DatasetUtilsTest(test_case.TestCase): | ||
def test_split_dataset_list(self): | ||
n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure to run sh shell/format.sh
and keep lines under 80 chars
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
keras_core/utils/dataset_utils.py
Outdated
If integer, it signifies the number of samples to pack | ||
in the left dataset. If `None`, it defaults to the complement | ||
to `right_size`. | ||
the fraction of the data to pack in the left dataset. If integer, it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indent is still 2 spaces.
You shouldn't actually need to modify the docstring at all, mind you. The original docstring was fine.
keras_core/utils/dataset_utils.py
Outdated
right_size=right_size, | ||
shuffle=shuffle, | ||
seed=seed, | ||
from torch.utils.data import Dataset as torchDataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Do not import torch unless the dataset passed is a torch dataset.
- Just
import torch
then accesstorch.utils.data.Dataset
.
To check whether it's a torch dataset, you can do something similar to this https://github.com/keras-team/keras-core/blob/main/keras_core/trainers/epoch_iterator.py#L228
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keras_core/utils/dataset_utils.py
Outdated
right_split = right_split.prefetch(tf.data.AUTOTUNE) | ||
return left_split, right_split | ||
|
||
elif dataset_type_spec == torchDataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought, I don't think we should support torch datasets at all here, because the API is becoming completely inconsistent:
- pass numpy arrays, get back tf.data.Dataset
- pass tf.data.Dataset, get back tf.data.Dataset
- pass torch dataset, get back torch dataset
Let's just stick to always returning a tf.data.Dataset IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet but if that's the case how this function can be used with torch workflow as it will require to return dataset which is compatible with torch ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet and since returning torch is only constrained within torch backend and will not impact other it should be safe as even going forward with jax we would need to add for that too. What do you think ? 🧐 Also if api consistency is issue one possible solution is to only return numpy arrays and will leave at user how it is needed, this way we will ensure to support every framework but not sure how this will work ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tf.data.Datasets
are supported by Keras models with all backends, so it's fine IMO. It's also the only way to stay be backwards compatible with tf.keras, which is very important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet aah got it then let's only use tf.data.Datasets
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet Please you can review now following changes are made.
torch.data.Dataset
is only imported once.- return will be limited to tf.data.Datasets.
- fixed indentation.
@fchollet Need clarification from you I have moved logic of detecting torch tensor and torch dataset to is_torch_tensor
and is_torch_dataset
since is_torch_dataloader
is under trainers should I move these function there or keep it here only ?
Apologies I missed it now it is resolved. |
@fchollet could you approve test workflow as all the changes have been done from my end. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates! There are various docstring issues left but I'll take it from here.
Hi Team,
This PR made following changes to following functions
keras
data-utils for split_dataset from here making sure all the settings stay's intact.tf.Dataset
and we need to supporttorch
andJax
.split_dataset
_convert_dataset_to_list
_get_data_iterator_from_dataset
_get_next_sample
_get_type_spec
_get_next_sample
_rescale_dataset_split_sizes
_restore_dataset_from_list