Skip to content

Commit

Permalink
Rework Flax and Paxml training tutorials (#5205)
Browse files Browse the repository at this point in the history
Description:
Rework training examples for Flax and Paxml to use data_iterator.

Additional information:
Flax and Paxml training examples were adjusted to the new API with data_iterator.

Signed-off-by: Albert Wolant <awolant@nvidia.com>
  • Loading branch information
awolant committed Dec 1, 2023
1 parent 6d45431 commit ac92c6f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 172 deletions.
199 changes: 50 additions & 149 deletions docs/examples/frameworks/jax/flax-basic_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
"\n",
"This simple example shows how to train a neural network implemented in Flax with DALI pipelines. If you want to learn more about training neural networks with Flax, look into [Flax Getting Started](https://flax.readthedocs.io/en/latest/getting_started.html) example.\n",
"\n",
"DALI setup is very similar to the [training example with pure JAX](jax-basic_example.ipynb). The only difference is the addition of a trailing dimension to the returned image to make it compatible with Flax convolutions. If you are familiar with how to use DALI with JAX you can skip this part and move to the training section of this notebook.\n",
"DALI setup is very similar to the [training example with pure JAX](jax-basic_example.ipynb). The only difference is the addition of a trailing dimension to the returned image to make it compatible with Flax convolutions. If you are not familiar with how to use DALI with JAX you can learn more in the [DALI and JAX Getting Started](jax-getting_started.ipynb) example.\n",
"\n",
"We will use MNIST in Caffe2 format from [DALI_extra](https://github.com/NVIDIA/DALI_extra)."
"We use MNIST in Caffe2 format from [DALI_extra](https://github.com/NVIDIA/DALI_extra)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {
"execution": {
"iopub.execute_input": "2023-07-28T07:43:41.850101Z",
Expand All @@ -38,14 +38,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"First step is to create a pipeline definition function that will later be used to create instances of DALI pipelines. It defines all steps of the preprocessing. In this simple example we have `fn.readers.caffe2` for reading data in Caffe2 format, `fn.decoders.image` for image decoding, `fn.crop_mirror_normalize` used to normalize the images and `fn.reshape` to adjust the shape of the output tensors. We also move the labels from the CPU to the GPU memory with `labels.gpu()` and apply one hot encoding to them for training with `fn.one_hot`.\n",
"First step is to create an iterator definition function that will later be used to create instances of DALI iterators. It defines all steps of the preprocessing. In this simple example we have `fn.readers.caffe2` for reading data in Caffe2 format, `fn.decoders.image` for image decoding, `fn.crop_mirror_normalize` used to normalize the images and `fn.reshape` to adjust the shape of the output tensors. We also move the labels from the CPU to the GPU memory with `labels.gpu()` and apply one hot encoding to them for training with `fn.one_hot`.\n",
"\n",
"This example focuses on how to use DALI pipeline with JAX. For more information on DALI pipeline look into [Getting started](../../getting_started.ipynb) and [pipeline documentation](../../../pipeline.rst)"
"This example focuses on how to use DALI pipeline with JAX. For more information on DALI iterator look into [DALI and JAX getting started](jax-getting_started.ipynb) and [pipeline documentation](../../../pipeline.rst)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2023-07-28T07:43:41.855441Z",
Expand All @@ -56,21 +56,22 @@
},
"outputs": [],
"source": [
"from nvidia.dali import pipeline_def\n",
"import nvidia.dali.fn as fn\n",
"import nvidia.dali.types as types\n",
"\n",
"from nvidia.dali.plugin.jax import data_iterator\n",
"\n",
"batch_size = 200\n",
"\n",
"batch_size = 50\n",
"image_size = 28\n",
"num_classes = 10\n",
"\n",
"\n",
"@pipeline_def(device_id=0, batch_size=batch_size, num_threads=4, seed=0)\n",
"def mnist_pipeline(data_path, random_shuffle):\n",
"@data_iterator(output_map=[\"images\", \"labels\"], reader_name=\"mnist_caffe2_reader\")\n",
"def mnist_iterator(data_path, is_training):\n",
" jpegs, labels = fn.readers.caffe2(\n",
" path=data_path,\n",
" random_shuffle=random_shuffle,\n",
" random_shuffle=is_training,\n",
" name=\"mnist_caffe2_reader\")\n",
" images = fn.decoders.image(\n",
" jpegs, device='mixed', output_type=types.GRAY)\n",
Expand All @@ -80,73 +81,22 @@
"\n",
" labels = labels.gpu()\n",
"\n",
" if random_shuffle:\n",
" if is_training:\n",
" labels = fn.one_hot(labels, num_classes=num_classes)\n",
"\n",
" return images, labels"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Next step is to instantiate DALI pipelines and build them. Building creates and initializes pipeline internals."
"With the iterator definition function we can now create DALI iterators."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2023-07-28T07:43:41.989183Z",
"iopub.status.busy": "2023-07-28T07:43:41.988964Z",
"iopub.status.idle": "2023-07-28T07:43:42.104446Z",
"shell.execute_reply": "2023-07-28T07:43:42.103668Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Building pipelines\n",
"<nvidia.dali.pipeline.Pipeline object at 0x7fb455793940>\n",
"<nvidia.dali.pipeline.Pipeline object at 0x7fb455791390>\n"
]
}
],
"source": [
"training_pipeline = mnist_pipeline(data_path=training_data_path, random_shuffle=True)\n",
"validation_pipeline = mnist_pipeline(data_path=validation_data_path, random_shuffle=False)\n",
"\n",
"print('Building pipelines')\n",
"training_pipeline.build()\n",
"validation_pipeline.build()\n",
"\n",
"print(training_pipeline)\n",
"print(validation_pipeline)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"DALI pipeline needs to be wrapped with appropriate DALI iterator to work with JAX. To get the iterator compatible with JAX we need to import it from DALI JAX plugin. In addition to the DALI pipeline object we can pass the `output_map`, `reader_name` and `auto_reset` parameters to the iterator. \n",
"\n",
"**Here is a quick explnation of how these parameters work:**\n",
"\n",
" - `output_map`: iterators return a dictionary with outputs of the pipeline as its values. Keys in this dictionary are defined by `output_map`. For example, `labels` output returned from the DALI pipeline defined above will be accessible as `iterator_output['labels']`,\n",
" - `reader_name`: setting this parameter introduces the notion of an epoch to our iterator. DALI pipeline itself is infinite, it will return the data indefinately, wrapping around the dataset. DALI readers (such as `fn.readers.caffe2` used in this example) have access to the information about the size of the dataset. If we want to pass this information to the iterator, we need to point to the operator that should be queried for the dataset size. We do it by naming the operator (note `name=\"mnist_caffe2_reader\"`) and passing the same name as the value for `reader_name` argument,\n",
" - `auto_reset`: this argument controls the behaviour of the iterator after the end of an epoch. If set to `True`, it will automatically reset the state of the iterator and prepare it to start the next epoch.\n",
"\n",
"If you want to know more about iterator arguments you can look into [JAX iterator documentation](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/plugins/jax_plugin_api.html#nvidia.dali.plugin.jax.DALIGenericIterator)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2023-07-28T07:43:42.109692Z",
Expand All @@ -161,30 +111,21 @@
"output_type": "stream",
"text": [
"Creating iterators\n",
"<nvidia.dali.plugin.jax.DALIGenericIterator object at 0x7fb450301b70>\n",
"Number of batches in training iterator = 300\n",
"Number of batches in validation iterator = 50\n"
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7fdc240f4e50>\n",
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7fdc1c78e020>\n",
"Number of batches in training iterator = 1200\n",
"Number of batches in validation iterator = 200\n"
]
}
],
"source": [
"from nvidia.dali.plugin import jax as dax\n",
"\n",
"\n",
"print('Creating iterators')\n",
"training_iterator = dax.DALIGenericIterator(\n",
" training_pipeline,\n",
" output_map=[\"images\", \"labels\"],\n",
" reader_name=\"mnist_caffe2_reader\",\n",
" auto_reset=True)\n",
"\n",
"validation_iterator = dax.DALIGenericIterator(\n",
" validation_pipeline,\n",
" output_map=[\"images\", \"labels\"],\n",
" reader_name=\"mnist_caffe2_reader\",\n",
" auto_reset=True)\n",
"training_iterator = mnist_iterator(data_path=training_data_path, is_training=True, batch_size=batch_size)\n",
"validation_iterator = mnist_iterator(data_path=validation_data_path, is_training=False, batch_size=batch_size)\n",
"\n",
"print(training_iterator)\n",
"print(validation_iterator)\n",
"\n",
"print(f\"Number of batches in training iterator = {len(training_iterator)}\")\n",
"print(f\"Number of batches in validation iterator = {len(validation_iterator)}\")"
]
Expand All @@ -203,7 +144,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2023-07-28T07:43:43.559575Z",
Expand Down Expand Up @@ -276,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2023-07-28T07:43:43.622376Z",
Expand Down Expand Up @@ -305,7 +246,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -314,15 +255,15 @@
"text": [
"Starting training\n",
"Epoch 0\n",
"Accuracy = 0.9637000560760498\n",
"Accuracy = 0.9551000595092773\n",
"Epoch 1\n",
"Accuracy = 0.9690000414848328\n",
"Accuracy = 0.9691000580787659\n",
"Epoch 2\n",
"Accuracy = 0.975100040435791\n",
"Accuracy = 0.9738000631332397\n",
"Epoch 3\n",
"Accuracy = 0.9761000275611877\n",
"Accuracy = 0.9622000455856323\n",
"Epoch 4\n",
"Accuracy = 0.9765000343322754\n"
"Accuracy = 0.9604000449180603\n"
]
}
],
Expand All @@ -347,33 +288,31 @@
"\n",
"This section shows how to extend the example above to use multiple GPUs.\n",
"\n",
"Again, we start with creating a pipeline definition function. The pipeline was slightly modified to support multiple GPUs.\n",
"Again, we start with creating an iterator definition function. It is a slightly modified version of the function we have seen before.\n",
"\n",
"Note the new arguments passed to `fn.readers.caffe2`: `num_shards` and `shard_id`. They are used to control sharding:\n",
"Note the new arguments passed to `fn.readers.caffe2`, `num_shards` and `shard_id`. They are used to control sharding:\n",
" - `num_shards` sets the total number of shards\n",
" - `shard_id` tells the pipeline for which shard in the training it is responsible. \n",
"\n",
" Also, the `device_id` argument was removed from the decorator. Since we want these pipelines to run on different GPUs we will pass particular `device_id` in pipeline creation. Most often, `device_id` and `shard_id` will have the same value but it is not a requirement. In this example we want the total batch size to be the same as in the single GPU version. That is why we define `batch_size_per_gpu` as `batch_size // jax.device_count()`. Note, that if `batch_size` is not divisible by the number of devices this might require some adjustment to make sure all samples are used in every epoch of the training.\n",
" If you want to learn more about DALI sharding behaviour look into [DALI sharding docs page](../../general/multigpu.ipynb)."
"We add `devices` argument to the decorator to specify which devices we want to use. Here we use all GPUs available to JAX on the machine."
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 200\n",
"image_size = 28\n",
"num_classes = 10\n",
"batch_size_per_gpu = batch_size // jax.device_count()\n",
"\n",
"\n",
"@pipeline_def(batch_size=batch_size_per_gpu, num_threads=4, seed=0)\n",
"def mnist_sharded_pipeline(data_path, random_shuffle, num_shards, shard_id):\n",
"@data_iterator(output_map=[\"images\", \"labels\"], reader_name=\"mnist_caffe2_reader\", devices=jax.devices())\n",
"def mnist_sharded_iterator(data_path, is_training, num_shards, shard_id):\n",
" jpegs, labels = fn.readers.caffe2(\n",
" path=data_path,\n",
" random_shuffle=random_shuffle,\n",
" random_shuffle=is_training,\n",
" name=\"mnist_caffe2_reader\",\n",
" num_shards=num_shards,\n",
" shard_id=shard_id)\n",
Expand All @@ -385,56 +324,22 @@
"\n",
" labels = labels.gpu()\n",
" \n",
" if random_shuffle:\n",
" if is_training:\n",
" labels = fn.one_hot(labels, num_classes=num_classes)\n",
"\n",
" return images, labels\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note the `device_id` values that are passed to place a pipeline on a different device."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pipeline <nvidia.dali.pipeline.Pipeline object at 0x7db28c2b2da0> working on device 0\n",
"Pipeline <nvidia.dali.pipeline.Pipeline object at 0x7db28c2b30d0> working on device 1\n"
]
}
],
"source": [
"pipelines = []\n",
"for id, device in enumerate(jax.devices()):\n",
" pipeline = mnist_sharded_pipeline(\n",
" data_path=training_data_path, random_shuffle=True, num_shards=jax.device_count(), shard_id=id, device_id=id)\n",
" print(f'Pipeline {pipeline} working on device {pipeline.device_id}')\n",
" pipelines.append(pipeline)\n",
" "
" return images, labels"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We created multiple DALI pipelines. Each will run its computations on a different GPU. Each of them will start the preprocessing from a different shard of the training dataset. In this configuration each pipeline will move to the next shard in the next epoch. If you want to control this you can look into `stick_to_shard` argument in the readers.\n",
"\n",
"Like in the single GPU example, we create training iterator. It will encapsulate all the pipelines that we created and return a dictionary of JAX arrays. With this simple configuration it will return arrays compatible with JAX `pmap`ed functions. Leaves of the returned dictionary will have shape `(num_devices, batch_per_device, ...)` and each slice across the first dimension of the array will reside on a different GPU."
"With the iterator definition function we can now create DALI iterators for training on multiple GPUs. This iterator will return outputs compatible with `pmapped` JAX functions. "
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -448,11 +353,7 @@
],
"source": [
"print('Creating training iterator')\n",
"training_iterator = dax.DALIGenericIterator(\n",
" pipelines,\n",
" output_map=[\"images\", \"labels\"],\n",
" reader_name=\"mnist_caffe2_reader\",\n",
" auto_reset=True)\n",
"training_iterator = mnist_sharded_iterator(data_path=training_data_path, is_training=True, batch_size=batch_size)\n",
"\n",
"print(f\"Number of batches in training iterator = {len(training_iterator)}\")"
]
Expand All @@ -469,7 +370,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -500,23 +401,23 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0\n",
"Accuracy = 0.9445000290870667\n",
"Accuracy = 0.9509000182151794\n",
"Epoch 1\n",
"Accuracy = 0.9641000628471375\n",
"Accuracy = 0.9643000364303589\n",
"Epoch 2\n",
"Accuracy = 0.9654000401496887\n",
"Epoch 3\n",
"Accuracy = 0.9724000692367554\n",
"Epoch 3\n",
"Accuracy = 0.9701000452041626\n",
"Epoch 4\n",
"Accuracy = 0.9760000705718994\n"
"Accuracy = 0.9758000373840332\n"
]
}
],
Expand Down
Loading

0 comments on commit ac92c6f

Please sign in to comment.