Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[split_dataset] migrating from tf.keras to keras_core #505

Merged
merged 16 commits into from
Jul 28, 2023

Conversation

asingh9530
Copy link
Contributor

Hi Team,

This PR made following changes to following functions

  • moving keras data-utils for split_dataset from here making sure all the settings stay's intact.
  • need to rewrite following functions since all of them accept tf.Dataset and we need to support torch and Jax.
    • split_dataset
    • _convert_dataset_to_list
    • _get_data_iterator_from_dataset
    • _get_next_sample
    • _get_type_spec
    • _get_next_sample
    • _rescale_dataset_split_sizes
    • _restore_dataset_from_list

@google-cla
Copy link

google-cla bot commented Jul 16, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Please also bring in the unit tests and add tests for torch datasets.

QQ: rather than special casing torch datasets, could we just support any object that implements len and getitem?

@@ -1,3 +1,11 @@
import tensorflow as tf
import torch
from torch.utils.data import Dataset as torchDataset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only import torch, access Dataset from there

Copy link
Contributor Author

@asingh9530 asingh9530 Jul 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet sure, but I argue that using this specific import is much error proof while importing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added changes for this.


Args:
dataset: A `tf.data.Dataset` object, or a list/tuple of arrays with the
same length.
same length.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use 4 space indent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet yes will apply black in later stages.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -30,20 +36,440 @@ def split_dataset(
Example:

>>> data = np.random.random(size=(1000, 4))
>>> left_ds, right_ds = split_dataset(data, left_size=0.8)
>>> left_ds, right_ds = tf.keras.utils.split_dataset(data, left_size=0.8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No TF references

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet will add a separate commit for doc string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

start_time,
):
if dataset_type_spec in [tuple, list]:
# The try-except here is for NumPy 1.24 compatibility, see:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet yes since we allow dataspec to be tuple this will be required not sure about list will test it out.

@asingh9530
Copy link
Contributor Author

Thanks for the PR! Please also bring in the unit tests and add tests for torch datasets.

QQ: rather than special casing torch datasets, could we just support any object that implements len and getitem?

@fchollet hmm🤔 need to think about it. will add further details to this comment.

@asingh9530 asingh9530 changed the title added keras utils -> keras-core utils [split_dataset] migrating from tf.keras to keras_core Jul 17, 2023
@@ -1,3 +1,11 @@
import tensorflow as tf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing -- we should not import tf or torch at the top of the file since that would make them required dependencies. They should be imported when needed (e.g. only import torch when you need to process a torch dataset).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm ok but since they are being used in multiple function declaring there would allows us to have multiple import statements.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done for torch also importing tf from module_utils

@@ -1,3 +1,8 @@
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be imported here. Torch is not a dependency of the package.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

If integer, it signifies the number of samples to pack
in the left dataset. If `None`, it defaults to the complement
to `right_size`.
the fraction of the data to pack in the left dataset. If integer, it
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use 4 space indent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indent is still 2 spaces.

You shouldn't actually need to modify the docstring at all, mind you. The original docstring was fine.


class DatasetUtilsTest(test_case.TestCase):
def test_split_dataset_list(self):
n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to run sh shell/format.sh and keep lines under 80 chars

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

If integer, it signifies the number of samples to pack
in the left dataset. If `None`, it defaults to the complement
to `right_size`.
the fraction of the data to pack in the left dataset. If integer, it
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indent is still 2 spaces.

You shouldn't actually need to modify the docstring at all, mind you. The original docstring was fine.

right_size=right_size,
shuffle=shuffle,
seed=seed,
from torch.utils.data import Dataset as torchDataset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Do not import torch unless the dataset passed is a torch dataset.
  • Just import torch then access torch.utils.data.Dataset.

To check whether it's a torch dataset, you can do something similar to this https://github.com/keras-team/keras-core/blob/main/keras_core/trainers/epoch_iterator.py#L228

Copy link
Contributor Author

@asingh9530 asingh9530 Jul 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • @fchollet nested import like torch.utils.data.Dataset is not working as module never gets added to globals hence throws import error this is same as I mentioned it here and this still is reproducible.
  • Sure will change logic to check how we detect torch dataset.

right_split = right_split.prefetch(tf.data.AUTOTUNE)
return left_split, right_split

elif dataset_type_spec == torchDataset:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, I don't think we should support torch datasets at all here, because the API is becoming completely inconsistent:

  • pass numpy arrays, get back tf.data.Dataset
  • pass tf.data.Dataset, get back tf.data.Dataset
  • pass torch dataset, get back torch dataset

Let's just stick to always returning a tf.data.Dataset IMO.

Copy link
Contributor Author

@asingh9530 asingh9530 Jul 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet but if that's the case how this function can be used with torch workflow as it will require to return dataset which is compatible with torch ?

Copy link
Contributor Author

@asingh9530 asingh9530 Jul 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet and since returning torch is only constrained within torch backend and will not impact other it should be safe as even going forward with jax we would need to add for that too. What do you think ? 🧐 Also if api consistency is issue one possible solution is to only return numpy arrays and will leave at user how it is needed, this way we will ensure to support every framework but not sure how this will work ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tf.data.Datasets are supported by Keras models with all backends, so it's fine IMO. It's also the only way to stay be backwards compatible with tf.keras, which is very important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet aah got it then let's only use tf.data.Datasets

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet Please you can review now following changes are made.

  • torch.data.Dataset is only imported once.
  • return will be limited to tf.data.Datasets.
  • fixed indentation.

@fchollet Need clarification from you I have moved logic of detecting torch tensor and torch dataset to is_torch_tensor and is_torch_dataset since is_torch_dataloader is under trainers should I move these function there or keep it here only ?

@asingh9530
Copy link
Contributor Author

The indent is still 2 spaces.

You shouldn't actually need to modify the docstring at all, mind you. The original docstring was fine.

Apologies I missed it now it is resolved.

@asingh9530
Copy link
Contributor Author

@fchollet could you approve test workflow as all the changes have been done from my end.

@asingh9530 asingh9530 requested a review from fchollet July 27, 2023 12:47
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! There are various docstring issues left but I'll take it from here.

@fchollet fchollet merged commit 8e96f94 into keras-team:main Jul 28, 2023
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants