From c1d78603ca6d818aec775733521e04db9c145716 Mon Sep 17 00:00:00 2001 From: zyt1024 <42999008+zyt1024@users.noreply.github.com> Date: Thu, 28 Dec 2023 17:14:26 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Complex=20op=E3=80=91add=20complex=20s?= =?UTF-8?q?upport=20for=20assign=5Fvalue=20=20(#59536)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * support_complex_for_assign_value * add test complex test for test_program_converter * add complex test for assign_value xpu * solve conflict * fix timeout * fix CE infer bug * fix program convert bug * fix program convert bug for assign_value --------- Co-authored-by: zyt1024 <1522064645@qq.com> --- paddle/fluid/framework/op_version_proto.cc | 1 + paddle/fluid/framework/program_converter.cc | 85 +++++++- .../ir_adaptor/translator/op_translator.cc | 17 +- .../ops_signature/assign_value_sig.cc | 26 +-- .../pir/dialect/operator/ir/op_attribute.cc | 4 + .../pir/dialect/operator/ir/op_attribute.h | 4 +- .../fluid/pir/dialect/operator/utils/utils.h | 6 + paddle/fluid/pybind/op_function_common.cc | 21 +- paddle/phi/api/yaml/op_version.yaml | 15 ++ paddle/phi/api/yaml/static_ops.yaml | 2 +- paddle/phi/kernels/assign_kernel.cc | 12 +- paddle/pir/core/builder.h | 5 + paddle/pir/core/builtin_attribute.cc | 10 + paddle/pir/core/builtin_attribute.h | 23 +++ paddle/pir/core/builtin_attribute_storage.h | 40 ++++ paddle/pir/core/builtin_dialect.cc | 4 +- paddle/pir/core/ir_printer.cc | 4 + python/paddle/nn/initializer/Bilinear.py | 2 +- python/paddle/nn/initializer/assign.py | 6 +- python/paddle/nn/initializer/dirac.py | 4 +- python/paddle/tensor/creation.py | 58 ++---- .../test_program_translator.py | 14 +- test/ir/inference/CMakeLists.txt | 4 +- test/ir/inference/test_mul_gru_fuse_pass.py | 2 +- test/ir/inference/test_mul_lstm_fuse_pass.py | 2 +- .../inference/test_seq_concat_fc_fuse_pass.py | 4 +- test/legacy_test/test_assign_value_op.py | 101 +++++++-- test/legacy_test/test_initializer.py | 3 +- test/legacy_test/test_initializer_nn.py | 4 +- test/legacy_test/test_program_converter.py | 193 ++++++++++++++++++ test/xpu/test_assign_value_op_xpu.py | 61 +++++- 31 files changed, 614 insertions(+), 123 deletions(-) diff --git a/paddle/fluid/framework/op_version_proto.cc b/paddle/fluid/framework/op_version_proto.cc index 2a93e755b085b..8be9323098c97 100644 --- a/paddle/fluid/framework/op_version_proto.cc +++ b/paddle/fluid/framework/op_version_proto.cc @@ -21,6 +21,7 @@ namespace pb { const std::unordered_map& GetLegacyOpVersions() { static std::unordered_map op_versions = { {"not_equal", 1}, + {"assign_value", 0}, {"fake_channel_wise_dequantize_max_abs", 2}, {"yolo_box", 1}, {"data_norm", 1}, diff --git a/paddle/fluid/framework/program_converter.cc b/paddle/fluid/framework/program_converter.cc index fc60a0abf676e..82739e788bba3 100644 --- a/paddle/fluid/framework/program_converter.cc +++ b/paddle/fluid/framework/program_converter.cc @@ -117,6 +117,41 @@ void ConvertSetValueOp(OpDesc* op) { } } +void ConvertAssignValueOp(OpDesc* op) { + std::vector values = PADDLE_GET_CONST( + std::vector, op->GetAttr("values", false)); + op->RemoveAttr("values"); + op->SetAttr("bool_values", std::vector()); + op->SetAttr("fp32_values", std::vector()); + op->SetAttr("int32_values", std::vector()); + op->SetAttr("int64_values", std::vector()); + + phi::DataType dtype = phi::DataType::FLOAT32; + if (values.size()) { + dtype = values.at(0).dtype(); + } + + switch (dtype) { + case phi::DataType::BOOL: + op->SetAttr("bool_values", ExtractPlainVector(values)); + break; + case phi::DataType::FLOAT32: + op->SetAttr("fp32_values", ExtractPlainVector(values)); + break; + case phi::DataType::FLOAT64: + op->SetAttr("fp32_values", ExtractPlainVector(values)); + break; + case phi::DataType::INT32: + op->SetAttr("int32_values", ExtractPlainVector(values)); + break; + case phi::DataType::INT64: + op->SetAttr("int64_values", ExtractPlainVector(values)); + break; + default: + PD_THROW("Invalid data type `", dtype, "`."); + } +} + void ConvertProgram(ProgramDesc* program) { PADDLE_ENFORCE_NOT_NULL( program, @@ -144,6 +179,9 @@ void ConvertProgram(ProgramDesc* program) { if (op_type == "set_value" || op_type == "set_value_grad") { ConvertSetValueOp(op); } + if (op_type == "assign_value") { + ConvertAssignValueOp(op); + } } } } @@ -204,6 +242,45 @@ void ConvertSetValueOp(OpDesc* op) { op->SetAttr("values", values); } +void ConvertAssignValueOp(OpDesc* op) { + VLOG(3) << "convert old assign value op to new"; + std::vector values; + + if (op->HasAttr("bool_values")) { + std::vector bool_values = + PADDLE_GET_CONST(std::vector, op->GetAttr("bool_values", false)); + if (bool_values.size()) { + values = WrapAsScalars(bool_values); + } + op->RemoveAttr("bool_values"); + } + if (op->HasAttr("fp32_values")) { + std::vector fp32_values = + PADDLE_GET_CONST(std::vector, op->GetAttr("fp32_values", false)); + if (fp32_values.size()) { + values = WrapAsScalars(fp32_values); + } + op->RemoveAttr("fp32_values"); + } + if (op->HasAttr("int32_values")) { + std::vector int32_values = + PADDLE_GET_CONST(std::vector, op->GetAttr("int32_values", false)); + if (int32_values.size()) { + values = WrapAsScalars(int32_values); + } + op->RemoveAttr("int32_values"); + } + if (op->HasAttr("int64_values")) { + std::vector int64_values = PADDLE_GET_CONST( + std::vector, op->GetAttr("int64_values", false)); + if (int64_values.size()) { + values = WrapAsScalars(int64_values); + } + op->RemoveAttr("int64_values"); + } + op->SetAttr("values", values); +} + void ConvertProgram(ProgramDesc* program) { PADDLE_ENFORCE_NOT_NULL( program, @@ -214,6 +291,7 @@ void ConvertProgram(ProgramDesc* program) { const std::unordered_map& legacy_op_versions = legacy_op_results.second; + VLOG(3) << "is_legacy_program : " << is_legacy_program; if (!is_legacy_program) return; VLOG(3) << "Updating Program Version and OpVersionMap"; @@ -232,10 +310,15 @@ void ConvertProgram(ProgramDesc* program) { for (size_t j = 0; j < num_ops; j++) { OpDesc* op = block->Op(static_cast(j)); const std::string op_type = op->Type(); + + if (op_type == "assign_value") { + VLOG(3) << "Converting program from old to new, op_type=" << op_type; + ConvertAssignValueOp(op); + } if (!legacy_op_versions.count(op_type)) { continue; } - + VLOG(3) << "Converting program from old to new, op_type=" << op_type; if (op_type == "set_value" || op_type == "set_value_grad") { ConvertSetValueOp(op); } diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 626073d143e3e..c64004c7191dd 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -972,19 +972,20 @@ struct AssignValueOpTranscriber : public OpTranscriber { ctx, phi::Place(phi::AllocationType::UNDEFINED)); attribute_map["place"] = attr_place; - int dtype = paddle::get(op_desc.GetAttr("dtype")); - - if (dtype == /*BOOL*/ 0) { + if (op_desc.HasAttr("bool_values")) { legacy_attr = op_desc.GetAttr("bool_values"); - } else if (dtype == /*INT32*/ 2) { - legacy_attr = op_desc.GetAttr("int32_values"); - } else if (dtype == /*FP32*/ 5) { + } else if (op_desc.HasAttr("fp32_values")) { legacy_attr = op_desc.GetAttr("fp32_values"); - } else if (dtype == /*INT64*/ 3) { + } else if (op_desc.HasAttr("int32_values")) { + legacy_attr = op_desc.GetAttr("int32_values"); + } else if (op_desc.HasAttr("int64_values")) { legacy_attr = op_desc.GetAttr("int64_values"); + } else if (op_desc.HasAttr("values")) { + legacy_attr = op_desc.GetAttr("values"); } else { IR_THROW( - "Op assign_value should have attribute `**_values` but not find"); + "Op assign_value should have attribute `**_values` or `values` but " + "not find"); } pir::Attribute attr_values = attribute_translator( diff --git a/paddle/fluid/operators/ops_signature/assign_value_sig.cc b/paddle/fluid/operators/ops_signature/assign_value_sig.cc index 977c2260e59b9..ae14c5a9d7879 100644 --- a/paddle/fluid/operators/ops_signature/assign_value_sig.cc +++ b/paddle/fluid/operators/ops_signature/assign_value_sig.cc @@ -18,30 +18,8 @@ namespace phi { KernelSignature AssignValueOpArgumentMapping( const ArgumentMappingContext& ctx) { - // Here we must use `dtype` attr to determine which attr to use, we can't - // judge by whether the attr is empty, some unittests will failed - int dtype = paddle::any_cast(ctx.Attr("dtype")); - // heer we can't depend on the fluid proto::VarType, so we use the dtype enum - // value directly, If the enum value is updated, the code also needs to be - // updated here, but the probability of updating the enum value is very low - if (dtype == /*BOOL*/ 0) { - return KernelSignature( - "assign_value", {}, {"shape", "dtype", "bool_values"}, {"Out"}); - } else if (dtype == /*INT32*/ 2) { - return KernelSignature( - "assign_value", {}, {"shape", "dtype", "int32_values"}, {"Out"}); - } else if (dtype == /*FP32*/ 5) { - return KernelSignature( - "assign_value", {}, {"shape", "dtype", "fp32_values"}, {"Out"}); - } else if (dtype == /*FP64*/ 6) { - return KernelSignature( - "assign_value", {}, {"shape", "dtype", "fp64_values"}, {"Out"}); - } else if (dtype == /*INT64*/ 3) { - return KernelSignature( - "assign_value", {}, {"shape", "dtype", "int64_values"}, {"Out"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } + return KernelSignature( + "assign_value", {}, {"shape", "dtype", "values"}, {"Out"}); } } // namespace phi diff --git a/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc b/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc index 3134214cf9029..10ae5a77d9f4a 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc @@ -43,6 +43,10 @@ phi::Scalar ScalarAttribute::data() { return phi::Scalar(dyn_cast().data()); } else if (isa()) { return phi::Scalar(dyn_cast().AsString()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported ir attribute when casting it into " diff --git a/paddle/fluid/pir/dialect/operator/ir/op_attribute.h b/paddle/fluid/pir/dialect/operator/ir/op_attribute.h index 0b0973a5205c8..f58803fa20002 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_attribute.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_attribute.h @@ -50,7 +50,9 @@ class ScalarAttribute : public pir::Attribute { (val.type_id() == pir::Int32Attribute::type_id()) || (val.type_id() == pir::IndexAttribute::type_id()) || (val.type_id() == pir::Int64Attribute::type_id()) || - (val.type_id() == pir::StrAttribute::type_id()); + (val.type_id() == pir::StrAttribute::type_id()) || + (val.type_id() == pir::Complex64Attribute::type_id()) || + (val.type_id() == pir::Complex128Attribute::type_id()); } static pir::Attribute get(pir::IrContext *ctx, phi::Scalar scalar) { diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.h b/paddle/fluid/pir/dialect/operator/utils/utils.h index 0e14077bb8559..7a8a5083a3dae 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.h +++ b/paddle/fluid/pir/dialect/operator/utils/utils.h @@ -120,6 +120,12 @@ static inline pir::Attribute TransToIrAttribute(phi::Scalar scalar, return pir::Int64Attribute::get(ctx, scalar.to()); case phi::DataType::BOOL: return pir::BoolAttribute::get(ctx, scalar.to()); + case phi::DataType::COMPLEX64: + return pir::Complex64Attribute::get( + ctx, scalar.to>()); + case phi::DataType::COMPLEX128: + return pir::Complex128Attribute::get( + ctx, scalar.to>()); default: PADDLE_THROW(phi::errors::Unimplemented( "Unsupported phi data type `%s` when casting it into " diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 489b25f35867c..0555724a49cfa 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -77,7 +77,7 @@ bool PyObject_CheckLongOrToLong(PyObject** obj) { } if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name) // NOLINT - .find("numpy") != std::string::npos) { + .find("numpy.int") != std::string::npos) { auto to = PyNumber_Long(*obj); if (to) { *obj = to; @@ -95,8 +95,12 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) { (((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT return true; } - if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name) // NOLINT - .find("numpy") != std::string::npos) { + auto type_name = + std::string(reinterpret_cast((*obj)->ob_type)->tp_name); + VLOG(4) << "type_name: " << type_name; + + if (type_name.find("numpy") != std::string::npos && + type_name.find("numpy.complex") == std::string::npos) { auto to = PyNumber_Float(*obj); if (to) { *obj = to; @@ -107,11 +111,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) { } bool PyObject_CheckComplexOrToComplex(PyObject** obj) { - if (PyComplex_Check(*obj) || PyLong_Check(*obj) || PyFloat_Check(*obj) || + if (PyComplex_Check(*obj) || PyObject_TypeCheck(*obj, g_vartype_pytype) || // NOLINT PyObject_TypeCheck(*obj, p_tensor_type)) { // NOLINT return true; } + if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name) // NOLINT + .find("numpy.complex") != std::string::npos) { + return true; + } // consider numpy cfloat & numpy cdouble? return false; } @@ -242,10 +250,15 @@ double CastPyArg2Double(PyObject* obj, phi::dtype::complex CastPyArg2Complex(PyObject* obj, const std::string& op_type, ssize_t arg_pos) { + PyTypeObject* type = obj->ob_type; + auto type_name = std::string(type->tp_name); if (PyComplex_Check(obj)) { double real = PyComplex_RealAsDouble(obj); double imag = PyComplex_ImagAsDouble(obj); return phi::dtype::complex(real, imag); // NOLINT + } else if (type_name == "numpy.complex64") { + Py_complex v = PyComplex_AsCComplex(obj); + return phi::dtype::complex(v.real, v.imag); } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " diff --git a/paddle/phi/api/yaml/op_version.yaml b/paddle/phi/api/yaml/op_version.yaml index 7c9618f52b17b..2bd09abd311ae 100644 --- a/paddle/phi/api/yaml/op_version.yaml +++ b/paddle/phi/api/yaml/op_version.yaml @@ -55,6 +55,21 @@ - delete_attr : atol comment : The attribute 'atol' is deleted. The reason why it is deleted is that attributes do not support a float64 value and it is changed to a tensor. +- op : assign_value + version : + - checkpoint : Upgrade assign_value, remove plain attributes in favor of generic attribute. + action : + - add_attr : values + comment : replace generic types with scalar. + default : std::vector() + - delete_attr : bool_values + comment : remove plain attributes. + - delete_attr : fp32_values + comment : remove plain attributes. + - delete_attr : int32_values + comment : remove plain attributes. + - delete_attr : int64_values + comment : remove plain attributes. - op : auc version : diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 5fe9ea4260d40..6ff2bfe427122 100755 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -90,7 +90,7 @@ backward : assign_grad - op : assign_value - args : (int[] shape, DataType dtype, int[] bool_values = {}, float[] fp32_values = {}, double[] fp64_values = {}, int[] int32_values = {}, int64_t[] int64_values = {}) + args : (int[] shape, DataType dtype, Scalar[] values = {}) output : Tensor(out) infer_meta : func : AssignValueInferMeta diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index b4504f83818d7..f54dfec2f6ad2 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -137,7 +137,9 @@ PD_REGISTER_KERNEL(assign_value, float, double, int8_t, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_REGISTER_KERNEL_FOR_ALL_DTYPE(assign, @@ -165,7 +167,9 @@ PD_REGISTER_KERNEL(assign_value, float, double, int8_t, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} #endif #ifdef PADDLE_WITH_XPU @@ -193,5 +197,7 @@ PD_REGISTER_KERNEL(assign_value, int, float, double, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} #endif diff --git a/paddle/pir/core/builder.h b/paddle/pir/core/builder.h index c5e3472bb070a..158d82f3fbcbe 100644 --- a/paddle/pir/core/builder.h +++ b/paddle/pir/core/builder.h @@ -16,6 +16,7 @@ #include +#include "paddle/phi/common/complex.h" #include "paddle/pir/core/block.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/operation.h" @@ -44,6 +45,8 @@ class Int64Attribute; class ArrayAttribute; class PointerAttribute; class TensorNameAttribute; +class Complex64Attribute; +class Complex128Attribute; using InsertionPoint = std::pair; /// @@ -150,6 +153,8 @@ class Builder { IR_API ArrayAttribute array_attr(const std::vector &value); IR_API PointerAttribute pointer_attr(void *value); IR_API TensorNameAttribute tensor_name_attr(const std::string &value); + IR_API Complex64Attribute complex64_attr(phi::dtype::complex value); + IR_API Complex128Attribute complex128_attr(phi::dtype::complex value); private: Operation *Insert(Operation *op); diff --git a/paddle/pir/core/builtin_attribute.cc b/paddle/pir/core/builtin_attribute.cc index a817fb48c55fc..32136371d5780 100644 --- a/paddle/pir/core/builtin_attribute.cc +++ b/paddle/pir/core/builtin_attribute.cc @@ -32,6 +32,14 @@ void* PointerAttribute::data() const { return storage()->data(); } Type TypeAttribute::data() const { return storage()->data(); } +phi::dtype::complex Complex64Attribute::data() const { + return storage()->data(); +} + +phi::dtype::complex Complex128Attribute::data() const { + return storage()->data(); +} + bool StrAttribute::operator<(const StrAttribute& right) const { return storage() < right.storage(); } @@ -109,3 +117,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(pir::ArrayAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::PointerAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::TypeAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::TensorNameAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Complex64Attribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Complex128Attribute) diff --git a/paddle/pir/core/builtin_attribute.h b/paddle/pir/core/builtin_attribute.h index a1751a8c248b8..59345c9e1b4f6 100644 --- a/paddle/pir/core/builtin_attribute.h +++ b/paddle/pir/core/builtin_attribute.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/phi/common/complex.h" #include "paddle/pir/core/attribute.h" #include "paddle/pir/core/builtin_attribute_storage.h" #include "paddle/pir/core/utils.h" @@ -28,6 +29,26 @@ class IR_API BoolAttribute : public Attribute { bool data() const; }; +class IR_API Complex64Attribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Complex64Attribute, + Complex64AttributeStorage); + + phi::dtype::complex data() const; +}; + +class IR_API Complex128Attribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Complex128Attribute, + Complex128AttributeStorage); + + phi::dtype::complex data() const; +}; + class IR_API FloatAttribute : public Attribute { public: using Attribute::Attribute; @@ -157,3 +178,5 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ArrayAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::PointerAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::TypeAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::TensorNameAttribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Complex64Attribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Complex128Attribute) diff --git a/paddle/pir/core/builtin_attribute_storage.h b/paddle/pir/core/builtin_attribute_storage.h index 533b0a4ad03e9..9e66fb6b010c9 100644 --- a/paddle/pir/core/builtin_attribute_storage.h +++ b/paddle/pir/core/builtin_attribute_storage.h @@ -19,6 +19,7 @@ #include #include "paddle/common/enforce.h" +#include "paddle/phi/common/complex.h" #include "paddle/pir/core/attribute.h" #include "paddle/pir/core/attribute_base.h" #include "paddle/pir/core/type.h" @@ -149,4 +150,43 @@ struct ArrayAttributeStorage : public AttributeStorage { const size_t size_; }; +struct Complex64AttributeStorage : public AttributeStorage { + using ParamKey = phi::dtype::complex; + explicit Complex64AttributeStorage(const ParamKey &key) { data_ = key; } + static Complex64AttributeStorage *Construct(const ParamKey &key) { + return new Complex64AttributeStorage(key); + } + static std::size_t HashValue(const ParamKey &key) { + std::stringstream complex_str; + complex_str << key.real << "+" << key.imag << "i"; + return std::hash{}(complex_str.str()); + } + + bool operator==(ParamKey key) const { return data_ == key; } + + phi::dtype::complex data() const { return data_; } + + private: + phi::dtype::complex data_; +}; + +struct Complex128AttributeStorage : public AttributeStorage { + using ParamKey = phi::dtype::complex; + explicit Complex128AttributeStorage(const ParamKey &key) { data_ = key; } + static Complex128AttributeStorage *Construct(const ParamKey &key) { + return new Complex128AttributeStorage(key); + } + static std::size_t HashValue(const ParamKey &key) { + std::stringstream complex_str; + complex_str << key.real << "+" << key.imag << "i"; + return std::hash{}(complex_str.str()); + } + + bool operator==(ParamKey key) const { return data_ == key; } + + phi::dtype::complex data() const { return data_; } + + private: + phi::dtype::complex data_; +}; } // namespace pir diff --git a/paddle/pir/core/builtin_dialect.cc b/paddle/pir/core/builtin_dialect.cc index 4bba7185384a3..91835c3029dc7 100644 --- a/paddle/pir/core/builtin_dialect.cc +++ b/paddle/pir/core/builtin_dialect.cc @@ -50,7 +50,9 @@ void BuiltinDialect::initialize() { Int64Attribute, ArrayAttribute, TypeAttribute, - TensorNameAttribute>(); + TensorNameAttribute, + Complex64Attribute, + Complex128Attribute>(); RegisterOps()) { os << "(Pointer)" << p.data(); + } else if (auto p = attr.dyn_cast()) { + os << "(Complex64)" << p.data(); + } else if (auto p = attr.dyn_cast()) { + os << "(Complex128)" << p.data(); } else if (auto arr = attr.dyn_cast()) { const auto& vec = arr.AsVector(); os << "["; diff --git a/python/paddle/nn/initializer/Bilinear.py b/python/paddle/nn/initializer/Bilinear.py index cfb18dac02c2a..1da82cbeee970 100644 --- a/python/paddle/nn/initializer/Bilinear.py +++ b/python/paddle/nn/initializer/Bilinear.py @@ -148,7 +148,7 @@ def forward(self, var, block=None): out_var = var if out_dtype in (core.VarDesc.VarType.FP32, core.DataType.FLOAT32): - value_name = "fp32_values" + value_name = "values" values = [float(v) for v in weight.flat] else: raise TypeError("Unsupported dtype %s", var.dtype) diff --git a/python/paddle/nn/initializer/assign.py b/python/paddle/nn/initializer/assign.py index 9274ff5275df0..3988f9f14859d 100644 --- a/python/paddle/nn/initializer/assign.py +++ b/python/paddle/nn/initializer/assign.py @@ -89,13 +89,13 @@ def forward(self, var, block=None): np_value = self._value if out_dtype in (core.VarDesc.VarType.FP32, core.DataType.FLOAT32): - value_name = "fp32_values" + value_name = "values" values = [float(v) for v in np_value.flat] elif out_dtype in (core.VarDesc.VarType.FP64, core.DataType.FLOAT64): - value_name = "fp64_values" + value_name = "values" values = [float(v) for v in np_value.flat] elif out_dtype in (core.VarDesc.VarType.INT32, core.DataType.INT32): - value_name = "int32_values" + value_name = "values" values = [int(v) for v in np_value.flat] elif out_dtype in ( core.VarDesc.VarType.INT8, diff --git a/python/paddle/nn/initializer/dirac.py b/python/paddle/nn/initializer/dirac.py index 7da5cd15b54f7..4aea131684f21 100644 --- a/python/paddle/nn/initializer/dirac.py +++ b/python/paddle/nn/initializer/dirac.py @@ -255,7 +255,7 @@ def __call__(self, var, block=None): attrs={ 'dtype': VarDesc.VarType.INT64, 'shape': [len(idx_list)], - 'int64_values': idx_list, + 'values': idx_list, }, stop_gradient=True, ) @@ -298,7 +298,7 @@ def __call__(self, var, block=None): attrs={ 'dtype': VarDesc.VarType.FP32, 'shape': [len(value_list)], - 'fp32_values': value_list, + 'values': value_list, }, stop_gradient=True, ) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 1fb067edcbb6e..5fbf1f0fbc468 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -16,7 +16,6 @@ import math import re -import warnings import numpy as np @@ -2361,6 +2360,8 @@ def assign(x, output=None): 'uint8', 'int8', 'bool', + 'complex64', + 'complex128', ], 'assign', '(When the type of input in assign is Variable.)', @@ -2408,44 +2409,23 @@ def convert_scalar(x): ) dtype = convert_np_dtype_to_dtype_(input.dtype) - if dtype == core.VarDesc.VarType.FP64: - # Setting FP64 numpy data is not supported in Paddle, so we - # use FP32 here - warnings.warn( - "paddle.assign doesn't support float64 input now due " - "to current platform protobuf data limitation, we convert " - "it to float32" - ) - dtype = core.VarDesc.VarType.FP32 - - if dtype == core.DataType.FLOAT64: - # Setting FP64 numpy data is not supported in Paddle, so we - # use FP32 here - warnings.warn( - "paddle.assign doesn't support float64 input now due " - "to current platform protobuf data limitation, we convert " - "it to float32" - ) - dtype = core.DataType.FLOAT32 - - if dtype in [core.VarDesc.VarType.BOOL, core.DataType.BOOL]: - value_name = "bool_values" - values = [int(v) for v in input.flat] - elif dtype in [core.VarDesc.VarType.FP32, core.DataType.FLOAT32]: - value_name = "fp32_values" - values = [float(v) for v in input.flat] - elif dtype in [core.VarDesc.VarType.INT32, core.DataType.INT32]: - value_name = "int32_values" - values = [int(v) for v in input.flat] - elif dtype in [core.VarDesc.VarType.INT64, core.DataType.INT64]: - value_name = "int64_values" - values = [int(v) for v in input.flat] - else: - raise TypeError( - "When the type of 'input' in assign is numpy.ndarray, " - "the data type of 'input' must be bool, float32, int32 or int64, but " - "received %s." % convert_dtype(dtype) - ) + check_dtype( + dtype, + 'input', + [ + 'float32', + 'float64', + 'int32', + 'int64', + 'bool', + 'complex64', + 'complex128', + ], + 'assign', + '(When the type of input in assign is numpy array.)', + ) + value_name = "values" + values = input.ravel().tolist() if input.size > 1024 * 1024: raise ValueError( "The size of input is too big. Please consider " diff --git a/test/dygraph_to_static/test_program_translator.py b/test/dygraph_to_static/test_program_translator.py index d384c7ad649d9..d6addfe3400bc 100644 --- a/test/dygraph_to_static/test_program_translator.py +++ b/test/dygraph_to_static/test_program_translator.py @@ -314,14 +314,24 @@ def test_ifelse_early_return1(self): answer = np.zeros([2, 2]) + 1 static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1) out = static_func() - np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05) + if isinstance(out, paddle.Tensor): + np.testing.assert_allclose( + paddle.to_tensor(answer), out, rtol=1e-05 + ) + elif isinstance(out, tuple): + np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05) @disable_test_case((ToStaticMode.AST, IrMode.PT)) def test_ifelse_early_return2(self): answer = np.zeros([2, 2]) + 3 static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return2) out = static_func() - np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05) + if isinstance(out, paddle.Tensor): + np.testing.assert_allclose( + paddle.to_tensor(answer), out, rtol=1e-05 + ) + elif isinstance(out, tuple): + np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05) class TestRemoveCommentInDy2St(Dy2StTestBase): diff --git a/test/ir/inference/CMakeLists.txt b/test/ir/inference/CMakeLists.txt index 020b84b4fd32a..185ca22f897f6 100755 --- a/test/ir/inference/CMakeLists.txt +++ b/test/ir/inference/CMakeLists.txt @@ -168,8 +168,8 @@ if(NOT WITH_MKLDNN set_tests_properties(${target} PROPERTIES LABELS "RUN_TYPE=INFER") endforeach() - set_tests_properties(test_mul_lstm_fuse_pass PROPERTIES TIMEOUT 300) - set_tests_properties(test_mul_gru_fuse_pass PROPERTIES TIMEOUT 300) + set_tests_properties(test_mul_lstm_fuse_pass PROPERTIES TIMEOUT 1000) + set_tests_properties(test_mul_gru_fuse_pass PROPERTIES TIMEOUT 600) endif() if(WITH_GPU AND TENSORRT_FOUND) diff --git a/test/ir/inference/test_mul_gru_fuse_pass.py b/test/ir/inference/test_mul_gru_fuse_pass.py index 91c8058c54ec5..0ccbe46724608 100644 --- a/test/ir/inference/test_mul_gru_fuse_pass.py +++ b/test/ir/inference/test_mul_gru_fuse_pass.py @@ -134,7 +134,7 @@ def sample_predictor_configs(self, program_config): def test(self): self.run_and_statis( - quant=False, max_duration=300, passes=["mul_gru_fuse_pass"] + quant=False, max_duration=600, passes=["mul_gru_fuse_pass"] ) diff --git a/test/ir/inference/test_mul_lstm_fuse_pass.py b/test/ir/inference/test_mul_lstm_fuse_pass.py index f6304404c3694..fec34311604ee 100644 --- a/test/ir/inference/test_mul_lstm_fuse_pass.py +++ b/test/ir/inference/test_mul_lstm_fuse_pass.py @@ -120,7 +120,7 @@ def sample_predictor_configs(self, program_config): def test(self): self.run_and_statis( - quant=False, max_duration=300, passes=["mul_lstm_fuse_pass"] + quant=False, max_duration=1000, passes=["mul_lstm_fuse_pass"] ) diff --git a/test/ir/inference/test_seq_concat_fc_fuse_pass.py b/test/ir/inference/test_seq_concat_fc_fuse_pass.py index 4f1a0cbb7af83..68e446c5a6469 100644 --- a/test/ir/inference/test_seq_concat_fc_fuse_pass.py +++ b/test/ir/inference/test_seq_concat_fc_fuse_pass.py @@ -140,7 +140,9 @@ def teller1(program_config, predictor_config): ) def test(self): - self.run_and_statis(quant=False, passes=["seq_concat_fc_fuse_pass"]) + self.run_and_statis( + quant=False, passes=["seq_concat_fc_fuse_pass"], max_duration=1000 + ) if __name__ == "__main__": diff --git a/test/legacy_test/test_assign_value_op.py b/test/legacy_test/test_assign_value_op.py index 6ff4282d9fc55..10ff186e2e966 100644 --- a/test/legacy_test/test_assign_value_op.py +++ b/test/legacy_test/test_assign_value_op.py @@ -22,24 +22,24 @@ from paddle.base import framework -def assign_value_wrapper( - shape=[], dtype=base.core.VarDesc.VarType.FP32, values=0.0 -): - if paddle.framework.in_dynamic_mode(): - tensor = paddle.Tensor() - else: - np_type = paddle.base.data_feeder._PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] - tensor = paddle.zeros(list(shape), np_type) - dtype = paddle.pir.core.convert_np_dtype_to_dtype_(np_type) - return paddle._C_ops.assign_value_( - tensor, shape, dtype, values, framework._current_expected_place() - ) +def wrap_assign_value_wrapper(dtype=base.core.VarDesc.VarType.FP32): + def assign_value_wrapper(shape=[], dtype=dtype, values=0.0): + if paddle.framework.in_dynamic_mode(): + tensor = paddle.Tensor() + else: + np_type = paddle.base.data_feeder._PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] + tensor = paddle.zeros(list(shape), np_type) + dtype = paddle.pir.core.convert_np_dtype_to_dtype_(np_type) + return paddle._C_ops.assign_value_( + tensor, shape, dtype, values, framework._current_expected_place() + ) + + return assign_value_wrapper class TestAssignValueOp(op_test.OpTest): def setUp(self): self.op_type = "assign_value" - self.python_api = assign_value_wrapper self.inputs = {} self.attrs = {} self.init_data() @@ -47,11 +47,12 @@ def setUp(self): self.attrs["dtype"] = framework.convert_np_dtype_to_dtype_( self.value.dtype ) + self.python_api = wrap_assign_value_wrapper(self.attrs["dtype"]) self.outputs = {"Out": self.value} def init_data(self): self.value = np.random.random(size=(2, 5)).astype(np.float32) - self.attrs["fp32_values"] = [float(v) for v in self.value.flat] + self.attrs["values"] = [float(v) for v in self.value.flat] def test_forward(self): self.check_output(check_cinn=True, check_pir=True) @@ -60,13 +61,13 @@ def test_forward(self): class TestAssignValueOp2(TestAssignValueOp): def init_data(self): self.value = np.random.random(size=(2, 5)).astype(np.int32) - self.attrs["int32_values"] = [int(v) for v in self.value.flat] + self.attrs["values"] = [int(v) for v in self.value.flat] class TestAssignValueOp3(TestAssignValueOp): def init_data(self): self.value = np.random.random(size=(2, 5)).astype(np.int64) - self.attrs["int64_values"] = [int(v) for v in self.value.flat] + self.attrs["values"] = [int(v) for v in self.value.flat] class TestAssignValueOp4(TestAssignValueOp): @@ -74,7 +75,29 @@ def init_data(self): self.value = np.random.choice(a=[False, True], size=(2, 5)).astype( np.bool_ ) - self.attrs["bool_values"] = [int(v) for v in self.value.flat] + self.attrs["values"] = [int(v) for v in self.value.flat] + + +class TestAssignValueOp5(TestAssignValueOp): + def init_data(self): + self.value = np.random.random(size=(2, 5)).astype(np.float64) + self.attrs["values"] = [float(v) for v in self.value.flat] + + +class TestAssignValueOp6(TestAssignValueOp): + def init_data(self): + self.value = ( + np.random.random(size=(2, 5)) + 1j * np.random.random(size=(2, 5)) + ).astype(np.complex64) + self.attrs["values"] = list(self.value.flat) + + +class TestAssignValueOp7(TestAssignValueOp): + def init_data(self): + self.value = ( + np.random.random(size=(2, 5)) + 1j * np.random.random(size=(2, 5)) + ).astype(np.complex128) + self.attrs["values"] = list(self.value.flat) class TestAssignApi(unittest.TestCase): @@ -97,8 +120,7 @@ def test_assign(self): with op_test.paddle_static_guard(): main_program = base.Program() with base.program_guard(main_program): - x = paddle.tensor.create_tensor(dtype=self.dtype) - paddle.assign(self.value, output=x) + x = paddle.assign(self.value) exe = base.Executor(self.place) [fetched_x] = exe.run(main_program, feed={}, fetch_list=[x]) @@ -145,5 +167,46 @@ def init_dtype(self): self.dtype = "bool" +class TestAssignApi5(TestAssignApi): + def init_dtype(self): + self.dtype = "float64" + + +class TestAssignApi6(TestAssignApi): + def setUp(self): + with op_test.paddle_static_guard(): + self.init_dtype() + self.value = ( + np.random.random(size=(2, 5)) + + 1j * (np.random.random(size=(2, 5))) + ).astype(np.complex64) + self.place = ( + base.CUDAPlace(0) + if base.is_compiled_with_cuda() + else base.CPUPlace() + ) + + def init_dtype(self): + self.dtype = "complex64" + + +class TestAssignApi7(TestAssignApi): + def setUp(self): + with op_test.paddle_static_guard(): + self.init_dtype() + self.value = ( + np.random.random(size=(2, 5)) + + 1j * (np.random.random(size=(2, 5))) + ).astype(np.complex128) + self.place = ( + base.CUDAPlace(0) + if base.is_compiled_with_cuda() + else base.CPUPlace() + ) + + def init_dtype(self): + self.dtype = "complex128" + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_initializer.py b/test/legacy_test/test_initializer.py index 5170207284459..ac612d2b2bee3 100644 --- a/test/legacy_test/test_initializer.py +++ b/test/legacy_test/test_initializer.py @@ -1354,7 +1354,8 @@ def test_numpy_array_initializer(self, dtype="float32"): self.assertEqual(len(block.ops), num_ops) init_op = block.ops[0] self.assertEqual(init_op.type, 'assign_value') - assert (init_op.attr('fp32_values') == np_array).all() + values = framework.extract_plain_list(init_op.attr('values')) + assert values == np_array.ravel().tolist() return block def test_numpy_array_initializer_fp16(self): diff --git a/test/legacy_test/test_initializer_nn.py b/test/legacy_test/test_initializer_nn.py index 95c64ac648290..1d9d8b08cf16d 100644 --- a/test/legacy_test/test_initializer_nn.py +++ b/test/legacy_test/test_initializer_nn.py @@ -664,8 +664,8 @@ def test_assign_initializer(self, dtype="float32"): self.assertEqual(len(block.ops), num_ops) init_op = block.ops[0] self.assertEqual(init_op.type, 'assign_value') - assert (init_op.attr('fp32_values') == np_array).all() - + values = framework.extract_plain_list(init_op.attr('values')) + assert values == np_array.ravel().tolist() paddle.disable_static() return block diff --git a/test/legacy_test/test_program_converter.py b/test/legacy_test/test_program_converter.py index 3894ca930ee0f..3ba1e7f33ad57 100644 --- a/test/legacy_test/test_program_converter.py +++ b/test/legacy_test/test_program_converter.py @@ -301,3 +301,196 @@ def test_complex128(self): legacy_program_bytes = mp._get_desc().serialize_to_string( legacy_format=True ) + + +class TestAssignValue(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + def _test_for_new_program_format(self, program_bytes): + restored_prog_as_is = framework_pb2.ProgramDesc.FromString( + program_bytes + ) + for block in restored_prog_as_is.blocks: + for op in block.ops: + if op.type in ("assign_value"): + attr_names = [attr.name for attr in op.attrs] + self.assertTrue("values" in attr_names) + self.assertFalse("bool_values" in attr_names) + self.assertFalse("int32_values" in attr_names) + self.assertFalse("int64_values" in attr_names) + self.assertFalse("fp32_values" in attr_names) + + def _test_for_legacy_program_format(self, program_bytes): + restored_prog_as_is = framework_pb2.ProgramDesc.FromString( + program_bytes + ) + for block in restored_prog_as_is.blocks: + for op in block.ops: + if op.type in ("set_value", "set_value_grad"): + attr_names = [attr.name for attr in op.attrs] + self.assertFalse("values" in attr_names) + self.assertTrue("bool_values" in attr_names) + self.assertTrue("int32_values" in attr_names) + self.assertTrue("int64_values" in attr_names) + self.assertTrue("fp32_values" in attr_names) + + def _test_equivalence( + self, + new_program_bytes, + legacy_program_bytes, + fetch_list, + expected_outputs, + ): + normal_program = paddle.static.io.deserialize_program(new_program_bytes) + converted_back_program = paddle.static.io.deserialize_program( + legacy_program_bytes + ) + exe = paddle.static.Executor(paddle.CPUPlace()) + out = exe.run(normal_program, fetch_list=fetch_list) + np.testing.assert_allclose(out[0], expected_outputs[0]) + out = exe.run(converted_back_program, fetch_list=fetch_list) + np.testing.assert_allclose(out[0], expected_outputs[0]) + + def test_int32(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = np.array([[1, 1], [3, 4], [1, 3]]).astype(np.int32) + out = paddle.assign(x) + + normal_program_bytes = mp._get_desc().serialize_to_string() + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + self.assertNotEqual(normal_program_bytes, legacy_program_bytes) + self._test_for_new_program_format(normal_program_bytes) + self._test_for_legacy_program_format(legacy_program_bytes) + self._test_equivalence( + normal_program_bytes, + legacy_program_bytes, + fetch_list=[out.name], + expected_outputs=[x], + ) + + def test_int64(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = np.array([[1, 1], [3, 4], [1, 3]]).astype(np.int64) + out = paddle.assign(x) + + normal_program_bytes = mp._get_desc().serialize_to_string() + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + + self.assertNotEqual(normal_program_bytes, legacy_program_bytes) + self._test_for_new_program_format(normal_program_bytes) + self._test_for_legacy_program_format(legacy_program_bytes) + self._test_equivalence( + normal_program_bytes, + legacy_program_bytes, + fetch_list=[out.name], + expected_outputs=[x], + ) + + def test_float32(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = np.random.random(size=(2, 5)).astype(np.float32) + out = paddle.assign(x) + + normal_program_bytes = mp._get_desc().serialize_to_string() + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + + self.assertNotEqual(normal_program_bytes, legacy_program_bytes) + self._test_for_new_program_format(normal_program_bytes) + self._test_for_legacy_program_format(legacy_program_bytes) + self._test_equivalence( + normal_program_bytes, + legacy_program_bytes, + fetch_list=[out.name], + expected_outputs=[x], + ) + + def test_float64(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = np.random.random(size=(2, 5)).astype(np.float64) + out = paddle.assign(x) + + normal_program_bytes = mp._get_desc().serialize_to_string() + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + + self.assertNotEqual(normal_program_bytes, legacy_program_bytes) + self._test_for_new_program_format(normal_program_bytes) + self._test_for_legacy_program_format(legacy_program_bytes) + self._test_equivalence( + normal_program_bytes, + legacy_program_bytes, + fetch_list=[out.name], + expected_outputs=[x], + ) + + def test_bool(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = np.random.choice(a=[False, True], size=(2, 5)).astype(np.bool_) + out = paddle.assign(x) + + normal_program_bytes = mp._get_desc().serialize_to_string() + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + + self.assertNotEqual(normal_program_bytes, legacy_program_bytes) + self._test_for_new_program_format(normal_program_bytes) + self._test_for_legacy_program_format(legacy_program_bytes) + self._test_equivalence( + normal_program_bytes, + legacy_program_bytes, + fetch_list=[out.name], + expected_outputs=[x], + ) + + def test_complex64(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = ( + np.random.random(size=(2, 5)) + + 1j * np.random.random(size=(2, 5)) + ).astype(np.complex64) + out = paddle.assign(x) + + with self.assertRaisesRegex(RuntimeError, "Invalid data type"): + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + + def test_complex128(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = ( + np.random.random(size=(2, 5)) + + 1j * np.random.random(size=(2, 5)) + ).astype(np.complex128) + out = paddle.assign(x) + + with self.assertRaisesRegex(RuntimeError, "Invalid data type"): + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/xpu/test_assign_value_op_xpu.py b/test/xpu/test_assign_value_op_xpu.py index f6d2d2ec96ae3..e4414cdaafc05 100644 --- a/test/xpu/test_assign_value_op_xpu.py +++ b/test/xpu/test_assign_value_op_xpu.py @@ -53,7 +53,7 @@ def setUp(self): def init_data(self): self.value = np.random.random(size=(2, 5)).astype(np.float32) - self.attrs["fp32_values"] = [float(v) for v in self.value.flat] + self.attrs["values"] = [float(v) for v in self.value.flat] def test_forward(self): self.check_output_with_place(self.place) @@ -61,19 +61,40 @@ def test_forward(self): class TestAssignValueOp2(TestAssignValueOp): def init_data(self): self.value = np.random.random(size=(2, 5)).astype(np.int32) - self.attrs["int32_values"] = [int(v) for v in self.value.flat] + self.attrs["values"] = [int(v) for v in self.value.flat] class TestAssignValueOp3(TestAssignValueOp): def init_data(self): self.value = np.random.random(size=(2, 5)).astype(np.int64) - self.attrs["int64_values"] = [int(v) for v in self.value.flat] + self.attrs["values"] = [int(v) for v in self.value.flat] class TestAssignValueOp4(TestAssignValueOp): def init_data(self): self.value = np.random.choice(a=[False, True], size=(2, 5)).astype( np.bool_ ) - self.attrs["bool_values"] = [int(v) for v in self.value.flat] + self.attrs["values"] = [int(v) for v in self.value.flat] + + class TestAssignValueOp5(TestAssignValueOp): + def init_data(self): + self.value = np.random.random(size=(2, 5)).astype(np.float64) + self.attrs["values"] = [float(v) for v in self.value.flat] + + class TestAssignValueOp6(TestAssignValueOp): + def init_data(self): + self.value = ( + np.random.random(size=(2, 5)) + + 1j * np.random.random(size=(2, 5)) + ).astype(np.complex64) + self.attrs["values"] = list(self.value.flat) + + class TestAssignValueOp7(TestAssignValueOp): + def init_data(self): + self.value = ( + np.random.random(size=(2, 5)) + + 1j * np.random.random(size=(2, 5)) + ).astype(np.complex128) + self.attrs["values"] = list(self.value.flat) class TestAssignApi(unittest.TestCase): @@ -90,8 +111,7 @@ def init_dtype(self): def test_assign(self): main_program = base.Program() with base.program_guard(main_program): - x = paddle.tensor.create_tensor(dtype=self.dtype) - paddle.assign(self.value, output=x) + x = paddle.assign(self.value) exe = base.Executor(self.place) [fetched_x] = exe.run(main_program, feed={}, fetch_list=[x]) @@ -121,6 +141,35 @@ def init_dtype(self): self.dtype = "bool" +class TestAssignApi5(TestAssignApi): + def init_dtype(self): + self.dtype = "float64" + + +class TestAssignApi6(TestAssignApi): + def setUp(self): + self.init_dtype() + self.value = ( + np.random.random(size=(2, 5)) + 1j * (np.random.random(size=(2, 5))) + ).astype(np.complex64) + self.place = base.XPUPlace(0) + + def init_dtype(self): + self.dtype = "complex64" + + +class TestAssignApi7(TestAssignApi): + def setUp(self): + self.init_dtype() + self.value = ( + np.random.random(size=(2, 5)) + 1j * (np.random.random(size=(2, 5))) + ).astype(np.complex128) + self.place = base.XPUPlace(0) + + def init_dtype(self): + self.dtype = "complex128" + + support_types = get_xpu_op_support_types('assign_value') for stype in support_types: create_test_class(globals(), XPUTestAssignValueOp, stype)