Skip to content

Commit

Permalink
Add experimental support for if statements in DALI (#4561)
Browse files Browse the repository at this point in the history
Utilize AutoGraph to capture if statements as a function that gets 
branches as callables passed as arguments.
Trace the code within both branches, detecting the usage of DataNode.
Each pipeline has a stack of previously checked conditions and entered
branches.
Whenever an operator is called within some condition block, its outputs
(DataNodes) are registered as produced in this scope.
Every time a DataNode is used as input to the operator or touched
directly via `ag__.ld` it is looked up in the stack - if was produced 
on the same level, it is used, otherwise necessary split nodes
are inserted.
Outputs of both branches are detected by AutoGraph.
Non-DataNode results are promoted to CPU constants (so they represent
a batch), and results of both branches go into merge node.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki committed Jan 19, 2023
1 parent e84766c commit 75431b4
Show file tree
Hide file tree
Showing 10 changed files with 1,179 additions and 24 deletions.
13 changes: 6 additions & 7 deletions dali/python/nvidia/dali/_autograph/impl/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(self, name, operator_overload):
self._extra_locals = None

def get_transformed_name(self, node):
return self._name + super(PyToLib, self).get_transformed_name(node)
return self._name + "__" + super(PyToLib, self).get_transformed_name(node)

def get_extra_locals(self):
if self._extra_locals is None:
Expand Down Expand Up @@ -931,7 +931,7 @@ def to_code(entity, recursive=True, experimental_optional_features=None):

def initialize_autograph(operator_overload=hooks.OperatorBase(),
converter_name="autograph",
filtered_library_modules=["nvidia.dali._autograph"]):
do_not_convert_modules=["nvidia.dali._autograph"]):
"""Initialize the AutoGraph with custom operator overloads.
Parameters
Expand All @@ -943,7 +943,7 @@ def initialize_autograph(operator_overload=hooks.OperatorBase(),
converter_name : str, optional
Name that is used to generated converted function names and as a fake module under which
the AutoGraph is inserted into them, by default "autograph".
filtered_library_modules : list, optional
do_not_convert_modules : list, optional
AutoGraph needs to filter the module that should not be converted. By default it will
only filter out its own functions, provide the list of module that should be ignored.
If the autograph is used under different name (for example included in the source as
Expand All @@ -954,7 +954,6 @@ def initialize_autograph(operator_overload=hooks.OperatorBase(),
raise RuntimeError("AutoGraph already initialized")
_TRANSPILER = PyToLib(converter_name, operator_overload)
# Add the name of the initialized library to know libraries to stop recursive conversion
do_not_convert_rules = tuple(
config.DoNotConvert(name) for name in filtered_library_modules)
config.CONVERSION_RULES = ((config.DoNotConvert(converter_name),) +
do_not_convert_rules + config.CONVERSION_RULES)
do_not_convert_rules = tuple(config.DoNotConvert(name) for name in do_not_convert_modules)
config.CONVERSION_RULES = ((config.DoNotConvert(converter_name),) + do_not_convert_rules +
config.CONVERSION_RULES)
4 changes: 4 additions & 0 deletions dali/python/nvidia/dali/_autograph/operators/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
# ==============================================================================
"""Utilities used to capture Python idioms."""

from nvidia.dali._autograph.utils import hooks


def ld(v):
"""Load variable operator."""
if isinstance(v, Undefined):
return v.read()
if hooks._DISPATCH.detect_overload_ld(v):
return hooks._DISPATCH.ld(v)
return v


Expand Down
8 changes: 7 additions & 1 deletion dali/python/nvidia/dali/_autograph/utils/hooks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -53,6 +53,12 @@ def detect_overload(self, object):
"""
return False

def detect_overload_ld(self, v):
return self.detect_overload(v)

def ld(self, v):
pass

def detect_overload_if_exp(self, cond):
return self.detect_overload(cond)

Expand Down
Loading

0 comments on commit 75431b4

Please sign in to comment.