Skip to content

Commit

Permalink
[Cherry-Pick]Cherry pick PR41200, PR41474, PR41382 (#41509)
Browse files Browse the repository at this point in the history
* Use `self`as a parameter of _hash_with_id function to avoid error caused by hash_id reuse (#41200)

* Add fill_constant_batch_size YAML and UT (#41474)

* Switch some dy2st UT to eager mode (#41382)

* Sitch some dy2st UT to eager mode

* Fix test_lstm and remove test_transformer

* Run test_resnet_v2 in old dy mode
  • Loading branch information
0x45f committed Apr 8, 2022
1 parent ebe72b8 commit ae34db3
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _to_static_var(self, to_parameter=False, **kwargs):

# Note: getattr(self, attr, None) will call x.grad=x.gradient(), but gradient() only available in dygraph.
# It will fail. So, for propery that different between dynamic and static graph, should not getattr(self, attr, None).
attr_not_need_keys = ['grad', 'T']
attr_not_need_keys = ['grad', 'T', 'place', '_place_str']
if isinstance(self, (ParamBase, EagerParamBase)):
attr_kwargs = self.__dict__.copy()
else:
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/fluid/layers/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,18 @@ def fill_constant_batch_size_like(input,
input=like, shape=[1], value=0, dtype='int64') #like=[[10, 10]] data=[0]
"""
if in_dygraph_mode():
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)

place = _current_expected_place()
if force_cpu:
place = core.CPUPlace()
out = _C_ops.final_state_full_batch_size_like(
input, shape, dtype, value, input_dim_idx, output_dim_idx, place)
out.stop_gradient = True
return out

helper = LayerHelper("fill_constant_batch_size_like", **locals())
out = helper.create_variable_for_type_inference(dtype=dtype)
attrs = {
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,13 @@ foreach(TEST_OP ${TEST_OPS_WITH_GC})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()

# Switch some dy2st UT to eager mode
set(TEST_EAGER_OPS test_jit_save_load test_translated_layer)
foreach(TEST_OP ${TEST_EAGER_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS FLAGS_enable_eager_mode=1)
endforeach()

if ((NOT WITH_GPU) AND (NOT WITH_XPU) AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
list(REMOVE_ITEM TEST_OPS "test_fleet_graph_execution_meta_optimizer")
list(REMOVE_ITEM TEST_OPS "test_gen_nccl_id_op")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set(DY2ST_EAGER_TEST_ENVS ${GC_ENVS} FLAGS_enable_eager_mode=1)
set(TEST_EAGER_OPS test_bmn test_break_continue test_ifelse test_loop test_mnist_amp
test_mnist_pure_fp16 test_mobile_net test_program_translator test_ptb_lm test_reinforcement_learning
test_resnet test_resnet_amp test_resnet_pure_fp16 test_se_resnet test_sentiment test_seq2seq
test_tsm test_word2vec test_yolov3)
test_tsm test_word2vec test_yolov3 test_bert test_cycle_gan test_lstm test_simnet)
list(REMOVE_ITEM TEST_OPS test_lac)
# NOTE(Aurelius84): In case of Windows CI, if open ON_INFER, RWLOCK of Scope will
# be removed and will cause some random failed in multi-thread.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from __future__ import print_function

import os
os.environ["FLAGS_enable_eager_mode"] = "0"
import math
import time
import unittest
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import paddle
import paddle.fluid.core as core
from paddle.static import program_guard, Program
import paddle.compat as cpt
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid.framework import convert_np_dtype_to_dtype_

paddle.enable_static()


def fill_constant_batch_size_like(input,
shape,
value,
data_type,
input_dim_idx=0,
output_dim_idx=0,
force_cpu=False):
return paddle.fluid.layers.fill_constant_batch_size_like(
input, shape, data_type, value, input_dim_idx, output_dim_idx,
force_cpu)


class TestFillConstatnBatchSizeLike1(OpTest):
# test basic
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.python_api = fill_constant_batch_size_like
self.init_data()

input = np.zeros(self.shape)
out = np.full_like(input, self.value, self.dtype)

self.inputs = {'Input': input}
self.outputs = {'Out': out}
self.attrs = {
'shape': self.shape,
'dtype': convert_np_dtype_to_dtype_(self.dtype),
'value': self.value,
'input_dim_idx': self.input_dim_idx,
'output_dim_idx': self.output_dim_idx,
'force_cpu': self.force_cpu
}

def init_data(self):
self.shape = [10, 10]
self.dtype = np.float32
self.value = 100
self.input_dim_idx = 0
self.output_dim_idx = 0
self.force_cpu = False

def test_check_output(self):
self.check_output(check_eager=True)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_run_program_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_program_desc(self):
def prepare_attrs(self):
return ('global_block', self.program_desc.block(0), 'start_op_index', 0,
'end_op_index', self.fwd_op_num, 'program_id',
_hash_with_id(self.program_desc))
_hash_with_id(self.program_desc, self))

def get_param_grad_names(self):
grad_names = []
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,18 @@
data_type : dtype
backend : place

- api : full_batch_size_like
args : (Tensor input, int[] shape, DataType dtype, Scalar value, int input_dim_idx, int output_dim_idx, Place place=CPUPlace())
output: Tensor
infer_meta :
func : FullBatchSizeLikeInferMeta
param : [input, shape, value, dtype, input_dim_idx, output_dim_idx]
kernel :
func : full_batch_size_like
param : [input, shape, value, dtype, input_dim_idx, output_dim_idx]
data_type : dtype
backend : place

- api : full_like
args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor
Expand Down

0 comments on commit ae34db3

Please sign in to comment.