diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 3ad4994a590f7..09c3cea398b2a 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -100,17 +100,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) { bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); } -void CastPyArg2AttrBoolean(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +bool CastPyArg2Boolean(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { if (obj == Py_None) { - attrs[key] = false; // To be compatible with QA integration testing. Some - // test case pass in None. + return false; // To be compatible with QA integration testing. Some + // test case pass in None. } else if (obj == Py_True) { - attrs[key] = true; + return true; } else if (obj == Py_False) { - attrs[key] = false; + return false; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -118,62 +116,89 @@ void CastPyArg2AttrBoolean(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return false; +} + +void CastPyArg2AttrBoolean(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Boolean(obj, op_type, arg_pos); +} + +int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos) { + if (PyObject_CheckLongOrToLong(&obj)) { + return (int)PyLong_AsLong(obj); // NOLINT + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument (position %d) must be " + "int, but got %s", + op_type, arg_pos + 1, + ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT + } + + return 0; } void CastPyArg2AttrInt(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, const std::string& op_type, ssize_t arg_pos) { + attrs[key] = CastPyArg2Int(obj, op_type, arg_pos); +} + +int64_t CastPyArg2Long(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { if (PyObject_CheckLongOrToLong(&obj)) { - attrs[key] = (int)PyLong_AsLong(obj); // NOLINT + return (int64_t)PyLong_AsLong(obj); // NOLINT } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " - "int, but got %s", + "long, but got %s", op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return 0; } void CastPyArg2AttrLong(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, const std::string& op_type, ssize_t arg_pos) { - if (PyObject_CheckLongOrToLong(&obj)) { - attrs[key] = (int64_t)PyLong_AsLong(obj); // NOLINT + attrs[key] = CastPyArg2Long(obj, op_type, arg_pos); +} + +float CastPyArg2Float(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { + if (PyObject_CheckFloatOrToFloat(&obj)) { + return (float)PyFloat_AsDouble(obj); // NOLINT } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " - "long, but got %s", + "float, but got %s", op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return 0.0; } void CastPyArg2AttrFloat(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, const std::string& op_type, ssize_t arg_pos) { - if (PyObject_CheckFloatOrToFloat(&obj)) { - attrs[key] = (float)PyFloat_AsDouble(obj); // NOLINT - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "%s(): argument (position %d) must be " - "float, but got %s", - op_type, arg_pos + 1, - ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT - } + attrs[key] = CastPyArg2Float(obj, op_type, arg_pos); } -void CastPyArg2AttrString(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +std::string CastPyArg2String(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { if (PyObject_CheckString(obj)) { Py_ssize_t size; const char* data; data = PyUnicode_AsUTF8AndSize(obj, &size); - attrs[key] = std::string(data, (size_t)size); // NOLINT + return std::string(data, (size_t)size); // NOLINT } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -181,16 +206,23 @@ void CastPyArg2AttrString(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return ""; } -void CastPyArg2AttrBooleans(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +void CastPyArg2AttrString(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2String(obj, op_type, arg_pos); +} + +std::vector CastPyArg2Booleans(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { + std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckBool(&item)) { @@ -204,11 +236,9 @@ void CastPyArg2AttrBooleans(PyObject* obj, i)); } } - attrs[key] = value; } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckBool(&item)) { @@ -222,7 +252,6 @@ void CastPyArg2AttrBooleans(PyObject* obj, i)); } } - attrs[key] = value; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -230,16 +259,23 @@ void CastPyArg2AttrBooleans(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return value; } -void CastPyArg2AttrInts(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +void CastPyArg2AttrBooleans(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Booleans(obj, op_type, arg_pos); +} + +std::vector CastPyArg2Ints(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { + std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckLongOrToLong(&item)) { @@ -253,11 +289,9 @@ void CastPyArg2AttrInts(PyObject* obj, i)); } } - attrs[key] = value; } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckLongOrToLong(&item)) { @@ -271,11 +305,9 @@ void CastPyArg2AttrInts(PyObject* obj, i)); } } - attrs[key] = value; } else if (PySequence_Check(obj)) { Py_ssize_t len = PySequence_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PySequence_GetItem(obj, i); if (PyObject_CheckLongOrToLong(&item)) { @@ -289,7 +321,6 @@ void CastPyArg2AttrInts(PyObject* obj, i)); } } - attrs[key] = value; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -297,16 +328,23 @@ void CastPyArg2AttrInts(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return value; } -void CastPyArg2AttrLongs(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +void CastPyArg2AttrInts(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Ints(obj, op_type, arg_pos); +} + +std::vector CastPyArg2Longs(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { + std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckLongOrToLong(&item)) { @@ -320,11 +358,9 @@ void CastPyArg2AttrLongs(PyObject* obj, i)); } } - attrs[key] = value; } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckLongOrToLong(&item)) { @@ -338,11 +374,9 @@ void CastPyArg2AttrLongs(PyObject* obj, i)); } } - attrs[key] = value; } else if (PySequence_Check(obj)) { Py_ssize_t len = PySequence_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PySequence_GetItem(obj, i); if (PyObject_CheckLongOrToLong(&item)) { @@ -356,7 +390,6 @@ void CastPyArg2AttrLongs(PyObject* obj, i)); } } - attrs[key] = value; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -364,16 +397,23 @@ void CastPyArg2AttrLongs(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return value; } -void CastPyArg2AttrFloats(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +void CastPyArg2AttrLongs(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Longs(obj, op_type, arg_pos); +} + +std::vector CastPyArg2Floats(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { + std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckFloatOrToFloat(&item)) { @@ -387,11 +427,9 @@ void CastPyArg2AttrFloats(PyObject* obj, i)); } } - attrs[key] = value; } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckFloatOrToFloat(&item)) { @@ -405,11 +443,9 @@ void CastPyArg2AttrFloats(PyObject* obj, i)); } } - attrs[key] = value; } else if (PySequence_Check(obj)) { Py_ssize_t len = PySequence_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PySequence_GetItem(obj, i); if (PyObject_CheckFloatOrToFloat(&item)) { @@ -423,7 +459,6 @@ void CastPyArg2AttrFloats(PyObject* obj, i)); } } - attrs[key] = value; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -431,16 +466,24 @@ void CastPyArg2AttrFloats(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return value; } -void CastPyArg2AttrFloat64s(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +void CastPyArg2AttrFloats(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Floats(obj, op_type, arg_pos); +} + +std::vector CastPyArg2Float64s(PyObject* obj, + const std::string& op_type, + ssize_t arg_pos) { + std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckFloatOrToFloat(&item)) { @@ -454,11 +497,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj, i)); } } - attrs[key] = value; } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckFloatOrToFloat(&item)) { @@ -472,11 +513,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj, i)); } } - attrs[key] = value; } else if (PySequence_Check(obj)) { Py_ssize_t len = PySequence_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PySequence_GetItem(obj, i); if (PyObject_CheckFloatOrToFloat(&item)) { @@ -490,7 +529,6 @@ void CastPyArg2AttrFloat64s(PyObject* obj, i)); } } - attrs[key] = value; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -498,16 +536,24 @@ void CastPyArg2AttrFloat64s(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return value; } -void CastPyArg2AttrStrings(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +void CastPyArg2AttrFloat64s(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Float64s(obj, op_type, arg_pos); +} + +std::vector CastPyArg2Strings(PyObject* obj, + const std::string& op_type, + ssize_t arg_pos) { + std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckString(item)) { @@ -524,11 +570,9 @@ void CastPyArg2AttrStrings(PyObject* obj, i)); } } - attrs[key] = value; } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckString(item)) { @@ -545,7 +589,6 @@ void CastPyArg2AttrStrings(PyObject* obj, i)); } } - attrs[key] = value; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -553,6 +596,15 @@ void CastPyArg2AttrStrings(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return value; +} + +void CastPyArg2AttrStrings(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Strings(obj, op_type, arg_pos); } void CastPyArg2AttrBlock(PyObject* obj, diff --git a/paddle/fluid/pybind/op_function_common.h b/paddle/fluid/pybind/op_function_common.h index 9dc3a71a6ccf9..7ead985266725 100644 --- a/paddle/fluid/pybind/op_function_common.h +++ b/paddle/fluid/pybind/op_function_common.h @@ -43,6 +43,30 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj); bool PyObject_CheckString(PyObject* obj); +bool CastPyArg2Boolean(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos); +int64_t CastPyArg2Long(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +float CastPyArg2Float(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +std::string CastPyArg2String(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +std::vector CastPyArg2Booleans(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +std::vector CastPyArg2Ints(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +std::vector CastPyArg2Longs(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +std::vector CastPyArg2Floats(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +std::vector CastPyArg2Float64s(PyObject* obj, + const std::string& op_type, + ssize_t arg_pos); +std::vector CastPyArg2Strings(PyObject* obj, + const std::string& op_type, + ssize_t arg_pos); + void CastPyArg2AttrBoolean(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, const std::string& op_type,