Skip to content

Commit

Permalink
isolates friends of storage, test=develop (#38977)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Jan 15, 2022
1 parent 35d2b71 commit d13c779
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 44 deletions.
45 changes: 9 additions & 36 deletions paddle/pten/api/lib/utils/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,7 @@ void SharesStorageBase(pten::DenseTensor* src, paddle::framework::Tensor* dst) {
platform::errors::InvalidArgument(
"The destination Tensor is nullptr when move allocation."));
dst->Resize(src->dims());
auto* storage = static_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src));
dst->ResetHolderWithType(storage->GetAllocation(),
dst->ResetHolderWithType(src->Holder(),
pten::TransToProtoVarType(src->dtype()));
dst->set_offset(src->meta().offset);
}
Expand All @@ -345,19 +343,7 @@ void ReMakePtenDenseTensorBase(const paddle::framework::Tensor& src,
meta->dtype = pten::TransToPtenDataType(src.type());
meta->layout = src.layout();
meta->offset = src.offset();

auto* shared_storage = static_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(dst));
PADDLE_ENFORCE_NOT_NULL(
shared_storage,
platform::errors::NotFound(
"Target DenseTensor's shared storage is nullptr."));

PADDLE_ENFORCE_EQ(src.IsInitialized(),
true,
paddle::platform::errors::InvalidArgument(
"Source Tensor is not initialized."));
shared_storage->ResetAllocation(src.Holder());
dst->ResetHolder(src.Holder());
}

void ReMakePtenDenseTensor(const paddle::framework::Tensor& src,
Expand All @@ -378,19 +364,12 @@ void ReMakePtenDenseTensorByArgDefBase(const paddle::framework::Tensor& src,
meta->layout = src.layout();
meta->offset = src.offset();

auto* shared_storage = static_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(dst));
PADDLE_ENFORCE_NOT_NULL(
shared_storage,
platform::errors::NotFound(
"Target DenseTensor's shared storage is nullptr."));

if (src.IsInitialized() &&
src.place() == pten::TransToFluidPlace(arg_def.backend)) {
shared_storage->ResetAllocation(src.Holder());
dst->ResetHolder(src.Holder());
} else {
shared_storage->ResetAllocationPlace(
pten::TransToFluidPlace(arg_def.backend));
// This does not affect the correctness, and will be modified immediately.
// dst->mutable_data(pten::TransToFluidPlace(arg_def.backend));
}
}

Expand Down Expand Up @@ -481,14 +460,10 @@ void MakeVariableFromPtenTensor(pten::DenseTensor* src,
tensor->Resize(src->dims());
SetLoD(tensor->mutable_lod(), src->lod());

// here dynamic_cast is slow
auto* storage = static_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src));

if (!tensor->IsInitialized() ||
(tensor->IsInitialized() &&
!IsSameAllocation(tensor->Holder(), storage->GetAllocation()))) {
tensor->ResetHolderWithType(std::move(storage->GetAllocation()), dtype);
!IsSameAllocation(tensor->Holder(), src->Holder()))) {
tensor->ResetHolderWithType(std::move(src->Holder()), dtype);
} else {
// Even the pten tensor and Variable have the same Alloctation (both have
// the same pointer address, same size and same place)
Expand All @@ -502,10 +477,8 @@ void MakeVariableFromPtenTensor(pten::DenseTensor* src,
auto dtype = pten::TransToProtoVarType(src->dtype());

if (!tensor->value().IsInitialized()) {
auto storage = dynamic_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src));
tensor->mutable_value()->ResetHolderWithType(
std::move(storage->GetAllocation()), dtype);
tensor->mutable_value()->ResetHolderWithType(std::move(src->Holder()),
dtype);
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
Expand Down
9 changes: 1 addition & 8 deletions paddle/pten/core/compat_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,14 @@ namespace pten {

class CompatibleDenseTensorUtils {
public:
static Storage* UnsafeGetMutableStorage(DenseTensor* tensor) {
return tensor->storage_.get();
}

static DenseTensorMeta* GetMutableMeta(DenseTensor* tensor) {
return &(tensor->meta_);
}

// only can deal with SharedStorage now
static void ClearStorage(DenseTensor* tensor) {
// use static_cast to improve performance, replace by dynamic_cast later
if (tensor->storage_ != nullptr) {
static_cast<paddle::experimental::SharedStorage*>(tensor->storage_.get())
->Reset();
}
tensor->MoveMemoryHolder();
}

static DenseTensor Slice(const DenseTensor& tensor,
Expand Down

0 comments on commit d13c779

Please sign in to comment.