Skip to content

Commit

Permalink
[Cherry-pick] Add fp16 dtype support for set_value op (#46906)
Browse files Browse the repository at this point in the history
Fix set_value failure when source tensor is fp16 Dtype and destiny value is a number
(dev PR link:#46801)
  • Loading branch information
Courtesy-Xs committed Oct 13, 2022
1 parent 0280c0b commit 100a075
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 5 deletions.
5 changes: 4 additions & 1 deletion paddle/fluid/operators/set_value_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
framework::proto::VarType::INT32,
framework::proto::VarType::INT64,
framework::proto::VarType::FP32,
framework::proto::VarType::FP64})
framework::proto::VarType::FP64,
framework::proto::VarType::FP16})
.SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int64_t>>(
"axes", "(list<int64_t>) Axes that `starts` and `ends` apply to.");
Expand Down Expand Up @@ -135,6 +136,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({});
AddAttr<std::vector<double>>("fp64_values", "Store the float64 values.")
.SetDefault({});
AddAttr<std::vector<float>>("fp16_values", "Store the float16 values.")
.SetDefault({});

AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.")
.SetDefault({});
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1151,11 +1151,15 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
} else if (self->tensor.dtype() ==
paddle::experimental::DataType::BOOL) {
attrs["bool_values"] = std::vector<int>{value_obj_tmp.cast<bool>()};
} else if (self->tensor.dtype() ==
paddle::experimental::DataType::FLOAT16) {
attrs["fp16_values"] =
std::vector<float>{value_obj_tmp.cast<float>()};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, int32 or int64, "
"float32, int32, int64 or float16, "
"please check the type of tensor."));
}
attrs["shape"] = std::vector<int64_t>{1};
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -964,11 +964,15 @@ void BindImperative(py::module *m_ptr) {
framework::proto::VarType::BOOL) {
attrs["bool_values"] =
std::vector<int>{value_obj.cast<bool>()};
} else if (self->DataType() ==
framework::proto::VarType::FP16) {
attrs["fp16_values"] =
std::vector<float>{value_obj.cast<float>()};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, int32 or int64, "
"float32, int32, int64 or float16, "
"please check the type of tensor."));
}
attrs["shape"] = std::vector<int64_t>{1};
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/ops/compat/set_value_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,21 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
"shape",
"bool_values"},
{"Out"});
} else if (ctx.HasAttr("fp16_values") &&
!paddle::any_cast<std::vector<float>>(
ctx.Attr("fp16_values"))
.empty()) {
return KernelSignature("set_value",
{"Input"},
{"starts",
"ends",
"steps",
"axes",
"decrease_axes",
"none_axes",
"shape",
"fp16_values"},
{"Out"});
}
}
}
Expand Down
23 changes: 22 additions & 1 deletion python/paddle/fluid/tests/unittests/test_set_value_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,28 @@ def set_dtype(self):
create_test_value_int64(TestSetValueItemSlice4)


def create_test_value_fp16(parent):

class TestValueInt(parent):

def set_value(self):
self.value = 3.7

def set_dtype(self):
self.dtype = "float16"

cls_name = "{0}_{1}".format(parent.__name__, "Valuefp16")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt


create_test_value_fp16(TestSetValueItemInt)
create_test_value_fp16(TestSetValueItemSlice)
create_test_value_fp16(TestSetValueItemSlice2)
create_test_value_fp16(TestSetValueItemSlice3)
create_test_value_fp16(TestSetValueItemSlice4)


def create_test_value_fp32(parent):

class TestValueInt(parent):
Expand Down Expand Up @@ -1015,7 +1037,6 @@ def test_error(self):
paddle.enable_static()
with paddle.static.program_guard(self.program):
self._value_type_error()
self._dtype_error()
self._step_error()
self._bool_list_error()
self._bool_tensor_error()
Expand Down
5 changes: 4 additions & 1 deletion python/paddle/fluid/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,10 +730,13 @@ def _setitem_impl_(var, item, value):
elif dtype == core.VarDesc.VarType.INT64:
value_name = "int64_values"
values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP16:
value_name = "fp16_values"
values = [float(v) for v in value.flat]
else:
raise TypeError(
"When assign a numpy.ndarray, integer or float to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, float32, int32 or int64, but "
"the data type of the paddle.Tensor must be bool, float32, int32, int64 or float16, but "
"received %s." % convert_dtype(dtype))
attrs[value_name] = values
attrs["shape"] = shape
Expand Down

0 comments on commit 100a075

Please sign in to comment.