Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring InferShape #3946

Merged
merged 61 commits into from
Sep 26, 2017
Merged

Conversation

jacquesqiao
Copy link
Member

@jacquesqiao jacquesqiao commented Sep 7, 2017

fix: #4183
design: #4142

support compile time and runtime infershape.

@jacquesqiao jacquesqiao changed the title [wip] Impl vardesc [wip] refactoring InferShape Sep 14, 2017
@jacquesqiao jacquesqiao mentioned this pull request Sep 17, 2017
return GetDim(op_.Output(name));
}

void SetOutputDim(const std::string& name, const DDim& dim) const {
Copy link
Collaborator

@reyoung reyoung Sep 25, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set methods should change private data members, so they cannot be const

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
t = var->GetMutable<LoDTensor>();
} else {
t = const_cast<Tensor*>(GetTensorFromVar(scope_.FindVar(name)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const_cast should not be used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

@@ -413,7 +561,8 @@ class OperatorWithKernel : public OperatorBase {
}

protected:
virtual void InferShape(const InferShapeContext& ctx) const = 0;
virtual void InferShape(const InferShapeContext& ctx) const {}
virtual void InferShape(const InferShapeContextBase& ctx) const {}
Copy link
Collaborator

@reyoung reyoung Sep 25, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there are two interfaces?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old one is for compiling and test, has been removed after transform all the old operators.

@@ -195,6 +195,21 @@ std::vector<int64_t> vectorize(const DDim& ddim) {
return result;
}

std::string debug_str(const DDim& ddim) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is debug_str used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use it a lot when debug the code

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed since we can use cout << ddim to print the debug string.

}

private:
Tensor* GetTensor(const std::string& name, bool allocate) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allocate could be a template parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (allocate) {
t = var->GetMutable<LoDTensor>();
} else {
PADDLE_ENFORCE(false, "Variable(%s) should be tensor", name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_THROW(...);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -340,6 +335,77 @@ class ExecutionContext : public InferShapeContext {
const platform::DeviceContext& device_context_;
};

class RunTimeInferShapeContext : public InferShapeContextBase {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RunTime-> Runtime

runtime is a word

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -406,7 +474,7 @@ class OperatorWithKernel : public OperatorBase {
}

protected:
virtual void InferShape(const InferShapeContext& ctx) const = 0;
virtual void InferShape(InferShapeContextBase* ctx) const = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should be changed a lot. void OperatorBase::InferShape(const Scope& scope) should be removed. void InferShape(InferShapeContextBase* ctx) const should be public.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, This will be done in next pr that will modify all the python related code.

Copy link
Collaborator

@reyoung reyoung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement Compile time & Runtime Infershape
5 participants