Skip to content

Commit

Permalink
add get inout var ptr for dygraph
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Jan 22, 2022
1 parent 09f6f17 commit 5636edb
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
23 changes: 18 additions & 5 deletions paddle/fluid/eager/legacy/infer_shape_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,17 +222,30 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
paddle::framework::DataLayout::kMKLDNN));
}

// TODO(paddle-dev): Can this be template?
std::vector<paddle::framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"GetInputVarPtrs not support in dygraph runtime context"));
std::vector<paddle::framework::InferShapeVarPtr> res;
auto it = tensor_in_->find(name);
PADDLE_ENFORCE_NE(it, tensor_in_->end(),
paddle::platform::errors::NotFound(
"Can not find [%s] in inputs.", name));
for (auto& tensor : it->second) {
res.emplace_back(tensor->MutableVar());
}
return res;
}

std::vector<paddle::framework::InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"GetOutputVarPtrs not support in dygraph runtime context"));
std::vector<paddle::framework::InferShapeVarPtr> res;
auto it = tensor_out_->find(name);
PADDLE_ENFORCE_NE(it, tensor_out_->end(),
paddle::platform::errors::NotFound(
"Can not find [%s] in outputs.", name));
for (auto& tensor : it->second) {
res.emplace_back(tensor->MutableVar());
}
return res;
}

DDim GetInputDim(const std::string& name) const override {
Expand Down
23 changes: 18 additions & 5 deletions paddle/fluid/imperative/infer_shape_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,30 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
(op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN));
}

// TODO(paddle-dev): Can this be template?
std::vector<framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"GetInputVarPtrs not support in dygraph runtime context"));
std::vector<framework::InferShapeVarPtr> res;
auto it = var_base_map_in_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_in_->end(),
platform::errors::NotFound("Can not find [%s] in inputs.", name));
for (auto& var : it->second) {
res.emplace_back(var->MutableVar());
}
return res;
}

std::vector<framework::InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"GetOutputVarPtrs not support in dygraph runtime context"));
std::vector<framework::InferShapeVarPtr> res;
auto it = var_base_map_out_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(),
platform::errors::NotFound("Can not find [%s] in outputs.", name));
for (auto& var : it->second) {
res.emplace_back(var->MutableVar());
}
return res;
}

DDim GetInputDim(const std::string& name) const override {
Expand Down

1 comment on commit 5636edb

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.