Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay/TOPI][Op] Add variance and layer norm op #3700

Merged
merged 11 commits into from
Aug 7, 2019

Conversation

icemelon
Copy link
Member

@icemelon icemelon commented Aug 3, 2019

Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers.

@junrushao1994 @kevinthesun Could you have help review?

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@icemelon
Copy link
Member Author

icemelon commented Aug 6, 2019

@kevinthesun @masahi please help to take a look.

The axis that should be normalized, typically the axis of the channels.

epsilon : double, optional, default=1e-5
Small float added to variance to avoid diving by zero.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dividing

The input data

axis : None or int or tuple of int
Axis or axes along which a mean operation is performed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean -> variance


axis : None or int or tuple of int
Axis or axes along which a mean operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of -> of all

The input data

axis : None or int or tuple of int
Axis or axes along which a mean operation is performed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean -> standard deviation


axis : None or int or tuple of int
Axis or axes along which a mean operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of all


axis : None or int or tuple of int
Axis or axes along which a mean operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of all

Copy link
Member

@masahi masahi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some typo issues but otherwise looks good.

@icemelon
Copy link
Member Author

icemelon commented Aug 7, 2019

Thanks @masahi . The docs are updated.

@masahi masahi merged commit 6b6e388 into apache:master Aug 7, 2019
@masahi
Copy link
Member

masahi commented Aug 7, 2019

thanks @icemelon9 @junrushao1994 this is merged.

@icemelon icemelon deleted the layer-norm branch August 7, 2019 17:52
wweic pushed a commit to wweic/tvm that referenced this pull request Aug 9, 2019
* Add LayerNorm op

* update

* fix

* Add mean_std and mean_variance

* add std and update doc

* add license

* x

* lint

* x

* fix

* fix doc
wweic pushed a commit to neo-ai/tvm that referenced this pull request Sep 6, 2019
* Add LayerNorm op

* update

* fix

* Add mean_std and mean_variance

* add std and update doc

* add license

* x

* lint

* x

* fix

* fix doc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants