Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize iterator checkpointing tests #5278

Merged
merged 10 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 1 addition & 160 deletions dali/test/python/checkpointing/test_dali_checkpointing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,7 +29,6 @@
)
from nose_utils import assert_warns
from nose2.tools import params, cartesian_params
from nose.plugins.attrib import attr
from dataclasses import dataclass
from nvidia.dali import tfrecord as tfrec
from nvidia.dali.auto_aug import auto_augment as aa
Expand Down Expand Up @@ -83,29 +82,6 @@ def check_pipeline_checkpointing_native(pipeline_factory):
compare_pipelines(pipe, restored, pipeline_args["batch_size"], comparsion_iterations)


def check_pipeline_checkpointing_pytorch(pipeline_factory, reader_name=None, size=-1):
from nvidia.dali.plugin.pytorch import DALIGenericIterator

pipe = pipeline_factory(**pipeline_args)
pipe.build()

iter = DALIGenericIterator(pipe, ["data"], auto_reset=True, reader_name=reader_name, size=size)
for _ in range(warmup_epochs):
for _ in iter:
pass

restored = pipeline_factory(**pipeline_args, checkpoint=iter.checkpoints()[0])
restored.build()
iter2 = DALIGenericIterator(
restored, ["data"], auto_reset=True, reader_name=reader_name, size=size
)

for out1, out2 in zip(iter, iter2):
for d1, d2 in zip(out1, out2):
for key in d1.keys():
assert (d1[key] == d2[key]).all()


def check_single_input_operator_pipeline(op, device, **kwargs):
@pipeline_def
def pipeline():
Expand All @@ -126,11 +102,6 @@ def check_single_input_operator(op, device, **kwargs):
check_pipeline_checkpointing_native(pipeline_factory)


def check_single_input_operator_pytorch(op, device, **kwargs):
pipeline_factory = check_single_input_operator_pipeline(op, device, **kwargs)
check_pipeline_checkpointing_pytorch(pipeline_factory, reader_name="Reader")


def check_no_input_operator(op, device, **kwargs):
@pipeline_def
def pipeline_factory():
Expand All @@ -139,14 +110,6 @@ def pipeline_factory():
check_pipeline_checkpointing_native(pipeline_factory)


def check_no_input_operator_pytorch(op, device, **kwargs):
@pipeline_def
def pipeline_factory():
return op(device=device, **kwargs)

check_pipeline_checkpointing_pytorch(pipeline_factory, size=8)


# Readers section


Expand Down Expand Up @@ -668,68 +631,6 @@ def test_numpy_reader(
)


@attr("pytorch")
@params(
(1, 3, 0, 1, True, False, False),
(5, 10, 0, 2, True, False, False),
(3, 64, 3, 4, False, False, False),
(0, 32, 1, 4, False, False, True),
(3, 64, 3, 4, False, False, True),
(1, 8, 0, 2, False, True, False),
(1, 8, 1, 2, False, True, False),
(1, 8, 3, 4, False, True, False),
(1, 3, 0, 1, True, False, False, 1),
(5, 10, 0, 2, True, False, False, 2),
(3, 64, 3, 4, False, False, True, 3),
)
def test_file_reader_pytorch(
num_epochs,
batch_size,
shard_id,
num_shards,
random_shuffle,
shuffle_after_epoch,
stick_to_shard,
iters_into_epoch=None,
):
from nvidia.dali.plugin.pytorch import DALIGenericIterator

@pipeline_def(batch_size=batch_size, device_id=0, num_threads=4, enable_checkpointing=True)
def pipeline():
data, label = fn.readers.file(
name="Reader",
file_root=images_dir,
pad_last_batch=True,
random_shuffle=random_shuffle,
shard_id=shard_id,
num_shards=num_shards,
shuffle_after_epoch=shuffle_after_epoch,
stick_to_shard=stick_to_shard,
)
image = fn.decoders.image_random_crop(data, device="mixed")
image = fn.resize(image, size=(200, 200))
return image, label

p = pipeline()
p.build()

iter = DALIGenericIterator(p, ["data", "labels"], auto_reset=True, reader_name="Reader")
for epoch in range(num_epochs):
for i, _ in enumerate(iter):
if iters_into_epoch is not None:
if epoch == num_epochs - 1 and i == iters_into_epoch - 1:
break

restored = pipeline(checkpoint=iter.checkpoints()[0])
restored.build()
iter2 = DALIGenericIterator(restored, ["data", "labels"], auto_reset=True, reader_name="Reader")

for out1, out2 in zip(iter, iter2):
for d1, d2 in zip(out1, out2):
for key in d1.keys():
assert (d1[key] == d2[key]).all()


@params(0, 1, 2, 3, 4, 5, 6, 7, 8)
def test_multiple_readers(num_iters):
my_images = os.path.join(images_dir, "134")
Expand Down Expand Up @@ -969,36 +870,18 @@ def test_random_coin_flip(device, shape):
check_no_input_operator(fn.random.coin_flip, device, shape=shape)


@attr("pytorch")
@cartesian_params(("cpu", "gpu"), (None, (1,), (10,)))
def test_random_coin_flip_pytorch(device, shape):
check_no_input_operator_pytorch(fn.random.coin_flip, device, shape=shape)


@cartesian_params(("cpu",), (None, (1,), (10,)))
@random_signed_off("random.normal", "normal_distribution")
def test_random_normal(device, shape):
check_no_input_operator(fn.random.normal, device, shape=shape)


@attr("pytorch")
@cartesian_params(("cpu", "gpu"), (None, (1,), (10,)))
def test_random_normal_pytorch(device, shape):
check_no_input_operator_pytorch(fn.random.normal, device, shape=shape)


@cartesian_params(("cpu", "gpu"), (None, (1,), (10,)))
@random_signed_off("random.uniform", "uniform")
def test_random_uniform(device, shape):
check_no_input_operator(fn.random.uniform, device, shape=shape)


@attr("pytorch")
@cartesian_params(("cpu", "gpu"), (None, (1,), (10,)))
def test_random_uniform_pytorch(device, shape):
check_no_input_operator_pytorch(fn.random.uniform, device, shape=shape)


@random_signed_off("segmentation.random_object_bbox")
def test_random_object_bbox():
check_single_input_operator(fn.segmentation.random_object_bbox, "cpu", format="box")
Expand Down Expand Up @@ -1114,34 +997,6 @@ def compare_external_source_pipelines(pipe1, pipe2, steps):
compare_external_source_pipelines(p1, p2, compare_iterations)


def check_external_source_pipeline_checkpointing_pytorch(pipeline_factory, iterations, *, size=-1):
from nvidia.dali.plugin.pytorch import DALIGenericIterator

def run(iterator, iterations):
completed_iterations = 0
while completed_iterations < iterations:
for _ in iterator:
completed_iterations += 1
if completed_iterations == iterations:
break

pipeline = pipeline_factory()
pipeline.build()

iter = DALIGenericIterator(pipeline, ["data"], auto_reset=True, size=size)

run(iter, iterations)

restored = pipeline_factory(checkpoint=iter.checkpoints()[0])
restored.build()
iter2 = DALIGenericIterator(restored, ["data"], auto_reset=True, size=size)

for out1, out2 in zip(iter, iter2):
for d1, d2 in zip(out1, out2):
for key in d1.keys():
assert (d1[key] == d2[key]).all()


def make_external_source_test_pipeline_factory(source, mode, batch_size, parallel, **kwargs):
kwargs["parallel"] = parallel
if mode == "idx":
Expand Down Expand Up @@ -1208,20 +1063,6 @@ def test_external_source_checkpointing(dataset_info, iterations, mode, parallel)
check_external_source_pipeline_checkpointing(pf, iterations, 2 * epoch_size)


@attr("pytorch")
@cartesian_params(
((1, 1), (4, 5)), # (epoch size, batch size)
(0, 4, 11), # test iterations
("idx", "batch_info", "sample_info"), # indexing mode
(True, False), # parallel
)
def test_external_source_checkpointing_pytorch(dataset_info, iterations, mode, parallel):
epoch_size, batch_size = dataset_info
source = make_dummy_source(epoch_size, batch_size, mode)
pf = make_external_source_test_pipeline_factory(source, mode, batch_size, parallel)
check_external_source_pipeline_checkpointing_pytorch(pf, iterations)


@cartesian_params(
("iterator", "iterable", "callable"), # source kind
(True, False), # parallel
Expand Down
Loading