-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP] support GPU BF16 amp for dygraph #39029
Conversation
Thanks for your contribution! |
Sorry to inform you that 2a0bd30's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
enum class AmpDtype { | ||
D0 = 0, // float32 | ||
D1, // float16 | ||
D2, // bfloat16 | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use pten::dtype directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, tks.
@@ -861,7 +861,7 @@ def train(layer, loader, loss_fn, opt): | |||
feed={feed_target_names[0]: tensor_img}, | |||
fetch_list=fetch_targets) | |||
|
|||
self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-5)) | |||
self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why change precision?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maintain the original accuracy 1.e-5.
prop = paddle.device.cuda.get_device_capability() | ||
cuda_version = paddle.version.cuda() | ||
if cuda_version is not None: | ||
cuda_maj_decide = int(cuda_version.split('.')[0]) >= 11 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the meaning of 'maj_decide'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to cuda_version_check
.
@@ -163,6 +185,7 @@ def check_optimizers(optimizers): | |||
@signature_safe_contextmanager | |||
@dygraph_only | |||
def amp_guard(enable=True, | |||
dtype='float16', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it better to make 'dtype' as the last parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, tks.
paddle/fluid/imperative/tracer.cc
Outdated
if (amp_dtype_ == AmpDtype::D1) { | ||
new_ins = AutoCastInputs<VarType>(type, ins); | ||
} | ||
} else if (amp_level_ == AmpLevel::O2) { | ||
VLOG(5) << "Pure fp16 run operator: " << type; | ||
new_ins = CastPureFp16Inputs<VarType>(type, ins); | ||
if (amp_dtype_ == AmpDtype::D1) { | ||
new_ins = CastPureFp16Inputs<VarType>(type, ins); | ||
} else if (amp_dtype_ == AmpDtype::D2) { | ||
new_ins = CastPureBf16Inputs<VarType>(type, ins); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why bf16 only supports o2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
supported O1:op in white list will use bf16, other will use fp32.
python/paddle/amp/auto_cast.py
Outdated
@@ -19,6 +19,7 @@ | |||
|
|||
|
|||
def auto_cast(enable=True, | |||
dtype='float16', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for the parameter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, tks.
python/paddle/fluid/data_feeder.py
Outdated
@@ -67,6 +67,9 @@ def convert_dtype(dtype): | |||
# however, jointly supporting python2 and python3, (as well as python4 maybe) | |||
# may still be a long-lasting problem. | |||
return str(dtype) | |||
# NOTE(zhangbo): Now numpy not support bfloat, and paddle use uint16 to carry the data of bfloat16, and there binary is consistent. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... does not support, and paddle uses unit16 to represent ..., and there binaries are consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, tks.
if dtype == 'float16': | ||
amp_dtype = "float16" | ||
elif dtype == 'bfloat16': | ||
amp_dtype = "bfloat16" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
amp_dtype = dtype is ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, tks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe
GPU动态图支持bf16混合精度训练。
paddle.amp.auto_cast()
接口新增参数dtype
,默认float16
,可选【float16
、bfloat16
】使用bfloat16进行混合精度训练示例:
paddle.amp.auto_cast(enable=True, custom_white_list={}, custom_black_list={}, level='O2', dtype='bfloat16')