Skip to content

Commit

Permalink
Almost API handling
Browse files Browse the repository at this point in the history
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki committed Feb 19, 2024
1 parent 6458765 commit 7b09753
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 55 deletions.
4 changes: 2 additions & 2 deletions dali/pipeline/operator/builtin/conditional/merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class Merge : public StatelessOperator<Backend> {
"The 'predicate' argument is required to be present as argument input.");
RegisterTestsDiagnostics();

auto origin_stack_trace = GetOperatorOriginInfo(spec_);
std::cout << "Merge>>> " << FormatStack(origin_stack_trace, true) << std::endl;
// auto origin_stack_trace = GetOperatorOriginInfo(spec_);
// std::cout << "Merge>>> " << FormatStack(origin_stack_trace, true) << std::endl;
}

~Merge() override = default;
Expand Down
4 changes: 2 additions & 2 deletions dali/pipeline/operator/builtin/conditional/split.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class Split : public StatelessOperator<Backend> {
DALI_ENFORCE(spec.HasTensorArgument("predicate"),
"The 'predicate' argument is required to be present as argument input.");
RegisterTestsDiagnostics();
auto origin_stack_trace = GetOperatorOriginInfo(spec_);
std::cout << "SPLIT>>> " << FormatStack(origin_stack_trace, true) << std::endl;
// auto origin_stack_trace = GetOperatorOriginInfo(spec_);
// std::cout << "SPLIT>>> " << FormatStack(origin_stack_trace, true) << std::endl;
}

~Split() override = default;
Expand Down
9 changes: 6 additions & 3 deletions dali/pipeline/operator/op_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ graph even if its outputs are not used.)code",
false);


// For simplicity we pass StackSummary as 4 separate arguments so we don't need to extend DALI
// with support for special FrameSummary type.
// List of FrameSummaries can be reconstructed using utility functions.
// For simplicity we pass StackSummary as 4 separate arguments so we don't need to extend DALI
// with support for special FrameSummary type.
// List of FrameSummaries can be reconstructed using utility functions.
AddOptionalArg("_origin_stack_filename", R"code(Every operator defined in Python captures and
processes the StackSummary (a List[FrameSummary], defined in Python traceback module) that describe
the callstack between the start of pipeline definition tracing and the "call" to the operator
Expand All @@ -109,6 +109,9 @@ _origin_stack_filename for more information.)code",
AddOptionalArg("_pipeline_internal", R"code(Boolean specifying if this operator was defined within
a pipeline scope. False if it was defined without pipeline being set as current.)code",
true);

AddOptionalArg("_api", "String representing API used to create operator: \"fn\" or \"ops\".",
"fn");
}


Expand Down
43 changes: 22 additions & 21 deletions dali/python/nvidia/dali/_utils/dali_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,34 +90,35 @@ def _filter_autograph_frames(stack_summary, frame_map, frame_filter):
skip = frame_filter.is_filtered(origin_frame_entry.filename)

# Detect repeated appearance of function transformed by AG
if _collapse_ag_frames:
# AutoGraph is wrapping a function call - entry point
if is_frame_converted_call(frame_entry):
is_ag_function_call_start = True
# It quits to a non-AG code, treat it as normal from now-on
if is_frame_call_unconverted(frame_entry):
is_ag_function_call_start = False
current_function_region = None
# We are in the first part of the converted_func call that is not skipped
# (as we are in user code, remember the function)
if is_ag_function_call_start and not skip:
is_ag_function_call_start = False
current_function_region = origin_frame_entry
skip = True
origin_stack_summary.append(origin_frame_entry)
# AutoGraph is wrapping a function call - entry point
if is_frame_converted_call(frame_entry):
is_ag_function_call_start = True
# It quits to a non-AG code, treat it as normal from now-on
if is_frame_call_unconverted(frame_entry):
is_ag_function_call_start = False
current_function_region = None
# We are in the first part of the converted_func call that is not skipped
# (as we are in user code, remember the function)
if is_ag_function_call_start and not skip:
is_ag_function_call_start = False
current_function_region = origin_frame_entry
skip = True
origin_stack_summary.append(origin_frame_entry)

# User code - not filtered out
if not skip:
# If we are in the same function region, we replace previous entry so we keep only the
# last one
if _collapse_ag_frames:
assert origin_stack_summary
if _is_matching_function(origin_stack_summary[-1], current_function_region):
assert origin_stack_summary
if _is_matching_function(origin_stack_summary[-1], current_function_region):
if _collapse_ag_frames:
origin_stack_summary.pop()
else:
current_function_region = None
else:
current_function_region = None
origin_stack_summary.append(origin_frame_entry)
elif not _filter_ag_frames:
elif not _filter_ag_frames and (
len(origin_stack_summary) == 0 or origin_stack_summary[-1] != origin_frame_entry
):
origin_stack_summary.append(origin_frame_entry)
return origin_stack_summary

Expand Down
4 changes: 3 additions & 1 deletion dali/python/nvidia/dali/fn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -81,6 +81,8 @@ def op_wrapper(*inputs, **kwargs):
def fn_wrapper(*inputs, **kwargs):
from nvidia.dali._debug_mode import _PipelineDebug

kwargs = {**kwargs, "_api": "_fn"}

current_pipeline = _PipelineDebug.current()
if getattr(current_pipeline, "_debug_on", False):
return current_pipeline._wrap_op_call(op_class, wrapper_name, *inputs, **kwargs)
Expand Down
40 changes: 24 additions & 16 deletions dali/python/nvidia/dali/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,29 +375,14 @@ def __init__(self, inputs, arg_inputs, arguments, _processed_arguments, op):
self._spec = op.spec.copy()
self._relation_id = self._counter.id

if _dali_trace.is_tracing_enabled():
if _Pipeline.current():
skip_bottom = _Pipeline.current()._definition_stack_frame
else:
skip_bottom = 0
# For fn API it is 4, for ops around 2
stack_summary = _dali_trace.extract_stack(
skip_bottom_frames=skip_bottom, skip_top_frames=4
)
filenames, linenos, names, lines = _dali_trace.separate_stack_summary(stack_summary)

self._spec.AddArg("_origin_stack_filename", filenames)
self._spec.AddArg("_origin_stack_lineno", linenos)
self._spec.AddArg("_origin_stack_name", names)
self._spec.AddArg("_origin_stack_line", lines)

# TODO(klecki): Replace "type(op).__name__" with proper name formatting based on backend

if _conditionals.conditionals_enabled():
inputs, arg_inputs = _conditionals.apply_conditional_split_to_args(inputs, arg_inputs)
_conditionals.inject_implicit_scope_argument(op._schema, arg_inputs)

self._process_instance_name(arguments)
self._process_trace(arguments)
_process_arguments(op._schema, self._spec, arguments, type(op).__name__)

self._inputs = _process_inputs(op._schema, self._spec, inputs, type(op).__name__)
Expand Down Expand Up @@ -430,6 +415,23 @@ def _process_instance_name(self, arguments):
else:
self._name = "__" + type(self._op).__name__ + "_" + str(self._counter.id)

def _process_trace(self, arguments):
if _dali_trace.is_tracing_enabled():
if _Pipeline.current():
skip_bottom = _Pipeline.current()._definition_stack_frame
else:
skip_bottom = 0
skip_top = 7 if self._op._api == "fn" else 3
stack_summary = _dali_trace.extract_stack(
skip_bottom_frames=skip_bottom, skip_top_frames=skip_top
)
filenames, linenos, names, lines = _dali_trace.separate_stack_summary(stack_summary)

arguments["_origin_stack_filename"] = filenames
arguments["_origin_stack_lineno"] = linenos
arguments["_origin_stack_name"] = names
arguments["_origin_stack_line"] = lines

def _generate_outputs(self):
pipeline = _Pipeline.current()
if pipeline is None and self._op.preserve:
Expand Down Expand Up @@ -528,6 +530,12 @@ def __init__(self, *, device="cpu", **kwargs):

self._init_args, self._call_args = _separate_kwargs(kwargs)

# It would be more generic to handle this in OperatorInstance, but we need it
# for error messages in the constructor of Operator (meta)class.
if "_api" not in self._init_args:
self._init_args["_api"] = "ops"
self._api = self._init_args["_api"]

for k in self._call_args.keys():
_check_arg_input(self._schema, type(self).__name__, k)

Expand Down
5 changes: 4 additions & 1 deletion dali/python/nvidia/dali/ops/_operators/python_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,6 +40,9 @@ def __init__(self, impl_name, function, num_outputs=1, device="cpu", **kwargs):

self._init_args, self._call_args = ops._separate_kwargs(kwargs)
self._name = self._init_args.pop("name", None)
if "_api" not in self._init_args:
self._init_args["_api"] = "ops"
self._api = self._init_args["_api"]

for key, value in self._init_args.items():
self._spec.AddArg(key, value)
Expand Down
5 changes: 4 additions & 1 deletion dali/python/nvidia/dali/ops/_operators/tfrecord.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -49,6 +49,9 @@ def __init__(self, path, index_path, features, **kwargs):
self._init_args.update({"path": self._path, "index_path": self._index_path})
self._name = self._init_args.pop("name", None)
self._preserve = self._init_args.get("preserve", False)
if "_api" not in self._init_args:
self._init_args["_api"] = "ops"
self._api = self._init_args["_api"]

for key, value in self._init_args.items():
self._spec.AddArg(key, value)
Expand Down
38 changes: 30 additions & 8 deletions dali/test/python/operator_1/test_operator_origin_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@
from nvidia.dali.auto_aug import auto_augment, augmentations
from nvidia.dali.auto_aug.core import augmentation, Policy
from nvidia.dali._utils import dali_trace
from test_utils import load_test_operator_plugin
from nvidia.dali.pipeline import do_not_convert
from nose2.tools import params
from test_utils import load_test_operator_plugin


dali_trace.set_tracing(enabled=True)

load_test_operator_plugin()


op_mode = "dali"
op_mode = "dali.fn"
extracted_stacks = []
base_frame = 0

Expand Down Expand Up @@ -74,8 +75,10 @@ def capture_dali_traces(pipe_def):
def origin_trace():
"""Either return trace using test operator or capture it via Python API"""
global op_mode
if op_mode == "dali":
if op_mode == "dali.fn":
return fn.origin_trace_dump()
if op_mode == "dali.ops":
return ops.OriginTraceDump()() # Yup, super obvious __init__ + __call__
# elif op_mode == "python":
global extracted_stacks
global base_frame
Expand All @@ -88,10 +91,17 @@ def compare_traces(dali_tbs, python_tbs):
assert len(dali_tbs) == len(python_tbs)
for dali_tb, python_tb in zip(dali_tbs, python_tbs):
err = f"Comparing dali_tb:\n{dali_tb}\nvs python_tb:\n{python_tb}"
print(err)
assert dali_tb.startswith(python_tb), err


def test_trace_almost_trivial():
test_modes = ["dali.fn", "dali.ops"]


@params(*test_modes)
def test_trace_almost_trivial(test_mode):
global op_mode
op_mode = test_mode
def pipe():
return origin_trace()

Expand All @@ -102,7 +112,10 @@ def pipe():
compare_traces(dali_regular_tbs, python_tbs)


def test_trace_recursive():
@params(*test_modes)
def test_trace_recursive(test_mode):
global op_mode
op_mode = test_mode

def recursive_helper(n=2):
if n:
Expand Down Expand Up @@ -132,7 +145,10 @@ def pipe():
compare_traces(dali_cond_tbs, python_tbs)


def test_trace_recursive_do_not_convert():
@params(*test_modes)
def test_trace_recursive_do_not_convert(test_mode):
global op_mode
op_mode = test_mode

@do_not_convert
def recursive_helper(n=2):
Expand All @@ -153,7 +169,10 @@ def pipe():
compare_traces(dali_cond_tbs, python_tbs)


def test_trace_if():
@params(*test_modes)
def test_trace_if(test_mode):
global op_mode
op_mode = test_mode

# dali_trace.set_tracing(options={"filter_ag_frames": False})
# dali_trace.set_tracing(options={"remap_ag_frames": False})
Expand Down Expand Up @@ -184,7 +203,10 @@ def pipe():
# compare_traces(dali_cond_tbs, python_tbs)


def test_trace_auto_aug():
@params(*test_modes)
def test_trace_auto_aug(test_mode):
global op_mode
op_mode = test_mode

# TODO(klecki): AutoGraph loses mapping for the trace_aug and points to a transformed file
# Find out if we can somehow propagate that code mapping back. Do not convert helps with
Expand Down

0 comments on commit 7b09753

Please sign in to comment.