Skip to content

Commit

Permalink
Ensure sample encapsulation in Tensor Vector (NVIDIA#3701)
Browse files Browse the repository at this point in the history
Add APIs matching TensorList to TensorVector:
* sample pointer accessors
* Set/GetMeta

Change operator[] to return [Const]SampleView.
Introduce UnsafeSetSample and UnsafeCopySample 
to replace TensorVector[i].ShareData(tensor) 
and TensorVector[i].Copy(tensor) - they work with current
code base, but for proper sample-based data structure
more checks should be introduced - intended for follow up.

Adjust code where necessary:
* where possible use data accessors directly on the TensorVector
    instead of the sample, as it should be faster than create temporary, so:
    `tv[i].mutable_data<T>()` -> `tv.mutable_tensor<T>(i)` etc.
* Using SampleViews is compatible with code that uses `view<T>`,
    as `view<T>(Tensor)` is equivalent to `view<T>(sample_view(Tensor))` 

Adjustments:
* allow views to work with scalar Tensors (they treated them as empty)
* introduce distinct SampleView and ConstSampleView as they need to
    be returned by value and we need sensible overloads for `view<>`.
* allow to access `capacity` and `nbytes` of individual samples,
    introduce _chunks_capacity and _chunks_nbytes for that.

Next steps written as TODO in TensorVector dosctring.

Current naming:
The `Unsafe` prefix in SetSample and CopySample is intended to temporary stay
there to discourage introduction of new use cases till the followup introduces
remaining checks. Capacity and nbytes of individual allocations have leading 
underscore as the API is to be reworked and is not intended for new usages.  

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki authored and cyyever committed May 13, 2022
1 parent 83ee47b commit 57ab01c
Show file tree
Hide file tree
Showing 72 changed files with 992 additions and 402 deletions.
4 changes: 2 additions & 2 deletions dali/benchmark/displacement_cpu_bench.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -88,7 +88,7 @@ void DisplacementBench(benchmark::State& st) {//NOLINT
// tensor out is resized by operator itself in DisplacementFilter::DataDependentSetup()

// TODO(klecki) Accomodate to use different inputs from test data
auto *ptr = (*tensor_in)[0].template mutable_data<T>();
auto *ptr = (*tensor_in).template mutable_tensor<T>(0);
for (int i = 0; i < N; i++) {
ptr[i] = i;
}
Expand Down
15 changes: 6 additions & 9 deletions dali/benchmark/operator_bench.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,16 +54,13 @@ class OperatorBench : public DALIBenchmark {
auto op_ptr = InstantiateOperator(op_spec);

auto data_in = std::make_shared<TensorVector<CPUBackend>>(batch_size);
for (auto &in_ptr : *data_in) {
in_ptr = std::make_shared<Tensor<CPUBackend>>();
in_ptr->set_type<T>();
in_ptr->Resize({H, W, C});
in_ptr->SetLayout("HWC");
}
data_in->set_type<T>();
data_in->Resize(uniform_list_shape(batch_size, TensorShape<>{H, W, C}));
data_in->SetLayout("HWC");

if (fill_in_data) {
for (auto &in_ptr : *data_in) {
auto *ptr = in_ptr->template mutable_data<T>();
for (int sample_idx = 0; sample_idx < batch_size; sample_idx++) {
auto *ptr = data_in->template mutable_tensor<T>(sample_idx);
for (int i = 0; i < N; i++) {
ptr[i] = static_cast<T>(i);
}
Expand Down
16 changes: 10 additions & 6 deletions dali/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ void SetExternalInputTensors(daliPipelineHandle *pipe_handle, const char *name,
if (layout_str != nullptr) {
layout = dali::TensorLayout(layout_str);
}
dali::TensorVector<Backend> data(curr_batch_size);
auto type_id = static_cast<dali::DALIDataType>(data_type);
auto elem_sizeof = dali::TypeTable::GetTypeInfo(type_id).size();

Expand All @@ -139,15 +138,20 @@ void SetExternalInputTensors(daliPipelineHandle *pipe_handle, const char *name,
else
order = AccessOrder::host();

dali::TensorVector<Backend> data(curr_batch_size);
data.set_pinned(flags & DALI_ext_pinned);
data.set_sample_dim(sample_dim);
data.set_type(type_id);
data.set_order(order);
data.SetLayout(layout);

for (int i = 0; i < curr_batch_size; i++) {
// We cast away the const from data_ptr, as there is no other way of passing it to the
// Tensor as we must also set the shape and type metadata.
// The vector that we pass to pipeline is const.
data[i].set_pinned(flags & DALI_ext_pinned);
data[i].set_order(order);
data[i].ShareData(const_cast<void *>(data_ptr[i]), tl_shape[i].num_elements() * elem_sizeof);
data[i].Resize(tl_shape[i], type_id);
data[i].SetLayout(layout);
std::shared_ptr<void> ptr(const_cast<void *>(data_ptr[i]), [](void *){}); // no deleter
data.UnsafeSetSample(i, ptr, tl_shape[i].num_elements() * elem_sizeof, flags & DALI_ext_pinned,
tl_shape[i], type_id, order, layout);
}
pipeline->SetExternalInput(name, data, order,
flags & DALI_ext_force_sync,
Expand Down
6 changes: 3 additions & 3 deletions dali/operators/audio/nonsilence_op.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -228,8 +228,8 @@ class NonsilenceOperatorCpu : public NonsilenceOperator<CPUBackend> {
args.reset_interval = reset_interval_;

auto res = DetectNonsilenceRegion(intermediate_buffers_[thread_id], args);
auto beg_ptr = output_begin[sample_id].mutable_data<int>();
auto len_ptr = output_length[sample_id].mutable_data<int>();
auto *beg_ptr = output_begin.mutable_tensor<int>(sample_id);
auto *len_ptr = output_length.mutable_tensor<int>(sample_id);
*beg_ptr = res.first;
*len_ptr = res.second;
}, in_shape.tensor_size(sample_id));
Expand Down
10 changes: 5 additions & 5 deletions dali/operators/audio/preemphasis_filter_op.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,11 +65,11 @@ void PreemphasisFilterCPU::RunImplTyped(workspace_t<CPUBackend> &ws) {
for (int sample_id = 0; sample_id < nsamples; sample_id++) {
tp.AddWork(
[this, &output, &input, sample_id](int thread_id) {
const auto in_ptr = input[sample_id].data<InputType>();
auto out_ptr = output[sample_id].mutable_data<OutputType>();
DALI_ENFORCE(input[sample_id].shape() == output[sample_id].shape(),
const auto *in_ptr = input.tensor<InputType>(sample_id);
auto *out_ptr = output.mutable_tensor<OutputType>(sample_id);
DALI_ENFORCE(input.tensor_shape(sample_id) == output.tensor_shape(sample_id),
"Input and output shapes don't match");
auto n = volume(output[sample_id].shape());
auto n = volume(output.tensor_shape(sample_id));
auto coeff = preemph_coeff_[sample_id];
if (coeff == 0.0f) {
for (int64_t j = 0; j < n; j++) {
Expand Down
8 changes: 4 additions & 4 deletions dali/operators/decoder/audio/audio_decoder_op.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -88,13 +88,13 @@ AudioDecoderCpu::SetupImpl(std::vector<OutputDesc> &output_desc, const workspace

for (int i = 0; i < batch_size; i++) {
auto &meta = sample_meta_[i] =
decoders_[i]->Open({static_cast<const char *>(input[i].raw_data()),
input[i].shape().num_elements()});
decoders_[i]->Open({static_cast<const char *>(input.raw_tensor(i)),
input.tensor_shape(i).num_elements()});
TensorShape<> data_sample_shape = DecodedAudioShape(
meta, use_resampling_ ? target_sample_rates_[i] : -1.0f, downmix_);
shape_data.set_tensor_shape(i, data_sample_shape);
shape_rate.set_tensor_shape(i, {});
files_names_[i] = input[i].GetSourceInfo();
files_names_[i] = input.GetMeta(i).GetSourceInfo();
}

output_desc[0] = { shape_data, output_type_ };
Expand Down
16 changes: 14 additions & 2 deletions dali/operators/decoder/decoder_test.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include <memory>
#include "dali/pipeline/data/types.h"
#include "dali/test/dali_test_decoder.h"

namespace dali {
Expand Down Expand Up @@ -64,6 +65,7 @@ class DecodeTestBase : public GenericDecoderTest<ImgType> {
// single input - encoded images
// single output - decoded images
TensorVector<CPUBackend> out(inputs[0]->num_samples());
std::vector<Tensor<CPUBackend>> tmp_out(inputs[0]->num_samples());
const TensorList<CPUBackend> &encoded_data = *inputs[0];
const int c = this->GetNumColorComp();

Expand All @@ -72,7 +74,17 @@ class DecodeTestBase : public GenericDecoderTest<ImgType> {
auto data_size = volume(encoded_data.tensor_shape(i));
this->DecodeImage(
data, data_size, c, this->ImageType(),
&out[i], GetCropWindowGenerator(i));
&tmp_out[i], GetCropWindowGenerator(i));
}

TensorListShape<> out_shape(inputs[0]->num_samples(), 3);
for (size_t i = 0; i < encoded_data.num_samples(); ++i) {
out_shape.set_tensor_shape(i, tmp_out[i].shape());
}
out.SetupLike(tmp_out[0]);
out.Resize(out_shape, DALI_UINT8);
for (size_t i = 0; i < encoded_data.num_samples(); ++i) {
out.UnsafeSetSample(i, tmp_out[i]);
}

vector<std::shared_ptr<TensorList<CPUBackend>>> outputs;
Expand Down
38 changes: 22 additions & 16 deletions dali/operators/decoder/nvjpeg/nvjpeg_decoder_decoupled_api.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -554,15 +554,16 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {
samples_jpeg2k_.clear();
#endif // NVJPEG2K_ENABLED

const auto &input = ws.Input<CPUBackend>(0);
for (int i = 0; i < curr_batch_size; i++) {
const auto &in = ws.Input<CPUBackend>(0)[i];
const auto in_size = in.size();
thread_pool_.AddWork([this, i, &in, in_size](int tid) {
auto *input_data = in.data<uint8_t>();
auto *input_data = input.tensor<uint8_t>(i);
const auto in_size = input.tensor_shape(i).num_elements();
const auto &source_info = input.GetMeta(i).GetSourceInfo();
thread_pool_.AddWork([this, i, input_data, in_size, source_info](int tid) {
SampleData &data = sample_data_[i];
data.clear();
data.sample_idx = i;
data.file_name = in.GetSourceInfo();
data.file_name = source_info;
data.encoded_length = in_size;

auto cached_shape = CacheImageShape(data.file_name);
Expand Down Expand Up @@ -704,15 +705,17 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {

void ProcessImagesCuda(MixedWorkspace &ws) {
auto& output = ws.Output<GPUBackend>(0);
const auto &input = ws.Input<CPUBackend>(0);
for (auto *sample : samples_single_) {
assert(sample);
auto i = sample->sample_idx;
auto *output_data = output.mutable_tensor<uint8_t>(i);
const auto &in = ws.Input<CPUBackend>(0)[i];
const auto *in_data = input.tensor<uint8_t>(i);
const auto in_size = input.tensor_shape(i).num_elements();
thread_pool_.AddWork(
[this, sample, &in, output_data](int tid) {
SampleWorker(sample->sample_idx, sample->file_name, in.size(), tid,
in.data<uint8_t>(), output_data, streams_[tid]);
[this, sample, in_data, in_size, output_data](int tid) {
SampleWorker(sample->sample_idx, sample->file_name, in_size, tid,
in_data, output_data, streams_[tid]);
}, task_priority_seq_--); // FIFO order, since the samples were already ordered
}
}
Expand Down Expand Up @@ -808,15 +811,17 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {
}

void ProcessImagesHost(MixedWorkspace &ws) {
const auto &input = ws.Input<CPUBackend>(0);
auto& output = ws.Output<GPUBackend>(0);
for (auto *sample : samples_host_) {
auto i = sample->sample_idx;
const auto *input_data = input.tensor<uint8_t>(i);
auto in_size = input.tensor_shape(i).num_elements();
auto *output_data = output.mutable_tensor<uint8_t>(i);
const auto &in = ws.Input<CPUBackend>(0)[i];
ImageCache::ImageShape shape = output_shape_[i].to_static<3>();
thread_pool_.AddWork(
[this, sample, &in, output_data, shape](int tid) {
HostFallback<StorageGPU>(in.data<uint8_t>(), in.size(), output_image_type_, output_data,
[this, sample, input_data, in_size, output_data, shape](int tid) {
HostFallback<StorageGPU>(input_data, in_size, output_image_type_, output_data,
streams_[tid], sample->file_name, sample->roi, use_fast_idct_);
CacheStore(sample->file_name, output_data, shape, streams_[tid]);
}, task_priority_seq_--); // FIFO order, since the samples were already ordered
Expand Down Expand Up @@ -846,13 +851,14 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {
int j = 0;
TensorVector<CPUBackend> tv(samples_hw_batched_.size());

const auto &input = ws.Input<CPUBackend>(0);
tv.SetupLike(input);
for (auto *sample : samples_hw_batched_) {
int i = sample->sample_idx;
const auto &in = ws.Input<CPUBackend>(0)[i];
const auto &out_shape = output_shape_.tensor_shape(i);

tv[j].ShareData(const_cast<Tensor<CPUBackend> &>(in));
in_lengths_[j] = in.size();
tv.UnsafeSetSample(j, input, i);
in_lengths_[j] = input.tensor_shape(i).num_elements();
nvjpeg_destinations_[j].channel[0] = output.mutable_tensor<uint8_t>(i);
nvjpeg_destinations_[j].pitch[0] = out_shape[1] * out_shape[2];
nvjpeg_params_[j] = sample->params;
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/generic/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ void CastCPU::RunImpl(HostWorkspace &ws) {
TYPE_SWITCH(itype, type2id, IType, CAST_ALLOWED_TYPES, (

for (int sample_id = 0; sample_id < num_samples; sample_id++) {
auto *out = output[sample_id].mutable_data<OType>();
const auto *in = input[sample_id].data<IType>();
auto *out = output.mutable_tensor<OType>(sample_id);
const auto *in = input.tensor<IType>(sample_id);
auto size = input_shape.tensor_size(sample_id);
tp.AddWork([out, in, size](int thread_id) { CpuHelper<OType, IType>(out, in, size); },
size);
Expand Down
8 changes: 4 additions & 4 deletions dali/operators/generic/constant.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -80,7 +80,7 @@ void FillTensorVector(
assert(is_uniform(shape));
int64_t n = shape[0].num_elements();
assert(src.size() == static_cast<size_t>(n) || src.size() == 1);
Dst *out = dst[0].mutable_data<Dst>();
Dst *out = dst.mutable_tensor<Dst>(0);
if (src.size() == 1) {
Dst val = ConvertSat<Dst>(src[0]);
for (int64_t i = 0; i < n; i++) {
Expand All @@ -92,7 +92,7 @@ void FillTensorVector(
}
}
for (int i = 1; i < shape.num_samples(); i++) {
dst[i].ShareData(dst[0]);
dst.UnsafeSetSample(i, dst, 0);
}
}
} // namespace
Expand All @@ -116,7 +116,7 @@ void Constant<CPUBackend>::RunImpl(HostWorkspace &ws) {
out.Resize(output_shape_);
int N = output_shape_.num_samples();
for (int i = 0; i < N; i++) {
assert(out[i].raw_data() == output_[i].raw_data());
assert(out.raw_tensor(i) == output_.raw_tensor(i));
}
out.SetLayout(layout_);
}
Expand Down
18 changes: 9 additions & 9 deletions dali/operators/generic/erase/erase_utils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -95,17 +95,17 @@ std::vector<kernels::EraseArgs<T, Dims>> GetEraseArgs(const OpSpec &spec,

for (int i = 0; i < nsamples; i++) {
if (has_tensor_roi_anchor) {
const auto& anchor = ws.ArgumentInput("anchor")[i];
assert(anchor.size() > 0);
roi_anchor.resize(anchor.size());
std::memcpy(roi_anchor.data(), anchor.data<float>(), sizeof(float) * roi_anchor.size());
auto anchor = view<const float>(ws.ArgumentInput("anchor")[i]);
assert(anchor.shape.num_elements() > 0);
roi_anchor.resize(anchor.shape.num_elements());
std::memcpy(roi_anchor.data(), anchor.data, sizeof(float) * roi_anchor.size());
}

if (has_tensor_roi_shape) {
const auto& shape = ws.ArgumentInput("shape")[i];
assert(shape.size() > 0);
roi_shape.resize(shape.size());
std::memcpy(roi_shape.data(), shape.data<float>(), sizeof(float) * roi_shape.size());
auto shape = view<const float>(ws.ArgumentInput("shape")[i]);
assert(shape.shape.num_elements() > 0);
roi_shape.resize(shape.num_elements());
std::memcpy(roi_shape.data(), shape.data, sizeof(float) * roi_shape.size());
}

DALI_ENFORCE(roi_anchor.size() == roi_shape.size());
Expand Down
6 changes: 3 additions & 3 deletions dali/operators/generic/lookup_table.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -25,8 +25,8 @@ void LookupValuesImpl(ThreadPool &tp, TensorVector<CPUBackend> &output,
const Output *lookup_table, const Output default_value) {
for (int sample_idx = 0; sample_idx < shape.num_samples(); sample_idx++) {
auto data_size = shape.tensor_size(sample_idx);
auto *out_data = output[sample_idx].mutable_data<Output>();
const auto *in_data = input[sample_idx].data<Input>();
auto *out_data = output.mutable_tensor<Output>(sample_idx);
const auto *in_data = input.tensor<Input>(sample_idx);
tp.AddWork(
[=](int thread_id) {
for (int64_t i = 0; i < data_size; i++) {
Expand Down
5 changes: 3 additions & 2 deletions dali/operators/generic/permute_batch.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,7 +45,8 @@ void PermuteBatch<CPUBackend>::RunImpl(HostWorkspace &ws) {
int src = indices_[i];
tp.AddWork([&, i, src](int tid) {
output.SetMeta(i, input.GetMeta(i));
output[i].Copy(input[src]);
// TODO(klecki): SetSample
output.UnsafeCopySample(i, input, src);
}, size);
}
tp.RunAll();
Expand Down
6 changes: 3 additions & 3 deletions dali/operators/generic/reshape.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -393,8 +393,8 @@ void Reshape<CPUBackend>::RunImpl(HostWorkspace &ws) {
out.Resize(output_shape_, output_type_->id());
int N = output_shape_.num_samples();
for (int i = 0; i < N; i++) {
assert(out[i].raw_data() == in[i].raw_data());
assert(out[i].shape() == output_shape_[i]);
assert(out.raw_tensor(i) == in.raw_tensor(i));
assert(out.tensor_shape(i) == output_shape_[i]);
}
out.SetLayout(layout);
}
Expand Down
Loading

0 comments on commit 57ab01c

Please sign in to comment.