Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable checkpointing in TensorFlow plugin (CPU only) #5334

Merged
merged 7 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
from nvidia.dali.types import DALIDataType
import nvidia.dali.plugin.tf as dali_tf
import tensorflow as tf
import numpy as np
import tempfile
from test_utils import get_dali_extra_path
import os
from nose_utils import assert_raises

data_root = get_dali_extra_path()
images_dir = os.path.join(data_root, "db", "single", "jpeg")


def check_dataset_checkpointing(dali_dataset, *, warmup_iters, test_iters):
it = iter(dali_dataset)
mgr = tf.train.Checkpoint(it)

def read_data(it, iters):
data = None
for _ in range(iters):
out = next(it)
if data is None:
data = [[] for _ in range(len(out))]
assert len(data) == len(out)
for i, x in enumerate(out):
data[i].append(np.asarray(x))
return data

def compare_data(data1, data2):
assert len(data1) == len(data2)
for output1, output2 in zip(data1, data2):
assert len(output1) == len(output2)
for x1, x2 in zip(output1, output2):
assert (x1 == x2).all()

read_data(it, warmup_iters)

with tempfile.TemporaryDirectory() as cpt_dir:
cpt = mgr.save(cpt_dir)
data = read_data(it, test_iters)
mgr.restore(cpt)

data_restored = read_data(it, test_iters)
compare_data(data, data_restored)


def check_pipeline_checkpointing(pipeline_factory, output_dtypes, **kwargs):
p = pipeline_factory()
p.build()
with tf.device("cpu"):
dataset = dali_tf.DALIDataset(pipeline=p, output_dtypes=output_dtypes)
check_dataset_checkpointing(dataset, **kwargs)


def test_random():
@pipeline_def(num_threads=4, device_id=0, batch_size=4, enable_checkpointing=True)
def pipeline():
return fn.random.uniform(dtype=DALIDataType.FLOAT)

check_pipeline_checkpointing(pipeline, (tf.float32,), warmup_iters=7, test_iters=10)


def test_reader():
@pipeline_def(num_threads=4, device_id=0, batch_size=4, enable_checkpointing=True)
def pipeline():
jpeg, label = fn.readers.file(
file_root=images_dir, pad_last_batch=False, random_shuffle=True
)
return (jpeg, label)

check_pipeline_checkpointing(pipeline, (tf.uint8, tf.int32), warmup_iters=7, test_iters=10)


def test_inputs_unsupported():
@pipeline_def(num_threads=4, device_id=0, batch_size=4, enable_checkpointing=True)
def external_source_pipe():
return fn.external_source(source=lambda x: np.array(x.iteration), batch=False)

p = external_source_pipe()
p.build()
with tf.device("cpu"):
dataset = dali_tf.experimental.DALIDatasetWithInputs(
pipeline=p,
output_dtypes=(tf.int64,),
batch_size=5,
output_shapes=(5,),
num_threads=4,
device_id=0,
)
with assert_raises(
Exception, regex="Checkpointing is not supported for DALI dataset with inputs."
):
check_dataset_checkpointing(dataset, warmup_iters=1, test_iters=1)
57 changes: 52 additions & 5 deletions dali_tf_plugin/dali_dataset_op.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2022, 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -403,12 +403,45 @@ class DALIDatasetOp::Dataset::Iterator : public DatasetIterator<Dataset> {
}

#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION >= 3)
Status SaveInternal(SerializationContext *ctx, IteratorStateWriter *writer) override {
return errors::Unimplemented("SaveInternal is not supported for DALI dataset.");
Status SaveInternal(SerializationContext *ctx,
IteratorStateWriter *writer) override {
TF_RETURN_IF_ERROR(checkCheckpointingSupport());
tensorflow::mutex_lock l(mu_);

char *cpt;
size_t n;
daliExternalContextCheckpoint external_context;
TF_DALI_CALL(daliGetSerializedCheckpoint(
&pipeline_handle_, &external_context, &cpt, &n));

tensorflow::Tensor cpt_tensor(DT_UINT8, {n});
memcpy(cpt_tensor.data(), cpt, n);
free(cpt);
TF_RETURN_IF_ERROR(writer->WriteTensor(prefix(), "checkpoint", cpt_tensor));

return OkStatus();
}

Status RestoreInternal(IteratorContext *ctx, IteratorStateReader *reader) override {
return errors::Unimplemented("RestoreInternal is not supported for DALI dataset");
Status RestoreInternal(IteratorContext *ctx,
IteratorStateReader *reader) override {
TF_RETURN_IF_ERROR(checkCheckpointingSupport());
tensorflow::mutex_lock l(mu_);

tensorflow::Tensor cpt_tensor;
TF_RETURN_IF_ERROR(reader->ReadTensor(prefix(), "checkpoint", &cpt_tensor));
auto cpt_data = cpt_tensor.tensor_data();

TF_DALI_CALL(daliDeletePipeline(&pipeline_handle_));
TF_RETURN_IF_ERROR(dataset()->InitPipeline(&pipeline_handle_));
daliExternalContextCheckpoint external_context;
TF_DALI_CALL(daliRestoreFromSerializedCheckpoint(
&pipeline_handle_, cpt_data.data(), cpt_data.size(), &external_context));

// Checkpointing is not supported with separated queues, so we can just prefetch uniformly
TF_DALI_CALL(daliPrefetchUniform(&pipeline_handle_,
dataset()->pipeline_def_.prefetch_queue_depth));

return OkStatus();
}
#endif

Expand Down Expand Up @@ -872,6 +905,20 @@ class DALIDatasetOp::Dataset::Iterator : public DatasetIterator<Dataset> {
return 0;
}

/**
* @brief Check checkpointing support and return and error if not
*/
Status checkCheckpointingSupport() {
if (dataset()->device_type_ == GPU) {
// Current TensorFlow (2.15) will raise an error before even getting here
return errors::Unimplemented("Checkpointing is not supported for DALI GPU dataset.");
}
if (dataset()->HasInputs()) {
return errors::Unimplemented("Checkpointing is not supported for DALI dataset with inputs.");
}
return Status();
}

enum class InputState {
in_progress, // we can still use inputs, none have ended
stop_pending, // input signalled end, we stop reading them, some might be in pipeline
Expand Down
16 changes: 15 additions & 1 deletion docs/advanced_topics_checkpointing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,18 @@ Other kinds of ``source`` don't support checkpointing.
Their state won't be saved in a checkpoint and
after restoring from a checkpoint, they will start from the beginning.
If you want to use checkpointing, we recommend you rewrite your source
to be a supported callable.
to be a supported callable.

Checkpointing in TensorFlow plugin
----------------------------------

:class:`plugin.tf.DALIDataset` is integrated with TensorFlow's ``tf.train.checkpoint``.
Please refer to
`TensorFlow checkpointing documentation page <https://www.tensorflow.org/guide/checkpoint#manual_checkpointing>`_
for more details.

.. warning::
Checkpointing is currently not supported for :class:`plugin.tf.experimental.DALIDatasetWithInputs`.

.. warning::
Checkpointing is currently not supported for GPU datasets.
3 changes: 2 additions & 1 deletion qa/TL0_python-self-test-core/test_body.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ test_py_with_framework() {
${python_invoke_test} --attr '!slow,!pytorch,!mxnet,!cupy' ${test_script}
done
${python_new_invoke_test} -A 'numba' -s type_annotations
${python_new_invoke_test} -A '!slow,numba' -s checkpointing
${python_new_invoke_test} -A '!slow,numba' checkpointing.test_dali_checkpointing
${python_new_invoke_test} -A '!slow,numba' checkpointing.test_dali_stateless_operators
}

test_py() {
Expand Down
1 change: 1 addition & 0 deletions qa/TL0_tensorflow_plugin/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ test_body() {
# DALI TF DATASET run
${python_invoke_test} test_dali_tf_dataset.py
${python_invoke_test} test_dali_tf_conditionals.py
${python_new_invoke_test} checkpointing.test_dali_checkpointing_tf_plugin
if [ -z "$DALI_ENABLE_SANITIZERS" ]; then
${python_invoke_test} test_dali_tf_dataset_shape.py
${python_invoke_test} test_dali_tf_dataset_eager.py
Expand Down
Loading