Skip to content

Commit

Permalink
Debug mode direct ExternalSource (NVIDIA#3605)
Browse files Browse the repository at this point in the history
Add direct external_source in debug mode

Adds debug version of ExternalSource operator removing additional callback
to the backend. Before that, in the debug mode for external source we created
a separate pipeline like for any other operator, which seems like an overkill
considering that all of the ExternalSource implementation (relevant for
the debug mode) is in Python.

Adds TensorList constructor from list of Tensors.

Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
  • Loading branch information
ksztenderski authored and cyyever committed Jan 23, 2022
1 parent ac14fdd commit eb3560f
Show file tree
Hide file tree
Showing 6 changed files with 336 additions and 93 deletions.
49 changes: 49 additions & 0 deletions dali/python/backend_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,29 @@ std::unique_ptr<Tensor<Backend> > TensorListGetItemImpl(TensorList<Backend> &t,
return ptr;
}

template <typename Backend>
std::shared_ptr<TensorList<Backend>> TensorListFromListOfTensors(py::list &list_of_tensors,
string &layout) {
auto tl = std::make_shared<TensorList<Backend>>(list_of_tensors.size());
TensorVector<Backend> tv(list_of_tensors.size());
for (size_t i = 0; i < list_of_tensors.size(); ++i) {
auto &t = list_of_tensors[i].cast<Tensor<Backend>&>();
tv[i].ShareData(t);
}

cudaStream_t stream = 0;
if (!list_of_tensors.empty() && std::is_same<Backend, GPUBackend>::value) {
auto &t = list_of_tensors[0].cast<Tensor<GPUBackend>&>();
stream = UserStream::Get()->GetStream(t);
}

tl->Copy(tv, stream);
tl->SetLayout(layout);
CUDA_CALL(cudaStreamSynchronize(stream));

return tl;
}

#if 0 // TODO(spanev): figure out which return_value_policy to choose
template <typename Backend>
py::tuple TensorListGetItemSliceImpl(TensorList<Backend> &t, py::slice slice) {
Expand Down Expand Up @@ -682,6 +705,19 @@ void ExposeTensorList(py::module &m) {
is_pinned : bool
If provided memory is page-locked (pinned)
)code")
.def(py::init([](py::list &list_of_tensors, string layout = "") {
return TensorListFromListOfTensors<CPUBackend>(list_of_tensors, layout);
}),
"list_of_tensors"_a,
"layout"_a = "",
R"code(
List of tensors residing in the CPU memory.
list_of_tensors : [TensorCPU]
Python list of TensorCPU objects
layout : str
Layout of the data
)code")
.def("_as_gpu", [](TensorList<CPUBackend> &t) {
auto ret = std::make_shared<TensorList<GPUBackend>>();
int dev = -1;
Expand Down Expand Up @@ -881,6 +917,19 @@ void ExposeTensorList(py::module &m) {
}),
"tl"_a,
"layout"_a = py::none())
.def(py::init([](py::list &list_of_tensors, string layout = "") {
return TensorListFromListOfTensors<GPUBackend>(list_of_tensors, layout);
}),
"list_of_tensors"_a,
"layout"_a = "",
R"code(
List of tensors residing in the GPU memory.
list_of_tensors : [TensorGPU]
Python list of TensorGPU objects
layout : str
Layout of the data
)code")
.def(py::init([](const py::object object, string layout = "", int device_id = -1) {
auto t = std::make_shared<TensorList<GPUBackend>>();
FillTensorFromCudaArray(object, t.get(), device_id, layout);
Expand Down
171 changes: 143 additions & 28 deletions dali/python/nvidia/dali/_debug_mode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,8 +16,13 @@
import nvidia.dali.pipeline as _pipeline
import nvidia.dali.tensors as _Tensors
import nvidia.dali.types as _types
from nvidia.dali.data_node import DataNode as _DataNode, _arithm_op
from nvidia.dali.data_node import DataNode as _DataNode, _arithm_op, _check
from nvidia.dali.external_source import _prep_data_for_feed_input
from nvidia.dali._utils.external_source_impl import \
get_callback_from_source as _get_callback_from_source, \
accepted_arg_count as _accepted_arg_count
import inspect
from queue import Queue


class DataNodeDebug(_DataNode):
Expand Down Expand Up @@ -114,6 +119,117 @@ def __rxor__(self, other):
return _PipelineDebug.current()._wrap_op_call(_arithm_op, ["bitxor", other, self], {})


def _transform_data_to_tensorlist(data, batch_size, layout=None):
data = _prep_data_for_feed_input(data, batch_size, layout)

if isinstance(data, list):
if isinstance(data[0], _Tensors.TensorGPU):
data = _Tensors.TensorListGPU(data, layout or "")
else:
data = _Tensors.TensorListCPU(data, layout or "")

return data


class _ExternalSourceDebug:
"""Debug mode version of ExternalSource operator."""

def __init__(
self, source=None, num_outputs=None, batch_size=-1, cycle=None, name=None, layout=None,
batch=None, batch_info=None):
if name is not None and num_outputs is not None:
raise ValueError("`num_outputs` is not compatible with named `ExternalSource`")

callback, source_desc = _get_callback_from_source(source, cycle, batch_info or False)

self._name = name
self._layout = layout
self._num_outputs = num_outputs
self._batch = batch
self._batch_size = batch_size
self._callback = callback
self._source_desc = source_desc
self._batch_info = batch_info
self._current_iter = 0
self._current_sample = 0
self._feed_inputs = Queue()

if callback is not None:
arg_count = _accepted_arg_count(callback)
if arg_count not in [0, 1]:
raise TypeError("External source callback must be a callable with 0 or 1 argument")
self.accepts_arg = arg_count > 0

def _callback_args(self, idx_in_batch, epoch_idx):
if not self.accepts_arg:
return ()
if idx_in_batch is not None:
arg = _types.SampleInfo(
self._current_sample + idx_in_batch,
idx_in_batch,
self._current_iter,
epoch_idx)
elif self._batch_info:
arg = _types.BatchInfo(
self._current_iter,
epoch_idx)
else:
arg = self._current_iter
return (arg,)

def _get_batch(self, epoch_idx):
try:
if self._batch:
callback_out = self._callback(*self._callback_args(None, epoch_idx))
else:
callback_out = [self._callback(*self._callback_args(i, epoch_idx))
for i in range(self._batch_size)]
self._current_sample += self._batch_size
self._current_iter += 1
except StopIteration:
self._current_iter = 0
self._current_sample = 0
raise
return callback_out

def _feed_input(self, data, kwargs):
if self._callback is not None:
raise RuntimeError(f"Cannot use `feed_input` on the external source '{self._name}' with a `source`"
" argument specified.")

self._feed_inputs.put((data, kwargs))

def _fetch(self, epoch_idx):
"""Fetches data from callback or provided with feed_input."""

def to_data_node_debug(data):
data = _transform_data_to_tensorlist(data, self._batch_size, layout)
device = 'gpu' if isinstance(data, _Tensors.TensorListGPU) else 'cpu'

return DataNodeDebug(data, self._name, device, self._source_desc)

if self._callback is not None:
callback_out = self._get_batch(epoch_idx)
layout = self._layout
if self._num_outputs is not None:
raw_data = []
for idx in range(self._num_outputs):
if self._batch:
raw_data.append(callback_out[idx])
else:
raw_data.append([callback_out[i][idx] for i in range(self._batch_size)])
else:
raw_data = callback_out
else:
raw_data, feed_input_params = self._feed_inputs.get()
layout = feed_input_params.get('layout', self._layout)

if self._num_outputs is not None:
return [to_data_node_debug(data) for data in raw_data]

return to_data_node_debug(raw_data)


class _PipelineDebug(_pipeline.Pipeline):
"""Debug mode for pipeline. Allows access to data inside the pipeline execution by wrapping all
operators inside their pipelines"""
Expand All @@ -123,12 +239,11 @@ def __init__(self, exec_func, **kwargs):
kwargs['exec_pipelined'] = False
kwargs['exec_async'] = False
self._debug_on = False
self._external_source_debug = False
self._es_input_name = 'input_'
self._es_kwarg_name = 'kwarg_'
self._subpipeline_kwargs = kwargs
self._subpipelines = {}
self._external_source_pipelines = {}
self._external_sources = {}
self._feed_input_data = {}
self._subpipelines_built = False
self._cur_subpipeline_id = -1
Expand All @@ -155,9 +270,8 @@ def run(self):
import numpy as np
if not self._built:
raise RuntimeError('Pipeline must be built first.')

self._debug_on = True
self._external_source_debug = True
self._cur_subpipeline_id = -1
_pipeline.Pipeline.push_current(self)

Expand All @@ -166,9 +280,8 @@ def run(self):
res = ()
elif not isinstance(res, tuple):
res = (res,)

self._debug_on = False
self._external_source_debug = False
if not self._subpipelines_built:
self._subpipelines_built = True
_pipeline.Pipeline.pop_current()
Expand All @@ -183,14 +296,19 @@ def feed_input(self, data_node, data, **kwargs):
Refer to :meth:`Pipeline.feed_input() <nvidia.dali.Pipeline.feed_input>` for details."""
if not self._built:
raise RuntimeError("Pipeline must be built first.")
if isinstance(data_node, str):
name = data_node
else:
_check(data_node)
name = data_node.name

if data_node not in self._external_source_pipelines:
if name not in self._external_sources:
# Saving data, because pipeline hasn't been run yet.
if data_node not in self._feed_input_data:
self._feed_input_data[data_node] = []
self._feed_input_data[data_node].append((data, kwargs))
if name not in self._feed_input_data:
self._feed_input_data[name] = []
self._feed_input_data[name].append((data, kwargs))
else:
self._external_source_pipelines[data_node].feed_input(data_node, data, **kwargs)
self._external_sources[name]._feed_input(name, data, kwargs)

@staticmethod
def _classify_data(data):
Expand Down Expand Up @@ -218,7 +336,7 @@ def is_primitive_type(x):
return True, device, data_list
else:
if is_primitive_type(data) or _types._is_numpy_array(data) or \
_types._is_mxnet_array(data) or isinstance(data, _Tensors.TensorCPU):
_types._is_mxnet_array(data) or isinstance(data, _Tensors.TensorCPU):
return False, 'cpu', data
if _types._is_torch_tensor(data):
return False, 'gpu' if data.is_cuda else 'cpu', data
Expand Down Expand Up @@ -271,30 +389,25 @@ def pipe():

return tuple(res) if isinstance(res, list) else res

self._external_source_debug = False
p = pipe()
p.build()
self._external_source_debug = True
return (p, inputs_external_source, kwargs_external_source)

def _external_source(self, op_wrapper, *inputs, **kwargs):
# TODO(ksztenderski): Possibly remove this wrapper to avoid running data through the backend and back.
def _external_source(self, name=None, **kwargs):
self._cur_subpipeline_id += 1
key = inspect.getframeinfo(
inspect.currentframe().f_back.f_back)[:3] + (self._cur_subpipeline_id,)
name = kwargs['name']
if not self._subpipelines_built:
self._subpipelines[key] = self._create_subpipeline(op_wrapper, inputs, kwargs)
pipe = self._subpipelines[key][0]
es = _ExternalSourceDebug(batch_size=self._max_batch_size, name=name, **kwargs)

# feed_input all data collected after build and before run
for (data, fi_kwargs) in self._feed_input_data.pop(name, []):
pipe.feed_input(name, data, **fi_kwargs)
es._feed_input(data, fi_kwargs)

self._external_source_pipelines[name] = pipe
self._external_sources[key] = es

if key in self._subpipelines:
return self._run_subpipeline(self._subpipelines[key], 'ExternalSource', inputs, kwargs)
if key in self._external_sources:
return self._external_sources[key]._fetch(self._epoch_idx)
else:
raise RuntimeError(f"Unexpected operator 'ExternalSource'. Debug mode does not support"
" changing the order of operators executed within the pipeline.")
Expand All @@ -309,7 +422,7 @@ def _run_subpipeline(self, pipe_tuple, op_name, inputs, kwargs):
f" it was built. {len(inputs_es)} != {len(inputs)}")
if len(kwargs_es) != len(kwargs.items()):
raise RuntimeError(f"Trying to use operator '{op_name}' with different number of keyward arguments"
" than when it was built.")
" than when it was built.")

def unexpected_argument_msg(to_external_source):
return f"recognized as {'batch' if to_external_source else 'constant'} but when built value" \
Expand All @@ -318,13 +431,15 @@ def unexpected_argument_msg(to_external_source):
for i, input in enumerate(inputs):
to_external_source, _, data = _PipelineDebug._classify_data(input)
if to_external_source != inputs_es[i]:
raise RuntimeError(f"In operator '{op_name}' input {input} {unexpected_argument_msg(to_external_source)}.")
raise RuntimeError(
f"In operator '{op_name}' input {input} {unexpected_argument_msg(to_external_source)}.")
if to_external_source:
pipe.feed_input(f'{self._es_input_name}{i}', data)
for key, value in kwargs.items():
to_external_source, _, data = _PipelineDebug._classify_data(value)
if to_external_source != kwargs_es[key]:
raise RuntimeError(f"In operator '{op_name}' argument '{key}' {unexpected_argument_msg(to_external_source)}.")
raise RuntimeError(
f"In operator '{op_name}' argument '{key}' {unexpected_argument_msg(to_external_source)}.")
if to_external_source:
pipe.feed_input(f'{self._es_kwarg_name}{key}', data)

Expand Down
Loading

0 comments on commit eb3560f

Please sign in to comment.