From 63b99531f47a49004157ee385189965defbf1f2e Mon Sep 17 00:00:00 2001 From: zhangminxu01 Date: Wed, 10 Jan 2024 02:56:54 +0800 Subject: [PATCH] add xpu assign_value --- paddle/fluid/operators/assign_value_op_xpu.cc | 23 +++++++++++++++++++ .../fluid/platform/device/xpu/xpu2_op_list.h | 6 ++++- 2 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/assign_value_op_xpu.cc diff --git a/paddle/fluid/operators/assign_value_op_xpu.cc b/paddle/fluid/operators/assign_value_op_xpu.cc new file mode 100644 index 0000000000000..eaad217ec4f58 --- /dev/null +++ b/paddle/fluid/operators/assign_value_op_xpu.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/assign_value_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL(assign_value, + ops::AssignValueKernel, + ops::AssignValueKernel, + ops::AssignValueKernel, + ops::AssignValueKernel); diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index d438f0e8d2a2d..7e8ce988503fe 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -47,7 +47,11 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace())})}, {"assign_value", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace())})}, {"batch_norm_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},