diff --git a/dali/python/nvidia/dali/_autograph/impl/api.py b/dali/python/nvidia/dali/_autograph/impl/api.py index 7e3d7795d5..5a1c460d3e 100644 --- a/dali/python/nvidia/dali/_autograph/impl/api.py +++ b/dali/python/nvidia/dali/_autograph/impl/api.py @@ -178,7 +178,7 @@ def __init__(self, name, operator_overload): self._extra_locals = None def get_transformed_name(self, node): - return self._name + super(PyToLib, self).get_transformed_name(node) + return self._name + "__" + super(PyToLib, self).get_transformed_name(node) def get_extra_locals(self): if self._extra_locals is None: @@ -931,7 +931,7 @@ def to_code(entity, recursive=True, experimental_optional_features=None): def initialize_autograph(operator_overload=hooks.OperatorBase(), converter_name="autograph", - filtered_library_modules=["nvidia.dali._autograph"]): + do_not_convert_modules=["nvidia.dali._autograph"]): """Initialize the AutoGraph with custom operator overloads. Parameters @@ -943,7 +943,7 @@ def initialize_autograph(operator_overload=hooks.OperatorBase(), converter_name : str, optional Name that is used to generated converted function names and as a fake module under which the AutoGraph is inserted into them, by default "autograph". - filtered_library_modules : list, optional + do_not_convert_modules : list, optional AutoGraph needs to filter the module that should not be converted. By default it will only filter out its own functions, provide the list of module that should be ignored. If the autograph is used under different name (for example included in the source as @@ -954,7 +954,6 @@ def initialize_autograph(operator_overload=hooks.OperatorBase(), raise RuntimeError("AutoGraph already initialized") _TRANSPILER = PyToLib(converter_name, operator_overload) # Add the name of the initialized library to know libraries to stop recursive conversion - do_not_convert_rules = tuple( - config.DoNotConvert(name) for name in filtered_library_modules) - config.CONVERSION_RULES = ((config.DoNotConvert(converter_name),) + - do_not_convert_rules + config.CONVERSION_RULES) + do_not_convert_rules = tuple(config.DoNotConvert(name) for name in do_not_convert_modules) + config.CONVERSION_RULES = ((config.DoNotConvert(converter_name),) + do_not_convert_rules + + config.CONVERSION_RULES) diff --git a/dali/python/nvidia/dali/_autograph/operators/variables.py b/dali/python/nvidia/dali/_autograph/operators/variables.py index 115c447017..d342169420 100644 --- a/dali/python/nvidia/dali/_autograph/operators/variables.py +++ b/dali/python/nvidia/dali/_autograph/operators/variables.py @@ -14,11 +14,15 @@ # ============================================================================== """Utilities used to capture Python idioms.""" +from nvidia.dali._autograph.utils import hooks + def ld(v): """Load variable operator.""" if isinstance(v, Undefined): return v.read() + if hooks._DISPATCH.detect_overload_ld(v): + return hooks._DISPATCH.ld(v) return v diff --git a/dali/python/nvidia/dali/_autograph/utils/hooks.py b/dali/python/nvidia/dali/_autograph/utils/hooks.py index 0d50895764..4c2c122f9a 100644 --- a/dali/python/nvidia/dali/_autograph/utils/hooks.py +++ b/dali/python/nvidia/dali/_autograph/utils/hooks.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -53,6 +53,12 @@ def detect_overload(self, object): """ return False + def detect_overload_ld(self, v): + return self.detect_overload(v) + + def ld(self, v): + pass + def detect_overload_if_exp(self, cond): return self.detect_overload(cond) diff --git a/dali/python/nvidia/dali/_conditionals.py b/dali/python/nvidia/dali/_conditionals.py new file mode 100644 index 0000000000..3a1fce1d77 --- /dev/null +++ b/dali/python/nvidia/dali/_conditionals.py @@ -0,0 +1,515 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains the implementation of DALI if statement. + +It initializes AutoGraph with the DaliOperatorOverload that provides the overload for the if_stmt +and adjust the filtered modules so DALI code is not converted. + +The if_stmt provides access to both branches as callables and the set_state/get_state functions +that allows to capture and adjust all symbols modified within those branches. This allows to +checkpoint the state and visit the code of both branches. + +if_stmt highlights which state variables are considered the outputs of the if/else pair - we can +use the state captured after visiting if and else branches and produce fn._conditional.merge +nodes for all of them. + +When visiting the if/else scopes, we are tracking tha path that we took and the predicates that +were used via the _ConditionStack. As it is not easy to detect which state variables would be +consumed as inputs to DALI operators, we inject additional code to the operator function. +Every time a DataNode is consumed, we look up in which scope it was produced and travel the +path from that point to the current scope in the _ConditionStack, applying necessary splits. +All the return values are registered to the current scope for further lookups. +""" + +from nvidia.dali import _autograph +from nvidia.dali.data_node import DataNode as _DataNode +from nvidia.dali import fn + +from nvidia.dali._autograph.utils import ag_logging as logging +from nvidia.dali._autograph.operators import variables + +from contextlib import contextmanager + +from enum import Enum + + +def _data_node_repr(data_node): + return f"DataNode(name={data_node.name}, device={data_node.device}, source={data_node.source})" + + +class _Branch(Enum): + TrueBranch = 0 + FalseBranch = 1 + Undefined = 2 + + +class _StackEntry: + """Information about 1 nesting level of if/else statement. + + Keeps the current branch (if we entered if/else branch) and the data nodes that were + produced in their scopes. Keeps the mapping of DataNodes produced in higher scopes that + were already split for use in this scope. + """ + + def __init__(self, predicate): + self.predicate = predicate + self.branch = _Branch.Undefined + self.splits = {} + self.produced_true = set() + self.produced_false = set() + # The produced_special handles the case of producing something visible on the same nesting + # level, but not in one of the branches and is used by merge code. + self.produced_special = set() + + @property + def produced(self): + """Access the set of hashes of DataNodes produced in the scope of currently selected branch. + """ + if self.branch == _Branch.TrueBranch: + return self.produced_true + elif self.branch == _Branch.FalseBranch: + return self.produced_false + else: + return self.produced_special | self.produced_true | self.produced_false + + @produced.setter + def produced(self, value): + """Access the set of hashes of DataNodes produced in the scope of currently selected branch. + """ + if self.branch == _Branch.TrueBranch: + self.produced_true = value + elif self.branch == _Branch.FalseBranch: + self.produced_false = value + else: + self.produced_special = value + + def add_produced(self, data_node): + """Add the DataNode or DataNodes to produced in the scope of currently selected branch.""" + if isinstance(data_node, _DataNode): + self.produced |= {_data_node_repr(data_node)} + elif isinstance(data_node, list): + if not data_node: + return + if isinstance(data_node[0], _DataNode): + self.produced |= set(_data_node_repr(dn) for dn in data_node) + elif isinstance(data_node[0], list): + flat_list = [item for sublist in data_node for item in sublist] + self.add_produced(flat_list) + else: + raise ValueError(f"Unexpected operator result to register: {data_node}. Expected up to" + " two-level nesting of DataNode.") + + def add_split(self, source_data_node, producer_node, true_node, false_node): + """Register the outputs of split node that were produced from the source_data_node + (or its descendant on this scope, the shortcut node). + + Parameters + ---------- + source_data_node : DataNode + Original source node that was looked up, record for faster consecutive lookups + producer_node : DataNode + The closest node on the path from source_data_node to this split + true_node : DataNode + True branch split + false_node : DataNode + False branch split + """ + self.splits[_data_node_repr(source_data_node)] = (true_node, false_node) + # Record the direct preceding node as the producer: + self.splits[_data_node_repr(producer_node)] = (true_node, false_node) + self.produced_true |= {_data_node_repr(true_node)} + self.produced_false |= {_data_node_repr(false_node)} + + def __str__(self): + return (f"StackEntry: pred={self.predicate}, branch={self.branch}, splits={self.splits}," + f" produced={self.produced}") + + def has(self, data_node): + """Check if this DataNode was either produced in this scope or already split for this scope. + """ + if _data_node_repr(data_node) in self.produced: + return True + elif _data_node_repr(data_node) in self.splits: + return True + else: + return False + + def get(self, data_node): + """Return the `data_node` if it was produced in this scope, or the appropriate split node + that was created for accessing the `data_node` in this scope. + """ + assert self.has(data_node) + if _data_node_repr(data_node) in self.produced: + return data_node + else: + assert self.branch in {_Branch.TrueBranch, _Branch.FalseBranch} + return self.splits[_data_node_repr(data_node)][self.branch.value] + + +class _ConditionStack: + """Tracks the current if/else scope with the path that we took. Captures the used and produced + data nodes, applying the necessary splits based on the scope level where they were produced + and where they are used. + """ + + def __init__(self): + self._stack = [_StackEntry(None)] + self._is_registration_allowed = True + + def push_predicate(self, predicate): + """Add next level of if/else scope that is predicated with the `predicate`. + The user might have provided a predicate from a scope of higher level, which means + that `predicate` might be subject to additional slicing. Apply that slicing and return + the actual predicate that will be used for slicing when entering this scope. + + The situation will happen for example in a case like this, where both predicates are + produced in global scope: + + pred_0 = ... + pred_1 = ... + + if pred_0: # push_pred(pred_0) -> returns pred_0 + if pred_1: # push_pred(pred_1) -> + # -> returns fn._conditional.slice(pred_1, predicate=pred_0) + + Parameters + ---------- + predicate : DataNode + Predicate guarding this scope. + + Returns + ------- + DataNode + Actual predicate after applying necessary slices to use it in this scope. + """ + new_pred = self.preprocess_input(predicate) + new_entry = _StackEntry(new_pred) + self._stack.append(new_entry) + return new_pred + + def top(self): + """Get the top scope in the stack""" + return self._stack[-1] + + def pop(self): + """Remove the top scope from the stack""" + result = self._stack.pop() + return result + + def stack_depth(self): + """Get the depth of the stack. Note, that by default there is at least one element + - the global scope.""" + return len(self._stack) + + def _find_closest(self, data_node): + """Find the closest scope level in the stack where we can access this node as produced + (or the split of this node closest to us). + """ + for level in range(self.stack_depth() - 1, -1, -1): + if self._stack[level].has(data_node): + return level + raise ValueError(f"{data_node} was not produced within this trace.") + + def _realize_split(self, data_node, stack_level): + """The data_node was produced (or last accessed as via split) in scope earlier than the + current one, traverse the scopes between that level and current one, and insert split nodes. + + Parameters + ---------- + data_node : DataNode + The data node that we want to use in the current scope. + stack_level : int + Stack level where the data_node was last "seen". + + Returns + ------- + DataNode + New node that can be used in current branch and scope. + """ + assert 0 <= stack_level and stack_level < self.stack_depth() - 1 + produced_data_node = self._stack[stack_level].get(data_node) + bottom = self._stack[:stack_level + 1] + top = self._stack[stack_level + 1:] + self._stack = bottom + while top: + current_entry = top.pop(0) + predicate = current_entry.predicate + + # Do not automatically register the outputs in the current scope, we track them below + # in their respective branches. + logging.log(9, (f"{self._indent()}[IF] Inserting split" + f" at {self.stack_depth() -1}:" + f" split({produced_data_node}, predicate={predicate}.")) + self._is_registration_allowed = False + true_node, false_node = fn._conditional.split(produced_data_node, predicate=predicate) + self._is_registration_allowed = True + + # Record the result of splitting the `data_node` that we are trying to look up + # (short-cut for consecutive lookups) + current_entry.add_split(data_node, produced_data_node, true_node, false_node) + if current_entry.branch == _Branch.TrueBranch: + produced_data_node = true_node + else: + produced_data_node = false_node + self._stack.append(current_entry) + return produced_data_node + + def preprocess_input(self, data_node): + """Process the DataNode that is an input to an operator call. Detect if the DataNode was + produced on the same nesting level. If not, split accordingly to the stack of the previous + conditions. Caches the previously processed DataNodes to not do repeated splitting. + """ + stack_level = self._find_closest(data_node) + logging.log(8, (f"{self._indent()}[IF/Input] {data_node} accessed at level" + f" {self.stack_depth() - 1} found at {stack_level}.")) + # We already have it cached or produced in this scope. + if stack_level == self.stack_depth() - 1: + return self.top().get(data_node) + # otherwise, we need to fill in the splits. + return self._realize_split(data_node, stack_level) + + def register_data_nodes(self, data_node, global_scope=False): + """Register the data nodes as produced in current scope, otherwise if `global_scope` is True + put them in the outermost scope. + """ + if not self._is_registration_allowed: + return + logging.log(8, (f"{self._indent()}[IF/Register] {data_node} at {self.stack_depth() -1}")) + scope = self._stack[0] if global_scope else self.top() + scope.add_produced(data_node) + + def track_true_branch(self): + """Mark `if` (true) branch as current scope.""" + self.top().branch = _Branch.TrueBranch + + def track_false_branch(self): + """Mark `else` (false) branch as current scope.""" + self.top().branch = _Branch.FalseBranch + + def no_branch(self): + """Mark no branch being tracked, the scope "level" stays related to the same if/else + statement.""" + self.top().branch = _Branch.Undefined + + def track_merge(self, split_predicate): + """Enter the merge section of the if/else statement. It adds the corresponding + split_predicate to the nodes visible as produced in the current scope, so all data nodes + are directly accessible in this scope when looked up by the merge operator. + We don't care about removing it as it's the last thing happening in that statement. + """ + self.no_branch() + self.top().add_produced(split_predicate) + + def _indent(self): + """Helper for indenting the log messages to resemble visited scopes""" + return ' ' * (self.stack_depth() - 1) + + +@contextmanager +def _cond_manager(predicate): + actual_predicate = this_condition_stack().push_predicate(predicate) + logging.log(7, (f"{this_condition_stack()._indent()}[IF]: {predicate}" + f" at {this_condition_stack().stack_depth() - 1}")) + # Return it so we can use it in merge + yield actual_predicate + this_condition_stack().pop() + + +@contextmanager +def _cond_true(): + this_condition_stack().track_true_branch() + logging.log(7, (f"{this_condition_stack()._indent()}[IF]: `if` branch" + f" at {this_condition_stack().stack_depth() - 1}")) + yield + this_condition_stack().no_branch() + + +@contextmanager +def _cond_false(): + this_condition_stack().track_false_branch() + logging.log(7, (f"{this_condition_stack()._indent()}[IF]: `else` branch" + f" at {this_condition_stack().stack_depth() - 1}")) + yield + this_condition_stack().no_branch() + + +@contextmanager +def _cond_merge(split_predicate): + this_condition_stack().track_merge(split_predicate) + yield + this_condition_stack().no_branch() + + +def conditionals_enabled(): + """Check (within a Pipeline context) if the conditionals are enabled. + """ + from nvidia.dali._debug_mode import _PipelineDebug + current_pipeline = _PipelineDebug.current() + enabled = getattr(current_pipeline, '_conditionals_enabled', False) + return enabled + + +def this_condition_stack(): + """Return the condition stack of current Pipeline""" + from nvidia.dali._debug_mode import _PipelineDebug + current_pipeline = _PipelineDebug.current() + if current_pipeline._condition_stack is None: + raise ValueError("Cannot access current condition stack when conditionals" + " were not enabled for a given pipeline.") + return current_pipeline._condition_stack + + +def register_data_nodes(data_node, inputs=[]): + """Register the outputs of the operator as produced in the scope of the current conditional + branch. + + Parameters + ---------- + data_node : DataNode or a list/tuple of DataNode + The output of the operator to be registered. + inputs : List of DataNode + Optional list of inputs of the operator whose outputs we are registering. + If there were no inputs, the outputs are considered as produced in global scope. + """ + + any_input = any(isinstance(input, _DataNode) for input in inputs) + # TODO(klecki): In theory we have two approaches for inputless operators. Here we insert their + # outputs to top level and let the automatic splitting handle the situation. Otherwise we could + # pass the scope information and batch_size within that scope to all operators that are invoked + # within that scope. + this_condition_stack().register_data_nodes(data_node, global_scope=not any_input) + + +def apply_conditional_split(input): + """Preprocess the DataNode to obtain correctly split batch for the current if scope.""" + return this_condition_stack().preprocess_input(input) + + +def apply_conditional_split_to_branch_outputs(branch_outputs, promote_constants=True): + """Apply splitting to the branch outputs. This may be necessary for DataNodes that are + branch outputs but were not touched in that branch (for example that branch is no-op). + + Parameters + ---------- + branch_outputs : tuple of DataNode + Outputs of the branch + promote_constants : bool, optional + Whether to promote constants to cpu-based Constant op, by default True + + Returns + ------- + tuple of DataNode + """ + from nvidia.dali.types import Constant + inputs_bkp = list(branch_outputs) + for i, input in enumerate(branch_outputs): + if isinstance(input, _DataNode): + inputs_bkp[i] = apply_conditional_split(input) + elif promote_constants: + # We assume that any return from the branch must be merged, so constants are promoted + # to batches using constant op, and thus can be used in merge. + constant_node = Constant(input, device="cpu") + register_data_nodes(constant_node) + inputs_bkp[i] = apply_conditional_split(constant_node) + return tuple(inputs_bkp) + + +def apply_conditional_split_to_args(inputs, kwargs): + """Preprocess the inputs and kwargs of the operator to obtain correctly split inputs for the + current if scope.""" + inputs = apply_conditional_split_to_branch_outputs(inputs, False) + for key, arg in kwargs.items(): + if isinstance(arg, _DataNode): + kwargs[key] = apply_conditional_split(arg) + return inputs, kwargs + + +def _verify_branch_outputs(outputs, symbol_names, branch_name): + """Verifies variables output by a conditional branch for consistency.""" + common_explanation = ( + "Encountered inconsistent outputs out of the `if/else` control flow statement." + " Variables need to be initialized in every code path (both `if` branches).") + for name, output in zip(symbol_names, outputs): + if isinstance(output, variables.Undefined): + raise RuntimeError(f"{common_explanation} Variable '{name}' must also be initialized" + f" in the `{branch_name}` branch.") + if isinstance(output, variables.UndefinedReturnValue): + raise RuntimeError(f"{common_explanation} The `{branch_name}` branch must also have" + f" a return statement.") + + +class DaliOperatorOverload(_autograph.OperatorBase): + + def detect_overload_ld(self, v): + return isinstance(v, _DataNode) + + def ld(self, v): + branch_v = apply_conditional_split(v) + return branch_v + + def detect_overload_if_stmt(self, cond): + return isinstance(cond, _DataNode) + + def if_stmt(self, cond, body, orelse, get_state, set_state, symbol_names, nouts): + # Initial checkpoint before if + init_state = get_state() + with _cond_manager(cond) as split_predicate: + # Set the state for the body inputs, execute the body and collect the outputs. + # Verify if all outputs are initialized within the branch, split the outputs if they + # were just passed through, so they can be merged with the other branch. + with _cond_true(): + body() + + body_state = get_state() + _verify_branch_outputs(body_state, symbol_names, "if") + body_outputs = body_state[:nouts] + body_outputs = apply_conditional_split_to_branch_outputs(body_outputs) + + # Do the same for else block. + set_state(init_state) + with _cond_false(): + orelse() + + orelse_state = get_state() + _verify_branch_outputs(orelse_state, symbol_names, "else") + orelse_outputs = orelse_state[:nouts] + orelse_outputs = apply_conditional_split_to_branch_outputs(orelse_outputs) + + # Build the state that is the combination of both branches. Only the actual outputs + # should be affected by the if/else blocks, the rest can be reused from-before split. + output_values = [] + # We execute the merge _after_ both branches, and pretend for a moment, that it + # can see those values produced in child scopes. + with _cond_merge(split_predicate): + for new_body_val, new_orelse_val in zip(body_outputs, orelse_outputs): + logging.log(9, (f"{this_condition_stack()._indent()}[IF] Inserting merge" + f" at {this_condition_stack().stack_depth() -1}:" + f" merge({new_body_val}, {new_orelse_val}, predicate=" + f"{split_predicate}.")) + merged = fn._conditional.merge(new_body_val, new_orelse_val, + predicate=split_predicate) + output_values.append(merged) + + # Register the new nodes outside of the conditional scope, they will be used in subsequent + # calls. + this_condition_stack().register_data_nodes(output_values, False) + # No point in propagating the split/merged values that won't be read later. + output_values += init_state[nouts:] + set_state(output_values) + + +_OVERLOADS = DaliOperatorOverload() + +_autograph.initialize_autograph(_OVERLOADS, + do_not_convert_modules=["nvidia.dali._autograph", "nvidia.dali"]) diff --git a/dali/python/nvidia/dali/external_source.py b/dali/python/nvidia/dali/external_source.py index 0ed228e896..fe1d0d75d2 100644 --- a/dali/python/nvidia/dali/external_source.py +++ b/dali/python/nvidia/dali/external_source.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -849,6 +849,7 @@ def external_source(source=None, num_outputs=None, *, cycle=None, name=None, dev """ from nvidia.dali._debug_mode import _PipelineDebug + from nvidia.dali import _conditionals def _external_source(source=None, num_outputs=None, *, cycle=None, name=None, device="cpu", layout=None, dtype=None, ndim=None, cuda_stream=None, use_copy_kernel=None, @@ -875,9 +876,12 @@ def _external_source(source=None, num_outputs=None, *, cycle=None, name=None, de source=source, num_outputs=num_outputs, cycle=cycle, name=name, device=device, layout=layout, batch=batch, **kwargs) else: - return _external_source(source, num_outputs, cycle=cycle, name=name, device=device, - layout=layout, dtype=dtype, ndim=ndim, cuda_stream=cuda_stream, - use_copy_kernel=use_copy_kernel, batch=batch, **kwargs) + result = _external_source(source, num_outputs, cycle=cycle, name=name, device=device, + layout=layout, dtype=dtype, ndim=ndim, cuda_stream=cuda_stream, + use_copy_kernel=use_copy_kernel, batch=batch, **kwargs) + if _conditionals.conditionals_enabled(): + _conditionals.register_data_nodes(result) + return result external_source.__doc__ += ExternalSource._args_doc diff --git a/dali/python/nvidia/dali/ops.py b/dali/python/nvidia/dali/ops.py index ea2deba46c..9b18bcb8ec 100644 --- a/dali/python/nvidia/dali/ops.py +++ b/dali/python/nvidia/dali/ops.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ CUDAStream as _CUDAStream, \ ScalarConstant as _ScalarConstant, \ Constant as _Constant +from nvidia.dali import _conditionals cupy = None @@ -418,6 +419,9 @@ def __init__(self, inputs, op, **kwargs): inputs[i] = _instantiate_constant_node(default_input_device, inp) inputs = tuple(inputs) + if _conditionals.conditionals_enabled(): + inputs, kwargs = _conditionals.apply_conditional_split_to_args(inputs, kwargs) + self._inputs = inputs spec_args, kwargs = _separate_kwargs(kwargs) @@ -656,11 +660,17 @@ def __call__(self, *inputs, **kwargs): # If we don't have multiple input sets, flatten the result if len(op_instances) == 1: - return op_instances[0].unwrapped_outputs - outputs = [] - for op in op_instances: - outputs.append(op.outputs) - return self._repack_output_sets(outputs) + result = op_instances[0].unwrapped_outputs + else: + outputs = [] + for op in op_instances: + outputs.append(op.outputs) + result = self._repack_output_sets(outputs) + if _conditionals.conditionals_enabled(): + if len(op_instances) != 1: + raise ValueError("Multiple input sets are not supported with conditionals.") + _conditionals.register_data_nodes(result, input_sets[0]) + return result # Check if any of inputs is a list def _detect_multiple_input_sets(self, inputs): @@ -1326,8 +1336,12 @@ def _arithm_op(name, *inputs): dev_inputs = list(edge.gpu() for edge in edges) else: dev_inputs = edges + # Call it immediately - return op(*dev_inputs) + result = op(*dev_inputs) + if _conditionals.conditionals_enabled(): + _conditionals.register_data_nodes(result, dev_inputs) + return result def cpu_ops(): diff --git a/dali/python/nvidia/dali/pipeline.py b/dali/python/nvidia/dali/pipeline.py index 4041376f4b..44c0f16553 100644 --- a/dali/python/nvidia/dali/pipeline.py +++ b/dali/python/nvidia/dali/pipeline.py @@ -19,6 +19,7 @@ from nvidia.dali import internal from nvidia.dali._multiproc.pool import WorkerPool from nvidia.dali import pickling as dali_pickle +from nvidia.dali import _conditionals from threading import local as tls from . import data_node as _data_node import functools @@ -252,6 +253,8 @@ def __init__(self, self._gpu_queue_size = prefetch_queue_depth else: raise TypeError("Expected prefetch_queue_depth to be either int or Dict[int, int]") + self._conditionals_enabled = False + self._condition_stack = None # Assign and validate output_dtype if isinstance(output_dtype, (list, tuple)): @@ -1399,6 +1402,10 @@ def _discriminate_args(func, **func_kwargs): if 'debug' not in func_argspec.args and 'debug' not in func_argspec.kwonlyargs: func_kwargs.pop('debug', False) + if ('enable_conditionals' not in func_argspec.args + and 'enable_conditionals' not in func_argspec.kwonlyargs): + func_kwargs.pop('enable_conditionals', False) + ctor_args = {} fn_args = {} @@ -1526,20 +1533,41 @@ def create_pipeline(*args, **kwargs): def _pipeline_def_experimental(fn=None, **pipeline_kwargs): from nvidia.dali._debug_mode import _PipelineDebug pipeline_debug = pipeline_kwargs.pop('debug', False) + pipeline_conditionals = pipeline_kwargs.pop('enable_conditionals', False) def actual_decorator(func): @functools.wraps(func) def create_pipeline(*args, **kwargs): debug_mode_on = kwargs.get('debug', pipeline_debug) - ctor_args, fn_kwargs = _discriminate_args(func, **kwargs) + conditionals_on = kwargs.get('enable_conditionals', pipeline_conditionals) + if conditionals_on: + pipe_func = _conditionals._autograph.to_graph(func) + else: + pipe_func = func + ctor_args, fn_kwargs = _discriminate_args(pipe_func, **kwargs) pipeline_args = {**pipeline_kwargs, **ctor_args} # Merge and overwrite dict if debug_mode_on: - pipe = _PipelineDebug(functools.partial(func, *args, **fn_kwargs), **pipeline_args) + # TODO(klecki): cross-validate conditionals with eager mode + if conditionals_on: + raise NotImplementedError("Conditionals are not supported in debug mode yet.") + pipe = _PipelineDebug(functools.partial(pipe_func, *args, **fn_kwargs), + **pipeline_args) else: pipe = Pipeline(**pipeline_args) + if conditionals_on: + pipe._conditionals_enabled = True + pipe._condition_stack = _conditionals._ConditionStack() with pipe: - pipe_outputs = func(*args, **fn_kwargs) + if conditionals_on: + # Add all parameters to the pipeline as "know" nodes in the top scope. + for arg in args: + if isinstance(arg, DataNode): + _conditionals.register_data_nodes(arg) + for arg in fn_kwargs: + if isinstance(arg, DataNode): + _conditionals.register_data_nodes(arg) + pipe_outputs = pipe_func(*args, **fn_kwargs) if isinstance(pipe_outputs, tuple): po = pipe_outputs elif pipe_outputs is None: diff --git a/dali/test/python/autograph/impl/test_api.py b/dali/test/python/autograph/impl/test_api.py index 141ce74907..31fa4a79e2 100644 --- a/dali/test/python/autograph/impl/test_api.py +++ b/dali/test/python/autograph/impl/test_api.py @@ -47,8 +47,6 @@ # from nvidia.dali._autograph.utils.all_utils import custom_constant -api.initialize_autograph() - global_n = 2 DEFAULT_RECURSIVE = converter.ConversionOptions(recursive=True) @@ -66,6 +64,18 @@ def custom_constant(val): class ApiTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._transpiler_bkp = api._TRANSPILER + cls._conversion_rules_bkp = api.config.CONVERSION_RULES + api._TRANSPILER = None + api.initialize_autograph() + + @classmethod + def tearDownClass(cls): + api._TRANSPILER = cls._transpiler_bkp + api.config.CONVERSION_RULES = cls._conversion_rules_bkp + def evaluate(self, x): return x diff --git a/dali/test/python/conditionals/test_pipeline_conditionals.py b/dali/test/python/conditionals/test_pipeline_conditionals.py new file mode 100644 index 0000000000..1e0ec0f1f7 --- /dev/null +++ b/dali/test/python/conditionals/test_pipeline_conditionals.py @@ -0,0 +1,574 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvidia.dali.pipeline import pipeline_def, experimental +import nvidia.dali.fn as fn +import nvidia.dali.types as types +from nvidia.dali.types import SampleInfo +from nvidia.dali import _conditionals +from nvidia.dali.data_node import DataNode + +import numpy as np +import os + +from test_utils import check_batch, compare_pipelines +from nose_utils import assert_raises +from test_utils import get_dali_extra_path +from nose2.tools import params + +import itertools + + +def test_condition_stack(): + test_stack = _conditionals._ConditionStack() + pred_node = DataNode("PredOp") + pred_nested = DataNode("PredOp2") + some_op = DataNode("SomeOp") + some_nested_op = DataNode("SomeOp2") + + # model: + # if pred_node: + # some_op() + # if pred_nested: + # some_other_op() + + test_stack.register_data_nodes(pred_node) + test_stack.register_data_nodes(pred_nested) + # Both visible in global scope + assert test_stack._find_closest(pred_node) == 0 + assert test_stack._find_closest(pred_nested) == 0 + # First predicate, no splitting required, as this is the first nesting level + first_level = test_stack.push_predicate(pred_node) + assert _conditionals._data_node_repr(pred_node) == _conditionals._data_node_repr(first_level) + + test_stack.track_true_branch() + test_stack.register_data_nodes(some_op) + assert test_stack._find_closest(some_op) == 1 + + assert test_stack._find_closest(pred_nested) == 0 + assert test_stack.stack_depth() == 2 + + true_split = test_stack._realize_split(pred_nested, 0) + second_level = test_stack.push_predicate(pred_nested) + # Second predicate require splitting + assert _conditionals._data_node_repr(true_split) == _conditionals._data_node_repr(second_level) + test_stack.track_true_branch() + test_stack.register_data_nodes(some_nested_op) + assert test_stack._find_closest(some_nested_op) == 2 + + # It's already on this level + assert len(test_stack.top().produced) == 1 + preprocessed = test_stack.preprocess_input(some_nested_op) + assert (_conditionals._data_node_repr(some_nested_op) + == _conditionals._data_node_repr(preprocessed)) + assert len(test_stack.top().produced) == 1 + + # This one is not + assert len(test_stack.top().produced) == 1 + preprocessed = test_stack.preprocess_input(some_op) + assert _conditionals._data_node_repr(some_op) != _conditionals._data_node_repr(some_nested_op) + assert len(test_stack.top().produced) == 2 + + test_stack.pop() + test_stack.pop() + assert len(test_stack.top().produced) == 2 + + +rng = np.random.default_rng() + +# Predicates +num_gens = [ + lambda x: np.int32(x.idx_in_batch - 3), lambda x: np.int32(-1 + if x.idx_in_batch % 2 == 0 else 1), + lambda x: np.int32((x.idx_in_batch % 3 == 0) - 1), lambda _: np.int32(1), lambda _: np.int32(0), + lambda _: np.int32(-1), + lambda _: rng.choice([np.int32(-2), np.int32(0), np.int32(2)]) +] + +pred_gens = [ + lambda x: np.array(x.idx_in_batch < 3), lambda x: np.array(x.idx_in_batch % 2 == 0), + lambda x: np.array(x.idx_in_batch % 3 == 0), lambda x: np.array( + (x.idx_in_batch + (x.iteration % 2)) % 2 == 0), lambda _: np.array(False), + lambda _: rng.choice([np.array(True), np.array(False)]) +] + +input_gens = [lambda x: np.array(0), lambda x: np.array(x.idx_in_epoch)] + + +def generic_execute(function, input_gen_list, optional_params=None): + """Given a Python `function` (taking some positional arguments) and a list of sample generators, + execute the function twice on batches of data generated by the generator and compare to test + the conditional execution. + + The function is executed both as a: + * DALI Pipeline with conditional execution enabled. External source nodes are passed + as positional parameters and fed with the generated batches. + * Regular function, where we pass the batches sample-by-sample to build output batches. + + Parameters + ---------- + function : callable + function used for testing + input_gen_list : list of sample generators + Possibly a stateful generator + optional_params : list of dictionaries, optional + Optional kwargs for external source associated with given input position, by default None + """ + if optional_params is None: + optional_params = [{} for _ in input_gen_list] + assert len(input_gen_list) == len(optional_params), ("Optional param should be provided for" + " every external source node.") + bs = 10 + iters = 5 + kwargs = { + "batch_size": bs, + "num_threads": 4, + "device_id": 0, + "prefetch_queue_depth": 1 # so that it's easier to use external source + } + + # Prepare external source nodes with placeholder names, convert + es_inputs = [ + fn.external_source(name=f"input_{i}", **params) for i, params in enumerate(optional_params) + ] + + pipeline_definition = experimental.pipeline_def(enable_conditionals=True)(function) + + def gen_batch(generator, bs, iter): + return [generator(SampleInfo(bs * iter + i, i, iter, 0)) for i in range(bs)] + + pipe = pipeline_definition(*es_inputs, **kwargs) + pipe.build() + + for iter in range(iters): + batches = [gen_batch(gen, bs, iter) for gen in input_gen_list] + for i, batch in enumerate(batches): + pipe.feed_input(f"input_{i}", batch) + + outputs = pipe.run() + + baseline_outputs = [] + for inputs_i in zip(*batches): + outputs_i = function(*inputs_i) + # make it a tad more generic + if not isinstance(outputs_i, tuple): + outputs_i = outputs_i, + baseline_outputs.append(outputs_i) + + # Repack list of tuples into tuple of lists. + baseline_outputs = tuple(zip(*baseline_outputs)) + # make the elements actually lists: + baseline_outputs = (list(baseline) for baseline in baseline_outputs) + + for out, baseline in zip(outputs, baseline_outputs): + check_batch(out, baseline, bs) + + +# Tests below are ported from dali/test/python/autograph/converters/test_control_flow.py + + +@params(*num_gens) +def test_basic(num_gen): + + def f(n): + a = np.int32(0) + b = np.int32(0) + if n > 0: + a = -n + else: + b = 2 * n + return a, b + + generic_execute(f, [num_gen]) + + +@params(*num_gens) +def test_complex_outputs(num_gen): + + class DataClass(object): + + def __init__(self, a, b): + self.a = a + self.b = b + + def f(n, obj): + obj.a = np.int32(0) + obj.b = np.int32(0) + if n > 0: + obj.a = -n + else: + obj.b = 2 * n + return obj.a, obj.b + + generic_execute(lambda input: f(input, DataClass(np.int32(0), np.int32(0))), [num_gen]) + + +@params(*num_gens) +def test_single_output(num_gen): + + def f(n): + if n > 0: + n = -n + return n + + generic_execute(f, [num_gen]) + + +@params(*num_gens) +def test_unbalanced(num_gen): + + def f(n): + if n > 0: + n = np.int32(3) + return n + + generic_execute(f, [num_gen]) + + +@params(*num_gens) +def test_local_var(num_gen): + + def f(n): + if n > 0: + b = np.int32(4) + n = b + 1 + return n + + generic_execute(f, [num_gen]) + + +@params(*num_gens) +def test_local_remains_local(num_gen): + + def f(n): + if n > 0: + b = np.int32(4) + n = b + 1 + return n + + generic_execute(f, [num_gen]) + + +@params(*num_gens) +def test_no_outputs(num_gen): + + def f(n): + if n > 0: + b = np.int32(4) # pylint:disable=unused-variable # noqa: F841 + return n + + generic_execute(f, [num_gen]) + + +@params(*num_gens) +def test_created_outputs(num_gen): + + def f(i): + if i == 0: + result = i - 1 + else: + result = i + 1 + return result + + generic_execute(f, [num_gen]) + + +# Simple cases, where we produce new data node in the branch + + +@params(*num_gens) +def test_one_branch_new_node(num_gen): + + def f(n): + result = n * 0 + if n >= 0: + result = n + 10 + return result + + generic_execute(f, [num_gen]) + + +@params(*num_gens) +def test_both_branches_new_node(num_gen): + + def f(n): + if n >= 0: + result = n + 10 + else: + result = n - 10 + return result + + generic_execute(f, [num_gen]) + + +@params(*num_gens) +def test_chain_branches_new_node(num_gen): + + def f(n): + if n == 0: + result = n + 10 + elif n > 0: + result = n + 100 + else: + result = n - 50 + return result + + generic_execute(f, [num_gen]) + + +# Cases where we do only assignment and no new node is produced within branch, so we need to +# detect usage in other way than looking at operator inputs + + +@params(*pred_gens) +def test_one_branch_only_assign(pred): + + def f(pred, base, true_branch): + result = base + if pred: + result = true_branch + return result + + generic_execute(f, [pred, lambda _: np.int32(42), lambda _: np.int32(7)]) + + +@params(*pred_gens) +def test_both_branches_only_assign(pred): + + def f(pred, true_branch, false_branch): + if pred: + result = true_branch + else: + result = false_branch + return result + + generic_execute(f, [pred, lambda _: np.int32(6), lambda _: np.int32(9)]) + + +@params(*itertools.product(pred_gens, pred_gens)) +def test_chain_branches_only_assign(pred_1, pred_2): + + def f(pred_1, pred_2, true_branch, elif_branch, else_branch): + if pred_1: + result = true_branch + elif pred_2: + result = elif_branch + else: + result = else_branch + return result + + generic_execute( + f, [pred_1, pred_2, lambda _: np.int32(42), lambda _: np.int32(6), lambda _: np.int32(9)]) + + +# More ifs - nesting and sequences + + +@params(*itertools.product(["cpu", "gpu"], input_gens, pred_gens, pred_gens)) +def test_consecutive(dev, input, pred_0, pred_1): + + def f(input, pred_0, pred_1): + if pred_0: + output = input + 1 + else: + output = input + 2 + + if pred_1: + output2 = output + 3 + else: + output2 = output + 4 + return output, output2 + + generic_execute(f, [input, pred_0, pred_1], [{"device": dev}, {}, {}]) + + +@params(*itertools.product(["cpu", "gpu"], input_gens, pred_gens, pred_gens)) +def test_nested(dev, input, pred_0, pred_1): + + def f(input, pred_0, pred_1): + if pred_0: + if pred_1: + output = input + 10 + else: + output = input + 200 + else: + output = input + 3000 + return output + + generic_execute(f, [input, pred_0, pred_1], [{"device": dev}, {}, {}]) + + +@params(*itertools.product(["cpu", "gpu"], input_gens, pred_gens, pred_gens)) +def test_nested_with_assignment(dev, input, pred_0, pred_1): + + def f(input, pred_0, pred_1): + to_assign = input * -5 + if pred_0: + if pred_1: + output = input + 10 + else: + output = to_assign + else: + output = input + 3000 + return output + + generic_execute(f, [input, pred_0, pred_1], [{"device": dev}, {}, {}]) + + +@params(*itertools.product(["cpu", "gpu"], input_gens, num_gens)) +def test_multiple_nests(dev, input, num): + + def f(input, num): + if num == -2: + if num == -1: + if num == 0: + if num == 1: + if num == 2: + if num > 3: + output = input - 100 + else: + output = input + 100 + else: + output = input - 200 + else: + output = input + 400 + else: + output = input - 800 + else: + output = input + 1600 + else: + output = input - 3200 + return output + + generic_execute(f, [input, num], [{"device": dev}, {}]) + + +# Compare pure Split/Merge operators with if statement +def test_against_split_merge(): + test_data_root = get_dali_extra_path() + caffe_db_folder = os.path.join(test_data_root, 'db', 'lmdb') + + bs = 10 + iters = 5 + kwargs = {"batch_size": bs, "num_threads": 4, "device_id": 0, "seed": 42} + + @pipeline_def(**kwargs) + def regular_pipe(): + encoded, _ = fn.readers.caffe(path=caffe_db_folder) + decoded = fn.decoders.image(encoded, device="mixed") + pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL) + true, false = fn._conditional.split(decoded, predicate=pred) + output_true = fn.rotate(true, angle=30) + output_false = fn.flip(false, horizontal=True) + return fn._conditional.merge(output_true, output_false, predicate=pred) + + @experimental.pipeline_def(enable_conditionals=True, **kwargs) + def conditional_pipe(): + encoded, _ = fn.readers.caffe(path=caffe_db_folder) + decoded = fn.decoders.image(encoded, device="mixed") + pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL) + if pred: + output = fn.rotate(decoded, angle=30) + else: + output = fn.flip(decoded, horizontal=True) + return output + + pipes = [regular_pipe(), conditional_pipe()] + for pipe in pipes: + pipe.build() + compare_pipelines(*pipes, bs, iters) + + +# Unified return tests - TODO(klecki) + +# Generator tests, remove the random predicate to test the same predicate in both pipelines. + + +@params(*(pred_gens[:-1])) +def test_generators(pred): + test_data_root = get_dali_extra_path() + caffe_db_folder = os.path.join(test_data_root, 'db', 'lmdb') + + bs = 10 + iters = 5 + kwargs = {"batch_size": bs, "num_threads": 4, "device_id": 0, "seed": 42} + + @pipeline_def(**kwargs) + def baseline_pipe(): + encoded, _ = fn.readers.caffe(path=caffe_db_folder) + rand = fn.random.uniform() + predicate = fn.external_source(source=pred, batch=False) + true_encoded, _ = fn._conditional.split(encoded, predicate=predicate) + true_rand, _ = fn._conditional.split(rand, predicate=predicate) + _, false_u8 = fn._conditional.split(np.uint8([0]), predicate=predicate) + _, false_f32 = fn._conditional.split(np.float32(0.), predicate=predicate) + encoded_out = fn._conditional.merge(true_encoded, false_u8, predicate=predicate) + rand_out = fn._conditional.merge(true_rand, false_f32, predicate=predicate) + return encoded_out, rand_out + + @experimental.pipeline_def(enable_conditionals=True, **kwargs) + def conditional_pipe(): + predicate = fn.external_source(source=pred, batch=False) + # Generators work by running in top scope and splitting for particular nesting + if predicate: + encoded_out, _ = fn.readers.caffe(path=caffe_db_folder) + rand_out = fn.random.uniform() + else: + encoded_out = types.Constant(np.uint8([0]), device="cpu") + rand_out = types.Constant(np.float32(0.), device="cpu") + return encoded_out, rand_out + + pipes = [baseline_pipe(), conditional_pipe()] + for pipe in pipes: + pipe.build() + compare_pipelines(*pipes, bs, iters) + + +# Mismatched branches test (uninitialized values) + + +def test_uninitialized(): + bs = 10 + kwargs = { + "batch_size": bs, + "num_threads": 4, + "device_id": 0, + } + + @experimental.pipeline_def(enable_conditionals=True, **kwargs) + def one_branch(): + pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL) + if pred: + output = fn.random.uniform() + return output + + with assert_raises( + RuntimeError, glob=("Encountered inconsistent outputs out of the `if/else` control flow" + " statement. Variables need to be initialized in every code path" + " (both `if` branches). Variable 'output' must also be initialized" + " in the `else` branch.")): + one_branch() + + @experimental.pipeline_def(enable_conditionals=True, **kwargs) + def one_return(): + pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL) + if pred: + return fn.random.uniform() + + with assert_raises( + RuntimeError, glob=("Encountered inconsistent outputs out of the `if/else` control flow" + " statement. Variables need to be initialized in every code path" + " (both `if` branches). The `else` branch must also have a return" + " statement.")): + one_return() diff --git a/qa/TL0_python-self-test-core/test_body.sh b/qa/TL0_python-self-test-core/test_body.sh index f8d3dd42f9..c3be65dbdf 100644 --- a/qa/TL0_python-self-test-core/test_body.sh +++ b/qa/TL0_python-self-test-core/test_body.sh @@ -22,6 +22,7 @@ test_py() { test_autograph() { ${python_new_invoke_test} -s autograph + ${python_new_invoke_test} -s conditionals } test_pytorch() {