Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Dy2Static ] Add closure analysis for control flow and add some unittest #43713

Merged
merged 5 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,99 @@ def create_while_nodes(condition_name, body_name, loop_var_names):
return ret


class NameScope:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR我们可以把类似的公用类、分析函数统一放到一个文件里,这样可以被其他transofmer导入使用


def __init__(self):
""" we don't analyze the read only variable
because they keep the same in control flow.
"""
self.globals = set()
self.nonlocals = set()
self.args = set()
self.w_vars = set() # all vars been stored,
# may be globals or non-locals
def created_vars(self):
return self.w_vars - self.globals - self.nonlocals - self.args

def write_vars(self):
return self.w_vars

def global_vars(self):
return self.globals


class FunctionNameLivenessAnalysis(gast.NodeVisitor):
""" analyze the liveness of a function.

every variables stored in this scope will be collected,
in addition with global/nonlocal information.

1. global variable is stored in node.var_globals.
2. nonlocal variable is stored in node.var_nonlocals.
3. arguments is stored in node.var_args.

For example:

def func(*args, **kargs):
a = 12
global i,j
nonlocal x,y
print(a)
i = k
for m in range(10):
q = 12

After this visitor we have:
# node is the FunctionDef node with name: "func"
node.pd_scope = NameScope(
globals = ['i', 'j'],
nonlocals = ['x', 'y'],
args = ['args', 'kargs'],
wr_vars = ['a', 'i', 'q', 'm']
)
"""

def __init__(self, root_node):
self.funcdef_stack = []
self.visit(root_node)

def _current_funcdef_scope(self):
return self.funcdef_stack[-1].pd_scope

def visit_Name(self, node):
self.generic_visit(node)
write_context = (gast.Store, gast.AugStore, gast.Del)
if isinstance(node.ctx, write_context):
self._current_funcdef_scope().w_vars.add(node.id)

def visit_FunctionDef(self, node):
setattr(node, 'pd_scope', NameScope())
self.funcdef_stack.append(node)
self._current_funcdef_scope().args |= set(
self._get_argument_names(node))
self.generic_visit(node)
self.funcdef_stack.pop()

def visit_Global(self, node):
self._current_funcdef_scope().globals |= set(node.names)

def visit_Nonlocal(self, node):
self._current_funcdef_scope().nonlocals |= set(node.names)

def _get_argument_names(self, node):
""" get all arguments name in the functiondef node.
this node is local to the function and shouldn't
be created.
"""
assert isinstance(
node, gast.FunctionDef), "Input node is not function define node"
names = [a for a in node.args.args]
names.append(node.args.vararg)
names.append(node.args.kwarg)
names = [i.id for i in names if i is not None]
return names


class NameVisitor(gast.NodeVisitor):
'''
Analysis name liveness for loop transformer
Expand All @@ -122,7 +215,6 @@ def __init__(self, root_node):

# List of nodes that have scope of variables.
self.nodes_with_scope = []

self.blacklist_names = {"False", "True", "None"}

# Mapping from gast.While/gast.For to variable nodes
Expand Down Expand Up @@ -244,6 +336,7 @@ def visit_Name(self, node):
type(gast.AugStore()),
type(gast.Del())
}

for loop_node in self.current_loop:
self.in_loop_vars[loop_node].append(node)
if type(node.ctx) in write_context:
Expand All @@ -255,6 +348,7 @@ def visit_Name(self, node):
def visit_FunctionDef(self, node):
self.nodes_with_scope.append(node)
self.blacklist_names.add(node.name)

# The variables in the function are not visible to the outside scope.
before_func_seen_vars = copy.copy(self.current_seen_vars)

Expand Down Expand Up @@ -353,6 +447,9 @@ def _is_call_func_name_node(self, node):
return True
return False

def _is_global_or_nonlocal(self, node):
return False

def _is_ancestor_node(self, ancestor_node, node):
parent_node = self._get_parent_node(node)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest

import paddle
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import FunctionNameLivenessAnalysis
from paddle.utils import gast
import inspect
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import语句注意要清理下,很多import语句是没有在这个单测里用到的



class JudgeVisitor(gast.NodeVisitor):

def __init__(self, ans):
self.ans = ans

def visit_FunctionDef(self, node):
scope = node.pd_scope
expected = self.ans.get(node.name, set())
assert scope.created_vars() == expected, "Not Equals."
self.generic_visit(node)


def test_normal_0(x):

def func():
if True:
i = 1

func()
return i


def test_normal_argument(x):
x = 1

def func():
if True:
print(x)
i = 1

func()
return x


def test_global(x):
global t
t = 10

def func():
if True:
print(x)
i = 1

func()
return x


def test_nonlocal(x, *args, **kargs):
i = 10

def func(*args, **kargs):
nonlocal i
k = 10
if True:
print(x)
i = 1

func(*args, **kargs)
return x


class TestClosureAnalysis(unittest.TestCase):

def setUp(self):
self.init_dygraph_func()

def init_dygraph_func(self):
self.all_dygraph_funcs = [
test_nonlocal, test_global, test_normal_0, test_normal_argument
]
self.answer = [
{
'func': set('k'),
'test_nonlocal': set('i')
},
{
'func': set({'i'}),
},
{
'func': set('i'),
},
{
'func': set('i'),
},
]

def test_main(self):
for ans, func in zip(self.answer, self.all_dygraph_funcs):
test_func = inspect.getsource(func)
gast_root = gast.parse(test_func)
name_visitor = FunctionNameLivenessAnalysis(gast_root)
JudgeVisitor(ans).visit(gast_root)


def TestClosureAnalysis_Attribute_func():
# in this function, only self is a Name, self.current is a Attribute. self is read and self.current.function is store()
i = 0
self.current.function = 12


class TestClosureAnalysis_Attribute(TestClosureAnalysis):

def init_dygraph_func(self):

self.all_dygraph_funcs = [TestClosureAnalysis_Attribute_func]
self.answer = [{"TestClosureAnalysis_Attribute_func": set({'i'})}]


if __name__ == '__main__':
unittest.main()