Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix gpt2 train loss Nan problem by add a line __syncthreads in BlockR… #33659

Merged
merged 1 commit into from
Jun 21, 2021
Merged

Conversation

zhiboniu
Copy link
Contributor

@zhiboniu zhiboniu commented Jun 18, 2021

PR types

Bug fixes

PR changes

OPs

Describe

背景:
gpt2 训练过程中出现loss不稳定、不收敛、最终变成nan的情况。

经排查:
1)在P40上训练正常,V100上训练出现异常。
2)添加一行log打印训练正常,无log打印训练异常。
3)使用原线性相加方式训练正常,使用BlockReduceSum训练异常。

最终添加一行__syncthreads后解决训练异常问题。
同时对另外两个BlockReduceSum做了同步修改。

对于shared[32]使用的共享内存大小数据32,来源是:
int wid = threadIdx.x / warpSize;
nvidia gpu blockdim最大1024,warpSize 32,所以改大小不超过maxblockdim/warpsize=32。

cherry-pick from:
#33658

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@XiaoguangHu01 XiaoguangHu01 merged commit cdeffff into PaddlePaddle:release/2.1 Jun 21, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants