Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP] support GPU BF16 amp for dygraph #39029

Merged
merged 22 commits into from
Feb 18, 2022

Conversation

zhangbo9674
Copy link
Contributor

@zhangbo9674 zhangbo9674 commented Jan 18, 2022

PR types

New features

PR changes

APIs

Describe

GPU动态图支持bf16混合精度训练。
paddle.amp.auto_cast()接口新增参数dtype,默认float16,可选【float16bfloat16
使用bfloat16进行混合精度训练示例:
paddle.amp.auto_cast(enable=True, custom_white_list={}, custom_black_list={}, level='O2', dtype='bfloat16')

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot-old
Copy link

Sorry to inform you that 2a0bd30's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Comment on lines 36 to 40
enum class AmpDtype {
D0 = 0, // float32
D1, // float16
D2, // bfloat16
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use pten::dtype directly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tks.

@@ -861,7 +861,7 @@ def train(layer, loader, loss_fn, opt):
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)

self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-5))
self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change precision?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maintain the original accuracy 1.e-5.

prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
if cuda_version is not None:
cuda_maj_decide = int(cuda_version.split('.')[0]) >= 11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the meaning of 'maj_decide'?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to cuda_version_check.

@@ -163,6 +185,7 @@ def check_optimizers(optimizers):
@signature_safe_contextmanager
@dygraph_only
def amp_guard(enable=True,
dtype='float16',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to make 'dtype' as the last parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tks.

Comment on lines 193 to 202
if (amp_dtype_ == AmpDtype::D1) {
new_ins = AutoCastInputs<VarType>(type, ins);
}
} else if (amp_level_ == AmpLevel::O2) {
VLOG(5) << "Pure fp16 run operator: " << type;
new_ins = CastPureFp16Inputs<VarType>(type, ins);
if (amp_dtype_ == AmpDtype::D1) {
new_ins = CastPureFp16Inputs<VarType>(type, ins);
} else if (amp_dtype_ == AmpDtype::D2) {
new_ins = CastPureBf16Inputs<VarType>(type, ins);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why bf16 only supports o2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

supported O1:op in white list will use bf16, other will use fp32.

@@ -19,6 +19,7 @@


def auto_cast(enable=True,
dtype='float16',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for the parameter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tks.

@@ -67,6 +67,9 @@ def convert_dtype(dtype):
# however, jointly supporting python2 and python3, (as well as python4 maybe)
# may still be a long-lasting problem.
return str(dtype)
# NOTE(zhangbo): Now numpy not support bfloat, and paddle use uint16 to carry the data of bfloat16, and there binary is consistent.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... does not support, and paddle uses unit16 to represent ..., and there binaries are consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tks.

Comment on lines 290 to 293
if dtype == 'float16':
amp_dtype = "float16"
elif dtype == 'bfloat16':
amp_dtype = "bfloat16"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

amp_dtype = dtype is ok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tks.

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhangbo9674 zhangbo9674 merged commit 7d6d384 into PaddlePaddle:develop Feb 18, 2022
@zhangbo9674 zhangbo9674 deleted the dev/bf16_a100 branch March 2, 2023 02:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants