Skip to content

Commit

Permalink
add get inout var ptr for dygraph (#39134)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Jan 22, 2022
1 parent 7ac2f80 commit ec24bc9
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
23 changes: 18 additions & 5 deletions paddle/fluid/eager/legacy/infer_shape_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,17 +222,30 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
paddle::framework::DataLayout::kMKLDNN));
}

// TODO(paddle-dev): Can this be template?
std::vector<paddle::framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"GetInputVarPtrs not support in dygraph runtime context"));
std::vector<paddle::framework::InferShapeVarPtr> res;
auto it = tensor_in_->find(name);
PADDLE_ENFORCE_NE(it, tensor_in_->end(),
paddle::platform::errors::NotFound(
"Can not find [%s] in inputs.", name));
for (auto& tensor : it->second) {
res.emplace_back(tensor->MutableVar());
}
return res;
}

std::vector<paddle::framework::InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"GetOutputVarPtrs not support in dygraph runtime context"));
std::vector<paddle::framework::InferShapeVarPtr> res;
auto it = tensor_out_->find(name);
PADDLE_ENFORCE_NE(it, tensor_out_->end(),
paddle::platform::errors::NotFound(
"Can not find [%s] in outputs.", name));
for (auto& tensor : it->second) {
res.emplace_back(tensor->MutableVar());
}
return res;
}

DDim GetInputDim(const std::string& name) const override {
Expand Down
23 changes: 18 additions & 5 deletions paddle/fluid/imperative/infer_shape_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,30 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
(op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN));
}

// TODO(paddle-dev): Can this be template?
std::vector<framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"GetInputVarPtrs not support in dygraph runtime context"));
std::vector<framework::InferShapeVarPtr> res;
auto it = var_base_map_in_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_in_->end(),
platform::errors::NotFound("Can not find [%s] in inputs.", name));
for (auto& var : it->second) {
res.emplace_back(var->MutableVar());
}
return res;
}

std::vector<framework::InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"GetOutputVarPtrs not support in dygraph runtime context"));
std::vector<framework::InferShapeVarPtr> res;
auto it = var_base_map_out_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(),
platform::errors::NotFound("Can not find [%s] in outputs.", name));
for (auto& var : it->second) {
res.emplace_back(var->MutableVar());
}
return res;
}

DDim GetInputDim(const std::string& name) const override {
Expand Down

0 comments on commit ec24bc9

Please sign in to comment.