diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 547f2aded1f00..bfef66d3ae51b 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -1220,8 +1220,7 @@ def __init__( ) self.normalize_before = normalize_before - self._dtype = "float16" - #self._dtype = self._helper.get_default_dtype() + self._dtype = self._helper.get_default_dtype() self._epsilon = epsilon self._trans_qkvw = trans_qkvw self._ring_id = ring_id @@ -3477,4 +3476,4 @@ def shard_tensor(dst_tensor, parent_tensor, pos): self.shared_weights2.append(shared_weight2) self.shared_scales2.append(shared_scale2) - self.shared_biases2.append(shared_bias2) \ No newline at end of file + self.shared_biases2.append(shared_bias2)