-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring InferShape #3946
Refactoring InferShape #3946
Conversation
…into impl-vardesc
paddle/framework/operator.h
Outdated
return GetDim(op_.Output(name)); | ||
} | ||
|
||
void SetOutputDim(const std::string& name, const DDim& dim) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set methods should change private data members, so they cannot be const
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/framework/operator.h
Outdated
if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) { | ||
t = var->GetMutable<LoDTensor>(); | ||
} else { | ||
t = const_cast<Tensor*>(GetTensorFromVar(scope_.FindVar(name))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const_cast should not be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
paddle/framework/operator.h
Outdated
@@ -413,7 +561,8 @@ class OperatorWithKernel : public OperatorBase { | |||
} | |||
|
|||
protected: | |||
virtual void InferShape(const InferShapeContext& ctx) const = 0; | |||
virtual void InferShape(const InferShapeContext& ctx) const {} | |||
virtual void InferShape(const InferShapeContextBase& ctx) const {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why there are two interfaces?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The old one is for compiling and test, has been removed after transform all the old operators.
paddle/framework/ddim.cc
Outdated
@@ -195,6 +195,21 @@ std::vector<int64_t> vectorize(const DDim& ddim) { | |||
return result; | |||
} | |||
|
|||
std::string debug_str(const DDim& ddim) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is debug_str
used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use it a lot when debug the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed since we can use cout << ddim to print the debug string.
paddle/framework/operator.h
Outdated
} | ||
|
||
private: | ||
Tensor* GetTensor(const std::string& name, bool allocate) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allocate could be a template parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/framework/operator.h
Outdated
if (allocate) { | ||
t = var->GetMutable<LoDTensor>(); | ||
} else { | ||
PADDLE_ENFORCE(false, "Variable(%s) should be tensor", name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_THROW(...)
;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/framework/operator.h
Outdated
@@ -340,6 +335,77 @@ class ExecutionContext : public InferShapeContext { | |||
const platform::DeviceContext& device_context_; | |||
}; | |||
|
|||
class RunTimeInferShapeContext : public InferShapeContextBase { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RunTime-> Runtime
runtime is a word
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -406,7 +474,7 @@ class OperatorWithKernel : public OperatorBase { | |||
} | |||
|
|||
protected: | |||
virtual void InferShape(const InferShapeContext& ctx) const = 0; | |||
virtual void InferShape(InferShapeContextBase* ctx) const = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should be changed a lot. void OperatorBase::InferShape(const Scope& scope)
should be removed. void InferShape(InferShapeContextBase* ctx) const
should be public.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, This will be done in next pr that will modify all the python related code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent
fix: #4183
design: #4142
support compile time and runtime infershape.