diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index fad402cc980e5..64d3b1d3d2960 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3128,7 +3128,7 @@ def attr_type(self, name): Returns: core.AttrType: the attribute type. """ - return self.desc.attr_type(name) + return self.desc.attr_type(name, True) def _set_attr(self, name, val): """ @@ -3280,6 +3280,41 @@ def _blocks_attr_ids(self, name): return self.desc._blocks_attr_ids(name) + def _var_attr(self, name): + """ + Get the Variable attribute by name. + + Args: + name(str): the attribute name. + + Returns: + Variable: the Variable attribute. + """ + attr_type = self.desc.attr_type(name, True) + assert attr_type == core.AttrType.VAR, "Required type attr({}) is Variable, but received {}".format( + name, attr_type) + attr_var_name = self.desc.attr(name, True).name() + return self.block._var_recursive(attr_var_name) + + def _vars_attr(self, name): + """ + Get the Variables attribute by name. + + Args: + name(str): the attribute name. + + Returns: + Variables: the Variables attribute. + """ + attr_type = self.desc.attr_type(name, True) + assert attr_type == core.AttrType.VARS, "Required type attr({}) is list[Variable], but received {}".format( + name, attr_type) + attr_vars = [ + self.block._var_recursive(var.name()) + for var in self.desc.attr(name, True) + ] + return attr_vars + def all_attrs(self): """ Get the attribute dict. @@ -3290,16 +3325,17 @@ def all_attrs(self): attr_names = self.attr_names attr_map = {} for n in attr_names: - attr_type = self.desc.attr_type(n) + attr_type = self.desc.attr_type(n, True) if attr_type == core.AttrType.BLOCK: attr_map[n] = self._block_attr(n) - continue - - if attr_type == core.AttrType.BLOCKS: + elif attr_type == core.AttrType.BLOCKS: attr_map[n] = self._blocks_attr(n) - continue - - attr_map[n] = self.attr(n) + elif attr_type == core.AttrType.VAR: + attr_map[n] = self._var_attr(n) + elif attr_type == core.AttrType.VARS: + attr_map[n] = self._vars_attr(n) + else: + attr_map[n] = self.attr(n) return attr_map diff --git a/python/paddle/fluid/tests/unittests/test_attribute_var.py b/python/paddle/fluid/tests/unittests/test_attribute_var.py index cabbfb826b53b..6e8e3c6675087 100644 --- a/python/paddle/fluid/tests/unittests/test_attribute_var.py +++ b/python/paddle/fluid/tests/unittests/test_attribute_var.py @@ -96,6 +96,10 @@ def test_static(self): infer_out = self.infer_prog() self.assertEqual(infer_out.shape, (10, 10)) + self.assertEqual( + main_prog.block(0).ops[4].all_attrs()['dropout_prob'].name, + p.name) + class TestTileTensorList(UnittestBase): diff --git a/python/paddle/fluid/tests/unittests/test_reverse_op.py b/python/paddle/fluid/tests/unittests/test_reverse_op.py index e2260082fc968..f090cf1c8de11 100644 --- a/python/paddle/fluid/tests/unittests/test_reverse_op.py +++ b/python/paddle/fluid/tests/unittests/test_reverse_op.py @@ -258,6 +258,12 @@ def call_func(self, x): # axes is a List[Variable] axes = [paddle.assign([0]), paddle.assign([2])] out = paddle.fluid.layers.reverse(x, axes) + + # check attrs + axis_attrs = paddle.static.default_main_program().block( + 0).ops[-1].all_attrs()["axis"] + self.assertTrue(axis_attrs[0].name, axes[0].name) + self.assertTrue(axis_attrs[1].name, axes[1].name) return out