Skip to content

Commit

Permalink
Add yaml for flatten_contiguous_range OP (#41345)
Browse files Browse the repository at this point in the history
* Add yaml for flatten_contiguous_range OP

* update

* Fix typos

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
  • Loading branch information
From00 and Shixiaowei02 committed Apr 3, 2022
1 parent 3152f3f commit c5285cc
Show file tree
Hide file tree
Showing 12 changed files with 33 additions and 86 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/kernels/flatten_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ namespace phi {

template <typename T, typename Context>
void FlattenGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& xshape,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto xshape_dims = xshape.dims();
dev_ctx.Alloc(x_grad, out_grad.dtype());
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/flatten_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ namespace phi {

template <typename T, typename Context>
void FlattenGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& xshape,
const DenseTensor& out_grad,
DenseTensor* x_grad);

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/ops/compat/flatten_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature FlattenGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"flatten_grad", {GradVarName("Out"), "XShape"}, {}, {GradVarName("X")});
"flatten_grad", {"XShape", GradVarName("Out")}, {}, {GradVarName("X")});
}

} // namespace phi
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/tests/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ cc_test(test_dot_api SRCS test_dot_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_empty_api SRCS test_empty_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_fill_api SRCS test_fill_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_cast_api SRCS test_cast_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS ${COMMON_API_TEST_DEPS})
Expand Down
75 changes: 0 additions & 75 deletions paddle/phi/tests/api/test_flatten_api.cc

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

class TestFlattenOp(OpTest):
def setUp(self):
self.python_api = paddle.flatten
self.python_out_sig = ["Out"]
self.op_type = "flatten_contiguous_range"
self.start_axis = 0
self.stop_axis = -1
Expand All @@ -35,10 +37,10 @@ def setUp(self):
}

def test_check_output(self):
self.check_output(no_check_set=["XShape"])
self.check_output(no_check_set=["XShape"], check_eager=True)

def test_check_grad(self):
self.check_grad(["X"], "Out")
self.check_grad(["X"], "Out", check_eager=True)

def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,11 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
if start_axis > stop_axis:
raise ValueError("The stop_axis should be larger than stat_axis")

if paddle.in_dynamic_mode():
if in_dygraph_mode():
dy_out, _ = _C_ops.final_state_flatten(x, start_axis, stop_axis)
return dy_out

if _in_legacy_dygraph():
dy_out, _ = _C_ops.flatten_contiguous_range(x, 'start_axis', start_axis,
'stop_axis', stop_axis)
return dy_out
Expand Down
10 changes: 7 additions & 3 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,15 @@

- api : flatten
args : (Tensor x, int start_axis, int stop_axis)
output : Tensor
output : Tensor(out), Tensor(xshape)
infer_meta :
func : FlattenInferMeta
func : FlattenWithXShapeInferMeta
kernel :
func : flatten
func : flatten_with_xshape
backend : x
inplace : (x -> out)
view : (x -> out)
backward : flatten_grad

# flip
- api : flip
Expand Down
13 changes: 13 additions & 0 deletions python/paddle/utils/code_gen/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,19 @@
kernel :
func : expm1_grad

- backward_api : flatten_grad
forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param : [xshape]
kernel :
func : flatten_grad
data_type: out_grad
backend: out_grad
layout: out_grad

- backward_api : floor_grad
forward : floor(Tensor x) -> Tensor(out)
args : (Tensor out_grad)
Expand Down
2 changes: 1 addition & 1 deletion tools/infrt/skipped_phi_api.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"phi_apis":["conj", "nll_loss"],
"phi_apis":["conj", "nll_loss", "flatten"],
"phi_kernels":["equal_all"]
}

0 comments on commit c5285cc

Please sign in to comment.