Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support BF16 training for sharding and dp #46846

Merged
merged 18 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
1e70140
Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result
GhostScreaming Sep 14, 2022
e1f08a2
Merge branch 'reduce_sum' of https://github.com/GhostScreaming/Paddle…
GhostScreaming Sep 14, 2022
ff1bfbc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Sep 17, 2022
f4fe24f
support pure bfloat16
sneaxiy Sep 21, 2022
b420a32
support bf16 linear
sneaxiy Sep 21, 2022
5b7bc39
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Sep 26, 2022
7ff1388
update PR to pass CI
sneaxiy Sep 27, 2022
b9a7c14
tiny fix where_grad_kernel.cu
sneaxiy Sep 27, 2022
46662c4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
sneaxiy Sep 27, 2022
9e18791
Merge branch 'fix_bfloat16' of https://github.com/sneaxiy/Paddle into…
GhostScreaming Sep 27, 2022
29a9e77
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Sep 27, 2022
817d7ee
Support bfloat16 type for reducer and sharding.
GhostScreaming Sep 28, 2022
6e15126
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Sep 28, 2022
44abf06
Fix some bug.
GhostScreaming Sep 28, 2022
7fe04f2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Oct 10, 2022
384d497
Polish code.
GhostScreaming Oct 14, 2022
d012390
Polise code.
GhostScreaming Oct 16, 2022
480c732
Add bfloat16 datatype in fill_grad kernels.
GhostScreaming Oct 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paddle/fluid/distributed/collective/reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ static void ConcatTensorsWithType(
ConcatTensorsForAllReduce<DeviceContext, double>()(
context, dense_tensors_, p_dense_contents);
break;
case phi::DataType::BFLOAT16:
ConcatTensorsForAllReduce<DeviceContext, platform::bfloat16>()(
context, dense_tensors_, p_dense_contents);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it concats tensors for "
Expand Down Expand Up @@ -281,6 +285,10 @@ static void SplitTensorsWithType(const DeviceContext &context,
SplitTensorsForAllReduce<DeviceContext, double>()(
context, p_dense_contents, p_dense_tensors);
break;
case phi::DataType::BFLOAT16:
SplitTensorsForAllReduce<DeviceContext, platform::bfloat16>()(
context, p_dense_contents, p_dense_tensors);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors for "
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/fill_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(fill_grad,
int64_t,
int,
paddle::platform::float16,
paddle::platform::bfloat16,
bool) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/fill_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(fill,
int64_t,
int,
paddle::platform::float16,
paddle::platform::bfloat16,
bool) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/fill_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(fill_grad,
int64_t,
int,
paddle::platform::float16,
paddle::platform::bfloat16,
bool) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/fill_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(fill,
int64_t,
int,
paddle::platform::float16,
paddle::platform::bfloat16,
bool) {}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
alignment = {"gpu": 256, "cpu": 4096}
align = {
Type.fp16.value: 2,
Type.bf16.value: 2,
Type.fp32.value: 4,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,12 @@ def _rank_buffer_size(self, buffer_max_size, model_size):
"====== FP16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======"
.format(rank_buffer_size[Type.fp16.value] / 2**19,
model_size / 2**19))
if Type.bf16.value in rank_buffer_size.keys():
# FP16 GradStorage and model size
logger_.info(
"====== BF16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======"
.format(rank_buffer_size[Type.bf16.value] / 2**19,
model_size / 2**19))
if Type.fp32.value in rank_buffer_size.keys():
# FP32 GradStorage and model size
logger_.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(self, size, dtype, device, convert_cpu=False):
dtype=np.float16) if Type.fp16.value == dtype else np.zeros(
size, dtype=np.float32)
self.buffer = core.eager.Tensor(value=value, place=core.CPUPlace())
if dtype == Type.bf16.value:
self.buffer = paddle.cast(self.buffer, dtype=paddle.bfloat16)
else:
self.buffer = paddle.zeros(size, dtype=dtype)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Type(Enum):
Type of trainable parameters
"""
fp16 = paddle.float16
bf16 = paddle.bfloat16
fp32 = paddle.float32


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Type(Enum):
Type of trainable parameters
"""
fp16 = paddle.float16
bf16 = paddle.bfloat16
fp32 = paddle.float32


Expand Down