Skip to content

Commit

Permalink
refine python interface for bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 committed Jan 19, 2022
1 parent ab98946 commit 2a0bd30
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
3 changes: 3 additions & 0 deletions python/paddle/fluid/data_feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def convert_dtype(dtype):
# however, jointly supporting python2 and python3, (as well as python4 maybe)
# may still be a long-lasting problem.
return str(dtype)
# NOTE(zhangbo): Now numpy not support bfloat, and paddle use uint16 to carry the data of bfloat16, and there binary is consistent.
if dtype in ['bfloat16']:
return 'uint16'

raise TypeError(
"dtype must be any of [bool, float16, uint16, float32, float64, int8, int16, "
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def amp_guard(enable=True,
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16 train mode."
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)

# check amp_dtype: float16 or bfloat16
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/fluid/layers/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def cast(x, dtype):
x(Tensor): An input N-D Tensor with data type bool, float16,
float32, float64, int32, int64, uint8.
dtype(np.dtype|core.VarDesc.VarType|str): Data type of the output:
bool, float16, float32, float64, int8, int32, int64, uint8.
bool, float16, float32, float64, int8, int32, int64, uint8, bfloat16.
Returns:
Tensor: A Tensor with the same shape as input's.
Expand All @@ -245,6 +245,8 @@ def cast(x, dtype):
"""
if in_dygraph_mode():
if not isinstance(dtype, core.VarDesc.VarType):
if dtype == 'bfloat16':
dtype = 'uint16'
dtype = convert_np_dtype_to_dtype_(dtype)
out = _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return out
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/tensor/to_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def to_string(var, prefix='Tensor'):
data = _format_tensor(
np_var, summary, indent=indent, max_width=max_width, signed=signed)

return _template.format(
_template = _template.format(
prefix=prefix,
shape=var.shape,
dtype=convert_dtype(var.dtype),
Expand All @@ -256,6 +256,11 @@ def to_string(var, prefix='Tensor'):
indent=' ' * indent,
data=data)

if var.dtype == paddle.fluid.core.VarDesc.VarType.BF16:
_template = _template + "\nthis data real dtype is bfloat16, you can use paddle.cast to cast it to float32 and get it's float data."

return _template


def eager_tensor_to_string(tensor, prefix='Tensor'):
indent = len(prefix) + 1
Expand Down

0 comments on commit 2a0bd30

Please sign in to comment.