Skip to content

Commit

Permalink
remove set_value numpy (#41017)
Browse files Browse the repository at this point in the history
* remove set_value numpy

* optimize code

* optimize to_tensor

* use common function

Co-authored-by: root <root@yq01-sys-hic-k8s-v100-box-a225-0186.yq01.baidu.com>
  • Loading branch information
Zjq9409 and root committed Mar 30, 2022
1 parent 95265d5 commit 1042f42
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import paddle
from .. import framework
from ..framework import convert_np_dtype_to_dtype_
from .. import core
from .. import unique_name
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, EagerParamBase
Expand Down Expand Up @@ -172,25 +173,24 @@ def set_value(self, value):
else:
self.value().set_string_list(value)
else:
value_np = value
if isinstance(value, base_tensor):
value_np = value.numpy()

self_tensor_np = self.numpy()

assert self_tensor_np.shape == value_np.shape, \
assert self.shape == list(value.shape), \
"Variable Shape not match, Variable [ {} ] need tensor with shape {} but load set tensor with shape {}".format(
self.name, self_tensor_np.shape, value_np.shape)
self.name, self.shape, value.shape)

if isinstance(value, base_tensor):
dtype = value.dtype
else:
dtype = convert_np_dtype_to_dtype_(value.dtype)

assert self_tensor_np.dtype == value_np.dtype, \
assert self.dtype == dtype, \
"Variable dtype not match, Variable [ {} ] need tensor with dtype {} but load tensor with dtype {}".format(
self.name, self_tensor_np.dtype, value_np.dtype)
self.name, self.dtype, dtype)

# NOTE(wuweilong): self could be VarBase or Tensor, the subsequent behavior are defined in different files
# if self is VarBase, method value() return Variable that bindded in imperative.cc, get_tensor() bindded in pybind.cc
# if self is Tensor, method value() return self that defined in this file, get_tensor() defined in eager_method.cc
# this Interface behavior will be unifed in the future.
self.value().get_tensor().set(value_np,
self.value().get_tensor().set(value,
framework._current_expected_place())

@framework.dygraph_only
Expand Down

0 comments on commit 1042f42

Please sign in to comment.