From d7bd5b0908e2116c93f5ff9489ff1165a2d449e1 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 28 Sep 2021 13:11:31 +0000 Subject: [PATCH 1/7] add fp16 for clip_by_norm api --- python/paddle/fluid/layers/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 75b0392ab6ae4..ceda304b26e89 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12524,7 +12524,7 @@ def clip_by_norm(x, max_norm, name=None): return _C_ops.clip_by_norm(x, 'max_norm', max_norm) helper = LayerHelper("clip_by_norm", **locals()) - check_variable_and_dtype(x, 'X', ['float32'], 'clip_by_norm') + check_variable_and_dtype(x, 'X', ['float32', 'float16'], 'clip_by_norm') check_type(max_norm, 'max_norm', (float), 'clip_by_norm') if name is None: From c73a74d442b4ff856ef3df3f865b7da276f39797 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 11 Oct 2021 12:47:52 +0000 Subject: [PATCH 2/7] support ClipByGlobalNorm for fp16 in dygraph --- python/paddle/fluid/clip.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 5a9ea1a445e2d..dc270c615d2b1 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -435,6 +435,8 @@ def __str__(self): def _dygraph_clip(self, params_grads): params_and_grads = [] sum_square_list = [] + sum_square_list_fp16 = [] + sum_square_list_fp32 = [] for p, g in params_grads: if g is None: continue @@ -446,13 +448,36 @@ def _dygraph_clip(self, params_grads): merge_grad = layers.get_tensor_from_selected_rows(merge_grad) sum_square = _squared_l2_norm(merge_grad) - sum_square_list.append(sum_square) + if sum_square.dtype == core.VarDesc.VarType.FP16: + sum_square_list_fp16.append(sum_square) + elif sum_square.dtype == core.VarDesc.VarType.FP32: + sum_square_list_fp32.append(sum_square) + else: + sum_square_list.append(sum_square) # all parameters have been filterd out - if len(sum_square_list) == 0: + if len(sum_square_list) + len(sum_square_list_fp16) + len( + sum_square_list_fp32) == 0: return params_grads - global_norm_var = layers.concat(sum_square_list) + sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32" + global_norm_var = [] + if len(sum_square_list_fp16) > 0: + global_norm_var_fp16 = layers.concat(sum_square_list_fp16) + global_norm_var_fp16 = layers.reduce_sum(global_norm_var_fp16) + global_norm_var.append(global_norm_var_fp16.astype(sum_dtype)) + if len(sum_square_list_fp32) > 0: + global_norm_var_fp32 = layers.concat(sum_square_list_fp32) + global_norm_var_fp32 = layers.reduce_sum(global_norm_var_fp32) + if sum_dtype == 'float32': + global_norm_var.append(global_norm_var_fp32) + else: + global_norm_var.append(global_norm_var_fp32.astype(sum_dtype)) + if len(sum_square_list) > 0: + global_norm_var_fp64 = layers.concat(sum_square_list) + global_norm_var_fp64 = layers.reduce_sum(global_norm_var_fp64) + global_norm_var.append(global_norm_var_fp64) + global_norm_var = layers.concat(global_norm_var) global_norm_var = layers.reduce_sum(global_norm_var) global_norm_var = layers.sqrt(global_norm_var) max_global_norm = layers.fill_constant( @@ -468,7 +493,9 @@ def _dygraph_clip(self, params_grads): params_and_grads.append((p, g)) continue # TODO(wangxi): use inplace elementwise_mul - new_grad = layers.elementwise_mul(x=g, y=clip_var) + clip_input = (clip_var.astype('float16') + if g.dtype == core.VarDesc.VarType.FP16 else clip_var) + new_grad = layers.elementwise_mul(x=g, y=clip_input) params_and_grads.append((p, new_grad)) return params_and_grads From 4b7d372690b30a75de4bf2ef464d9a7a393b4e2b Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 12 Oct 2021 03:42:13 +0000 Subject: [PATCH 3/7] add unittest for dygraph clipGlobalNorm --- .../tests/unittests/test_gradient_clip.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py index e2050cf32dbdd..f80f1384756a9 100644 --- a/python/paddle/fluid/tests/unittests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -453,5 +453,59 @@ def check_clip_result(self, loss, optimizer): "gradient clip by value has wrong results!") +class SimpleNet(paddle.nn.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.linear = paddle.nn.Linear(5, 5) + self.batch_norm = paddle.nn.BatchNorm(5) + + def forward(self, x): + x = self.linear(x) + x = self.batch_norm(x) + return x + + +class TestDygraphGradientClipFP16(unittest.TestCase): + def test_gradient_clip(self): + with fluid.dygraph.guard(): + paddle.seed(10) + model = SimpleNet() + sgd_optimizer = fluid.optimizer.SGD( + learning_rate=0.0, + parameter_list=model.parameters(), + grad_clip=fluid.clip.GradientClipByGlobalNorm(1.0)) + model, sgd_optimizer = paddle.amp.decorate( + models=model, optimizers=sgd_optimizer, level='O2') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + inputs = fluid.layers.uniform_random( + [1, 5], min=-10, max=10).astype('float32') + with paddle.amp.auto_cast(level='O2'): + out = model(fluid.dygraph.to_variable(inputs)) + loss = fluid.layers.reduce_mean(out) + scaler.scale(loss).backward() + opt, params_grads = scaler.minimize(sgd_optimizer, loss) + clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=0.8) + _, grads = zip(*params_grads) + params_grads = clip(params_grads) + _, grads_clip = zip(*params_grads) + global_norm = 0 + for u in grads: + u = u.numpy() + global_norm += np.sum(np.power(u, 2)) + global_norm = np.sqrt(global_norm) + global_norm_clip = 0 + for v in grads_clip: + v = v.numpy() + global_norm_clip += np.sum(np.power(v, 2)) + global_norm_clip = np.sqrt(global_norm_clip) + a = np.minimum(global_norm, 0.8) + b = global_norm_clip + self.assertTrue( + np.isclose( + a=a, b=b, rtol=1e-3, atol=1e-8), + "gradient clip by global norm has wrong results, expetcd:%f, but recieved:%f" + % (a, b)) + + if __name__ == '__main__': unittest.main() From ffba47d6356a7b2f991af9060f4a910e55e4d88d Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 12 Oct 2021 05:15:14 +0000 Subject: [PATCH 4/7] refine unittest for dygraph clipGlobalNorm for mac and windows --- .../tests/unittests/test_gradient_clip.py | 77 ++++++++++--------- 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py index f80f1384756a9..04fc8c1eadc07 100644 --- a/python/paddle/fluid/tests/unittests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -467,44 +467,45 @@ def forward(self, x): class TestDygraphGradientClipFP16(unittest.TestCase): def test_gradient_clip(self): - with fluid.dygraph.guard(): - paddle.seed(10) - model = SimpleNet() - sgd_optimizer = fluid.optimizer.SGD( - learning_rate=0.0, - parameter_list=model.parameters(), - grad_clip=fluid.clip.GradientClipByGlobalNorm(1.0)) - model, sgd_optimizer = paddle.amp.decorate( - models=model, optimizers=sgd_optimizer, level='O2') - scaler = paddle.amp.GradScaler(init_loss_scaling=1024) - inputs = fluid.layers.uniform_random( - [1, 5], min=-10, max=10).astype('float32') - with paddle.amp.auto_cast(level='O2'): - out = model(fluid.dygraph.to_variable(inputs)) - loss = fluid.layers.reduce_mean(out) - scaler.scale(loss).backward() - opt, params_grads = scaler.minimize(sgd_optimizer, loss) - clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=0.8) - _, grads = zip(*params_grads) - params_grads = clip(params_grads) - _, grads_clip = zip(*params_grads) - global_norm = 0 - for u in grads: - u = u.numpy() - global_norm += np.sum(np.power(u, 2)) - global_norm = np.sqrt(global_norm) - global_norm_clip = 0 - for v in grads_clip: - v = v.numpy() - global_norm_clip += np.sum(np.power(v, 2)) - global_norm_clip = np.sqrt(global_norm_clip) - a = np.minimum(global_norm, 0.8) - b = global_norm_clip - self.assertTrue( - np.isclose( - a=a, b=b, rtol=1e-3, atol=1e-8), - "gradient clip by global norm has wrong results, expetcd:%f, but recieved:%f" - % (a, b)) + if fluid.core.is_compiled_with_cuda(): + with fluid.dygraph.guard(): + paddle.seed(10) + model = SimpleNet() + sgd_optimizer = fluid.optimizer.SGD( + learning_rate=0.0, + parameter_list=model.parameters(), + grad_clip=fluid.clip.GradientClipByGlobalNorm(1.0)) + model, sgd_optimizer = paddle.amp.decorate( + models=model, optimizers=sgd_optimizer, level='O2') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + inputs = fluid.layers.uniform_random( + [1, 5], min=-10, max=10).astype('float32') + with paddle.amp.auto_cast(level='O2'): + out = model(fluid.dygraph.to_variable(inputs)) + loss = fluid.layers.reduce_mean(out) + scaler.scale(loss).backward() + opt, params_grads = scaler.minimize(sgd_optimizer, loss) + clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=0.8) + _, grads = zip(*params_grads) + params_grads = clip(params_grads) + _, grads_clip = zip(*params_grads) + global_norm = 0 + for u in grads: + u = u.numpy() + global_norm += np.sum(np.power(u, 2)) + global_norm = np.sqrt(global_norm) + global_norm_clip = 0 + for v in grads_clip: + v = v.numpy() + global_norm_clip += np.sum(np.power(v, 2)) + global_norm_clip = np.sqrt(global_norm_clip) + a = np.minimum(global_norm, 0.8) + b = global_norm_clip + self.assertTrue( + np.isclose( + a=a, b=b, rtol=1e-3, atol=1e-8), + "gradient clip by global norm has wrong results, expetcd:%f, but recieved:%f" + % (a, b)) if __name__ == '__main__': From c88c385cecb72cbc0f7aa4e7f037c6495481a806 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 12 Oct 2021 12:04:28 +0000 Subject: [PATCH 5/7] refine unittest --- .../tests/unittests/test_gradient_clip.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py index 04fc8c1eadc07..743b28902a20e 100644 --- a/python/paddle/fluid/tests/unittests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -471,10 +471,8 @@ def test_gradient_clip(self): with fluid.dygraph.guard(): paddle.seed(10) model = SimpleNet() - sgd_optimizer = fluid.optimizer.SGD( - learning_rate=0.0, - parameter_list=model.parameters(), - grad_clip=fluid.clip.GradientClipByGlobalNorm(1.0)) + sgd_optimizer = paddle.optimizer.SGD( + learning_rate=0.0, parameters=model.parameters()) model, sgd_optimizer = paddle.amp.decorate( models=model, optimizers=sgd_optimizer, level='O2') scaler = paddle.amp.GradScaler(init_loss_scaling=1024) @@ -483,12 +481,25 @@ def test_gradient_clip(self): with paddle.amp.auto_cast(level='O2'): out = model(fluid.dygraph.to_variable(inputs)) loss = fluid.layers.reduce_mean(out) - scaler.scale(loss).backward() - opt, params_grads = scaler.minimize(sgd_optimizer, loss) - clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=0.8) + scaled = scaler.scale(loss) + scaled.backward() + scaler.unscale_(sgd_optimizer) + # before clip + params_grads = [] + for param in model.parameters(): + if param.stop_gradient: + continue + if param._grad_ivar() is not None: + params_grads.append((param, param._grad_ivar())) _, grads = zip(*params_grads) + # clip grads + clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=0.8) params_grads = clip(params_grads) _, grads_clip = zip(*params_grads) + # param update + scaler.step(sgd_optimizer) + scaler.update() + global_norm = 0 for u in grads: u = u.numpy() @@ -499,6 +510,7 @@ def test_gradient_clip(self): v = v.numpy() global_norm_clip += np.sum(np.power(v, 2)) global_norm_clip = np.sqrt(global_norm_clip) + a = np.minimum(global_norm, 0.8) b = global_norm_clip self.assertTrue( From e99f3568594c185ab8a33abf279b2f381a3a46b1 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 12 Oct 2021 12:56:30 +0000 Subject: [PATCH 6/7] add unittest for fp64 --- .../tests/unittests/test_gradient_clip.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py index 743b28902a20e..8c116cb28c0b2 100644 --- a/python/paddle/fluid/tests/unittests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -520,5 +520,52 @@ def test_gradient_clip(self): % (a, b)) +class TestDygraphGradientClipFP64(unittest.TestCase): + def test_gradient_clip(self): + with fluid.dygraph.guard(): + inputs = fluid.layers.uniform_random( + [16, 5], min=-10, max=10).astype('float64') + linear = fluid.dygraph.Linear(5, 5, dtype="float64") + out = linear(fluid.dygraph.to_variable(inputs)) + out = linear(fluid.dygraph.to_variable(inputs)) + loss = fluid.layers.reduce_mean(out) + loss.backward() + # before clip + params_grads = [] + for param in linear.parameters(): + if param.stop_gradient: + continue + if param._grad_ivar() is not None: + params_grads.append((param, param._grad_ivar())) + _, grads = zip(*params_grads) + # clip grads + clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=0.1) + params_grads = clip(params_grads) + _, grads_clip = zip(*params_grads) + + global_norm = 0 + for u in grads: + u = u.numpy() + global_norm += np.sum(np.power(u, 2)) + global_norm = np.sqrt(global_norm) + + global_norm_clip = 0 + for v in grads_clip: + v = v.numpy() + print(v) + global_norm_clip += np.sum(np.power(v, 2)) + global_norm_clip = np.sqrt(global_norm_clip) + print(global_norm_clip) + + a = np.minimum(global_norm, 0.1) + b = global_norm_clip + + self.assertTrue( + np.isclose( + a=a, b=b, rtol=1e-6, atol=1e-8), + "gradient clip by global norm has wrong results, expetcd:%f, but recieved:%f" + % (a, b)) + + if __name__ == '__main__': unittest.main() From 8539be7d15b0d75882f76417d819e2a3e673d793 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 12 Oct 2021 12:59:02 +0000 Subject: [PATCH 7/7] refine unittest for fp64 --- python/paddle/fluid/tests/unittests/test_gradient_clip.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py index 8c116cb28c0b2..29735f1c89c85 100644 --- a/python/paddle/fluid/tests/unittests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -527,7 +527,6 @@ def test_gradient_clip(self): [16, 5], min=-10, max=10).astype('float64') linear = fluid.dygraph.Linear(5, 5, dtype="float64") out = linear(fluid.dygraph.to_variable(inputs)) - out = linear(fluid.dygraph.to_variable(inputs)) loss = fluid.layers.reduce_mean(out) loss.backward() # before clip