Skip to content

Commit

Permalink
give shape related contructor and reshape warning
Browse files Browse the repository at this point in the history
  • Loading branch information
JiabinYang committed Mar 18, 2021
1 parent fe241fd commit 80db24b
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 4 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/extension/include/ext_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class PD_DLL_DECL Tensor {
/// \brief Construct a Tensor on target Place for CustomOp.
/// Generally it's only used for user to create Tensor.
explicit Tensor(const PlaceType& place);
/// \brief Construct a Tensor on target Place with shape for CustomOp.
/// Generally it's only used for user to create Tensor.
Tensor(const PlaceType& place, const std::vector<int64_t>& shape);
/// \brief Reset the shape of the tensor.
/// Generally it's only used for the input tensor.
/// Reshape must be called before calling
Expand Down
20 changes: 19 additions & 1 deletion paddle/fluid/extension/src/ext_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,31 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,

void Tensor::reshape(const std::vector<int64_t> &shape) {
GET_CASTED_TENSOR
tensor->Resize(framework::make_ddim(shape));
auto new_dim = framework::make_ddim(shape);
if (tensor->numel() != framework::product(new_dim)) {
LOG(WARNING) << "Custom Op: Calling reshape to a new shape which is bigger "
"or smaller"
<< "than original shape will not change your tensor's memory "
"Please call"
<< "paddle::Tensor::mutable_data<T>() after to reallocate "
"your tensor's size."
<< std::endl;
}
tensor->Resize(new_dim);
}

Tensor::Tensor(const PlaceType &place)
: tensor_(std::make_shared<framework::LoDTensor>()),
place_(place),
stream_(StreamWrapper()) {}

Tensor::Tensor(const PlaceType &place, const std::vector<int64_t> &shape)
: tensor_(std::make_shared<framework::LoDTensor>()),
place_(place),
stream_(StreamWrapper()) {
reshape(shape);
}

template <typename T>
T *Tensor::mutable_data(const PlaceType &place) {
place_ = place;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/custom_tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class CustomTensorUtils {
/// \brief Share data FROM another tensor.
/// Use this to pass tensor from op to op
/// \return void.
static void ShareDataFrom(const void* src, const Tensor& dst);
static void ShareDataFrom(const void* src, const paddle::Tensor& dst);

static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType(
const paddle::DataType& dtype) {
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/tests/custom_op/custom_relu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data,
}

std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());

out.reshape(x.shape());
// out.reshape(x.shape());
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>(
Expand Down

0 comments on commit 80db24b

Please sign in to comment.