Skip to content

Commit

Permalink
isolates friends of storage, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Jan 15, 2022
1 parent 35d2b71 commit 31b8f30
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 44 deletions.
45 changes: 9 additions & 36 deletions paddle/pten/api/lib/utils/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,7 @@ void SharesStorageBase(pten::DenseTensor* src, paddle::framework::Tensor* dst) {
platform::errors::InvalidArgument(
"The destination Tensor is nullptr when move allocation."));
dst->Resize(src->dims());
auto* storage = static_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src));
dst->ResetHolderWithType(storage->GetAllocation(),
dst->ResetHolderWithType(src->Holder(),
pten::TransToProtoVarType(src->dtype()));
dst->set_offset(src->meta().offset);
}
Expand All @@ -345,19 +343,7 @@ void ReMakePtenDenseTensorBase(const paddle::framework::Tensor& src,
meta->dtype = pten::TransToPtenDataType(src.type());
meta->layout = src.layout();
meta->offset = src.offset();

auto* shared_storage = static_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(dst));
PADDLE_ENFORCE_NOT_NULL(
shared_storage,
platform::errors::NotFound(
"Target DenseTensor's shared storage is nullptr."));

PADDLE_ENFORCE_EQ(src.IsInitialized(),
true,
paddle::platform::errors::InvalidArgument(
"Source Tensor is not initialized."));
shared_storage->ResetAllocation(src.Holder());
dst->ResetHolder(src.Holder());
}

void ReMakePtenDenseTensor(const paddle::framework::Tensor& src,
Expand All @@ -378,19 +364,12 @@ void ReMakePtenDenseTensorByArgDefBase(const paddle::framework::Tensor& src,
meta->layout = src.layout();
meta->offset = src.offset();

auto* shared_storage = static_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(dst));
PADDLE_ENFORCE_NOT_NULL(
shared_storage,
platform::errors::NotFound(
"Target DenseTensor's shared storage is nullptr."));

if (src.IsInitialized() &&
src.place() == pten::TransToFluidPlace(arg_def.backend)) {
shared_storage->ResetAllocation(src.Holder());
dst->ResetHolder(src.Holder());
} else {
shared_storage->ResetAllocationPlace(
pten::TransToFluidPlace(arg_def.backend));
// This does not affect the correctness, and will be modified immediately.
// dst->mutable_data(pten::TransToFluidPlace(arg_def.backend));
}
}

Expand Down Expand Up @@ -481,14 +460,10 @@ void MakeVariableFromPtenTensor(pten::DenseTensor* src,
tensor->Resize(src->dims());
SetLoD(tensor->mutable_lod(), src->lod());

// here dynamic_cast is slow
auto* storage = static_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src));

if (!tensor->IsInitialized() ||
(tensor->IsInitialized() &&
!IsSameAllocation(tensor->Holder(), storage->GetAllocation()))) {
tensor->ResetHolderWithType(std::move(storage->GetAllocation()), dtype);
!IsSameAllocation(tensor->Holder(), src->Holder()))) {
tensor->ResetHolderWithType(std::move(src->Holder()), dtype);
} else {
// Even the pten tensor and Variable have the same Alloctation (both have
// the same pointer address, same size and same place)
Expand All @@ -502,10 +477,8 @@ void MakeVariableFromPtenTensor(pten::DenseTensor* src,
auto dtype = pten::TransToProtoVarType(src->dtype());

if (!tensor->value().IsInitialized()) {
auto storage = dynamic_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src));
tensor->mutable_value()->ResetHolderWithType(
std::move(storage->GetAllocation()), dtype);
tensor->mutable_value()->ResetHolderWithType(std::move(src->Holder()),
dtype);
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
Expand Down
9 changes: 1 addition & 8 deletions paddle/pten/core/compat_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,14 @@ namespace pten {

class CompatibleDenseTensorUtils {
public:
static Storage* UnsafeGetMutableStorage(DenseTensor* tensor) {
return tensor->storage_.get();
}

static DenseTensorMeta* GetMutableMeta(DenseTensor* tensor) {
return &(tensor->meta_);
}

// only can deal with SharedStorage now
static void ClearStorage(DenseTensor* tensor) {
// use static_cast to improve performance, replace by dynamic_cast later
if (tensor->storage_ != nullptr) {
static_cast<paddle::experimental::SharedStorage*>(tensor->storage_.get())
->Reset();
}
tensor->MoveMemoryHolder();
}

static DenseTensor Slice(const DenseTensor& tensor,
Expand Down

1 comment on commit 31b8f30

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.