Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP] support GPU BF16 amp for dygraph #39029

Merged
merged 22 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 104 additions & 12 deletions paddle/fluid/imperative/amp_auto_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,26 +113,40 @@ AutoCastGuard::~AutoCastGuard() { tracer_->SetAmpLevel(pre_amp_level_); }
AmpOperators::AmpOperators()
: allow_ops_(new std::unordered_set<std::string>()),
block_ops_(new std::unordered_set<std::string>()),
unsupported_fp16_ops_(new std::unordered_set<std::string>()) {
unsupported_fp16_ops_(new std::unordered_set<std::string>()),
unsupported_bf16_ops_(new std::unordered_set<std::string>()) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto unsupported_ops_gpu = std::get<2>(
auto unsupported_ops_gpu_fp16 = std::get<2>(
OpSupportedInfos("GPU", paddle::framework::proto::VarType::FP16));
unsupported_fp16_ops_->insert(unsupported_ops_gpu.begin(),
unsupported_ops_gpu.end());
unsupported_fp16_ops_->insert(unsupported_ops_gpu_fp16.begin(),
unsupported_ops_gpu_fp16.end());
auto unsupported_ops_gpu_bf16 = std::get<2>(
OpSupportedInfos("GPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_gpu_bf16.begin(),
unsupported_ops_gpu_bf16.end());
// NOTE: GPU/NPU/XPU is compiled seperatly.
#elif defined(PADDLE_WITH_ASCEND_CL)
auto unsupported_ops_npu = std::get<2>(
auto unsupported_ops_npu_fp16 = std::get<2>(
OpSupportedInfos("NPU", paddle::framework::proto::VarType::FP16));
unsupported_fp16_ops_->insert(unsupported_ops_npu.begin(),
unsupported_ops_npu.end());
unsupported_fp16_ops_->insert(unsupported_ops_npu_fp16.begin(),
unsupported_ops_npu_fp16.end());
auto unsupported_ops_npu_bf16 = std::get<2>(
OpSupportedInfos("NPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_npu_bf16.begin(),
unsupported_ops_npu_bf16.end());
#elif defined(PADDLE_WITH_XPU)
auto unsupported_ops_xpu = std::get<2>(
auto unsupported_ops_xpu_fp16 = std::get<2>(
OpSupportedInfos("XPU", paddle::framework::proto::VarType::FP16));
unsupported_fp16_ops_->insert(unsupported_ops_xpu.begin(),
unsupported_ops_xpu.end());
unsupported_fp16_ops_->insert(unsupported_ops_xpu_fp16.begin(),
unsupported_ops_xpu_fp16.end());
auto unsupported_ops_xpu_bf16 = std::get<2>(
OpSupportedInfos("XPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_xpu_bf16.begin(),
unsupported_ops_xpu_bf16.end());
#endif
VLOG(4) << allow_ops_->size() << " " << block_ops_->size() << " "
<< unsupported_fp16_ops_->size();
<< unsupported_fp16_ops_->size() << " "
<< unsupported_bf16_ops_->size();
}

AmpOperators::~AmpOperators() {}
Expand All @@ -157,6 +171,11 @@ AmpOperators::GetMutableUnsupportedFp16Ops() {
return unsupported_fp16_ops_;
}

std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableUnsupportedBf16Ops() {
return unsupported_bf16_ops_;
}

std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
os << "allow ops: ";
auto allow_ops = ops.GetMutableAllowOps();
Expand All @@ -172,6 +191,11 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
auto unsupported_fp16_ops = ops.GetMutableUnsupportedFp16Ops();
std::copy((*unsupported_fp16_ops).begin(), (*unsupported_fp16_ops).end(),
std::ostream_iterator<std::string>(os, " "));
os << "\n";
os << "unsupported bf16 ops: ";
auto unsupported_bf16_ops = ops.GetMutableUnsupportedBf16Ops();
std::copy((*unsupported_bf16_ops).begin(), (*unsupported_bf16_ops).end(),
std::ostream_iterator<std::string>(os, " "));
return os;
}

Expand All @@ -188,7 +212,8 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) {
paddle::platform::is_xpu_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader
if (data_type == paddle::framework::proto::VarType::FP32 ||
data_type == paddle::framework::proto::VarType::FP16) {
data_type == paddle::framework::proto::VarType::FP16 ||
data_type == paddle::framework::proto::VarType::BF16) {
return true;
}
}
Expand Down Expand Up @@ -236,6 +261,16 @@ static inline std::shared_ptr<VarType> CastToFP32(
return var;
}

template <typename VarType>
static inline std::shared_ptr<VarType> CastToBF16(
const std::shared_ptr<VarType>& var) {
auto dst_type = framework::proto::VarType::BF16;
if (NeedCast(var) && (GetDataType<VarType>(var) != dst_type)) {
return CastToType(var, dst_type);
}
return var;
}

template <typename VarType>
static inline framework::proto::VarType::Type GetPromoteType(
const std::string& op_type, const NameVarMap<VarType>& ins) {
Expand Down Expand Up @@ -386,5 +421,62 @@ template NameVarMap<VarBase> CastPureFp16Inputs<VarBase>(
const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerVariable> CastPureFp16Inputs<egr::EagerVariable>(
const std::string& op_type, const NameVarMap<egr::EagerVariable>& ins);

template <typename VarType>
NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins) {
NameVarMap<VarType> new_ins(ins);
if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) {
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to bfloat16";
for (auto& var : pair.second) {
var = CastToBF16<VarType>(var);
}
}
return new_ins;
} else {
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float";
for (auto& var : pair.second) {
var = CastToFP32<VarType>(var);
}
}
return new_ins;
}
return new_ins;
}
template NameVarMap<VarBase> AutoCastBF16Inputs<VarBase>(
const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerVariable> AutoCastBF16Inputs<egr::EagerVariable>(
const std::string& op_type, const NameVarMap<egr::EagerVariable>& ins);

template <typename VarType>
NameVarMap<VarType> CastPureBf16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins) {
NameVarMap<VarType> new_ins(ins);
auto dst_type = framework::proto::VarType::BF16;
if (AmpOperators::Instance().GetMutableUnsupportedBf16Ops()->count(op_type) ||
AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
dst_type = framework::proto::VarType::FP32;
}
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type);
for (auto& var : pair.second) {
var = (dst_type == framework::proto::VarType::FP32
? CastToFP32<VarType>(var)
: CastToBF16<VarType>(var));
}
}
return new_ins;
}
template NameVarMap<VarBase> CastPureBf16Inputs<VarBase>(
const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerVariable> CastPureBf16Inputs<egr::EagerVariable>(
const std::string& op_type, const NameVarMap<egr::EagerVariable>& ins);

} // namespace imperative
} // namespace paddle
12 changes: 12 additions & 0 deletions paddle/fluid/imperative/amp_auto_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class AmpOperators {
std::shared_ptr<std::unordered_set<std::string>>
GetMutableUnsupportedFp16Ops();

std::shared_ptr<std::unordered_set<std::string>>
GetMutableUnsupportedBf16Ops();

private:
AmpOperators(); // forbid calling default constructor

Expand All @@ -69,6 +72,9 @@ class AmpOperators {

// The set of ops that has no fp16 CUDA kennel.
std::shared_ptr<std::unordered_set<std::string>> unsupported_fp16_ops_;

// The set of ops that has no bf16 CUDA kennel.
std::shared_ptr<std::unordered_set<std::string>> unsupported_bf16_ops_;
};

std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
Expand All @@ -95,6 +101,12 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
template <typename VarType>
NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins);
template <typename VarType>
NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins);
template <typename VarType>
NameVarMap<VarType> CastPureBf16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins);

} // namespace imperative
} // namespace paddle
17 changes: 17 additions & 0 deletions paddle/fluid/imperative/gradient_accumulator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
Expand Down Expand Up @@ -423,6 +424,22 @@ void TensorAdd(const VarType& src, VarType* dst) {
src_tensor, dst_tensor, place);
}
}
if (data_type == framework::proto::VarType::BF16) {
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA)
return TensorAddImpl<platform::CUDADeviceContext, platform::bfloat16>(
src_tensor, dst_tensor, place);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type), place));
#endif
} else if (platform::is_cpu_place(place)) {
return TensorAddImpl<platform::CPUDeviceContext, platform::bfloat16>(
src_tensor, dst_tensor, place);
}
}
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
Expand Down
14 changes: 12 additions & 2 deletions paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ thread_local bool Tracer::has_grad_ = true;

thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0;

thread_local pten::DataType Tracer::amp_dtype_ = pten::DataType::FLOAT32;

static std::shared_ptr<Tracer> g_current_tracer(nullptr);

const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; }
Expand Down Expand Up @@ -200,10 +202,18 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
NameVarMap<VarType> new_ins = ins;
if (amp_level_ == AmpLevel::O1) {
VLOG(5) << "Auto mixed precision run operator: " << type;
new_ins = AutoCastInputs<VarType>(type, ins);
if (amp_dtype_ == pten::DataType::FLOAT16) {
new_ins = AutoCastInputs<VarType>(type, ins);
} else if (amp_dtype_ == pten::DataType::BFLOAT16) {
new_ins = AutoCastBF16Inputs<VarType>(type, ins);
}
} else if (amp_level_ == AmpLevel::O2) {
VLOG(5) << "Pure fp16 run operator: " << type;
new_ins = CastPureFp16Inputs<VarType>(type, ins);
if (amp_dtype_ == pten::DataType::FLOAT16) {
new_ins = CastPureFp16Inputs<VarType>(type, ins);
} else if (amp_dtype_ == pten::DataType::BFLOAT16) {
new_ins = CastPureBf16Inputs<VarType>(type, ins);
}
}

try {
Expand Down
24 changes: 24 additions & 0 deletions paddle/fluid/imperative/tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace imperative {

enum class AmpLevel;

enum class AmpDtype;

using GarbageCollectorMap =
std::map<platform::Place,
std::unique_ptr<paddle::framework::GarbageCollector>>;
Expand Down Expand Up @@ -131,6 +133,27 @@ class Tracer {

AmpLevel GetAmpLevel() const { return amp_level_; }

void SetAmpDtype(std::string amp_dtype) {
VLOG(4) << "set amp_dtype to " << amp_dtype;
if (amp_dtype == "float16") {
amp_dtype_ = pten::DataType::FLOAT16;
} else if (amp_dtype == "bfloat16") {
amp_dtype_ = pten::DataType::BFLOAT16;
} else {
amp_dtype_ = pten::DataType::FLOAT32;
}
}

std::string GetAmpDtype() const {
if (amp_dtype_ == pten::DataType::FLOAT16) {
return std::string("float16");
} else if (amp_dtype_ == pten::DataType::BFLOAT16) {
return std::string("bfloat16");
} else {
return std::string("float32");
}
}

paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists(
const platform::Place& place);

Expand All @@ -143,6 +166,7 @@ class Tracer {
GarbageCollectorMap gcs_;
static thread_local bool has_grad_;
static thread_local AmpLevel amp_level_;
static thread_local pten::DataType amp_dtype_;
};

// To access static variable current_tracer
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2230,6 +2230,8 @@ void BindImperative(py::module *m_ptr) {
&imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_amp_level", &imperative::Tracer::GetAmpLevel,
&imperative::Tracer::SetAmpLevel)
.def_property("_amp_dtype", &imperative::Tracer::GetAmpDtype,
&imperative::Tracer::SetAmpDtype)
.def_property("_has_grad", &imperative::Tracer::HasGrad,
&imperative::Tracer::SetHasGrad)
.def_property(
Expand Down
2 changes: 2 additions & 0 deletions paddle/pten/kernels/funcs/math_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ struct ElementwiseAddTo<paddle::platform::CPUDeviceContext, T> {

template struct ElementwiseAddTo<paddle::platform::CPUDeviceContext,
pten::dtype::float16>;
template struct ElementwiseAddTo<paddle::platform::CPUDeviceContext,
pten::dtype::bfloat16>;

} // namespace funcs
} // namespace pten
2 changes: 2 additions & 0 deletions paddle/pten/kernels/funcs/math_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ struct ElementwiseAddTo<paddle::platform::CUDADeviceContext, T> {

template struct ElementwiseAddTo<paddle::platform::CUDADeviceContext,
pten::dtype::float16>;
template struct ElementwiseAddTo<paddle::platform::CUDADeviceContext,
pten::dtype::bfloat16>;

} // namespace funcs
} // namespace pten
8 changes: 5 additions & 3 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
def auto_cast(enable=True,
custom_white_list=None,
custom_black_list=None,
level='O1'):
level='O1',
dtype='float16'):
"""
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
If enabled, the input data type (float32 or float16) of each operator is decided
Expand All @@ -40,7 +41,8 @@ def auto_cast(enable=True,
observed in downstream ops. These ops will not be converted to fp16.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)

dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.

Examples:

.. code-block:: python
Expand Down Expand Up @@ -73,7 +75,7 @@ def auto_cast(enable=True,
print(d.dtype) # FP16

"""
return amp_guard(enable, custom_white_list, custom_black_list, level)
return amp_guard(enable, custom_white_list, custom_black_list, level, dtype)


def decorate(models,
Expand Down
15 changes: 13 additions & 2 deletions python/paddle/distributed/fleet/utils/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ def forward(ctx, run_function, preserve_rng_state, *args):
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))

if tracer._amp_dtype == 'float16':
ctx.amp_dtype = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'):
ctx.amp_dtype = 'bfloat16'
else:
raise ValueError("unsupported amp dtype: {}".format(
tracer._amp_dtype))

ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

with paddle.no_grad():
Expand Down Expand Up @@ -137,15 +146,17 @@ def backward(ctx, *args):
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level):
level=ctx.amp_level,
dtype=ctx.amp_dtype):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level):
level=ctx.amp_level,
dtype=ctx.amp_dtype):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)

Expand Down
Loading