Skip to content

Commit

Permalink
Added support for BF16 datatype for all oneDNN activation kernels (#4…
Browse files Browse the repository at this point in the history
…0721)

* added missing BF16 activations

* added softplus bf16

* minor change

* disabled tests for GPU
  • Loading branch information
jakpiase committed Mar 23, 2022
1 parent 292011e commit 8e67629
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 35 deletions.
28 changes: 26 additions & 2 deletions paddle/fluid/operators/abs_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@ namespace operators {
class AbsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};

class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -72,8 +87,17 @@ class AbsGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(dtype, ctx.GetPlace());
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};

Expand Down
43 changes: 16 additions & 27 deletions paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,7 @@ using ExpMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationKernel<ops::functor<float>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

#define REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(act_type, functor, \
grad_functor) \
#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL( \
act_type, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationKernel<ops::functor<float>>, \
Expand All @@ -339,30 +331,27 @@ namespace ops = paddle::operators;
ops::MKLDNNActivationKernel<ops::functor<float>>);

#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor); \
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); \
__macro(elu, EluMKLDNNFunctor, EluMKLDNNGradUseOutFunctor); \
__macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor);
__macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor); \
__macro(gelu, GeluMKLDNNFunctor, GeluMKLDNNGradFunctor); \
__macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(mish, MishMKLDNNFunctor, MishMKLDNNGradFunctor); \
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \
__macro(sigmoid, SigmoidMKLDNNFunctor, SigmoidMKLDNNGradUseOutFunctor); \
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradUseOutFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor);

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
REGISTER_ACTIVATION_MKLDNN_KERNEL_FWD_ONLY(round, RoundMKLDNNFunctor);

REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
ReluMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor,
GeluMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
SigmoidMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor,
SqrtMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(mish, MishMKLDNNFunctor,
MishMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_KERNEL_FWD_ONLY(round, RoundMKLDNNFunctor);

namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
softplus, MKLDNN, paddle::platform::CPUPlace,
ops::MKLDNNActivationKernel<ops::SoftplusMKLDNNFunctor<float>>);
ops::MKLDNNActivationKernel<ops::SoftplusMKLDNNFunctor<float>>,
ops::MKLDNNActivationKernel<
ops::SoftplusMKLDNNFunctor<paddle::platform::bfloat16>>);
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def setUp(self):
self.dtype = np.uint16
self.init_data()
self.config()
self.set_attrs()
self.out = self.op_forward(self.x)

self.inputs = {'X': convert_float_to_uint16(self.x)}
self.outputs = {'Out': self.out}
self.set_attrs()

def calculate_grads(self):
self.dx = self.op_grad(self.out, self.x)
Expand Down Expand Up @@ -162,5 +162,110 @@ def op_grad(self, dout, x):
return dout * ((np.exp(x) * omega) / delta**2)


class TestMKLDNNRelu6BF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "relu6"

def op_forward(self, x):
return np.clip(x, 0, 6)

def op_grad(self, dout, x):
return np.where((x > 0) & (x <= 6), dout, 0)


class TestMKLDNNLeakyReluBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "leaky_relu"

def op_forward(self, x):
return np.where(x > 0, x, self.alpha * x)

def op_grad(self, dout, x):
return np.where(x > 0, dout, self.alpha * dout)

def set_attrs(self):
self.alpha = 0.2
self.attrs = {"use_mkldnn": True, "alpha": self.alpha}


class TestMKLDNNSwishBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "swish"

def expit(self, val):
return 1 / (1 + np.exp(-self.beta * val))

def op_forward(self, x):
return x * self.expit(x)

def op_grad(self, dout, x):
return dout * self.expit(x) * (1 + self.beta * x * (1 - self.expit(x)))

def set_attrs(self):
self.beta = 0.2
self.attrs = {"use_mkldnn": True, "beta": self.beta}


class TestMKLDNNHardSwishBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "hard_swish"

def op_forward(self, x):
result = np.where(x < -3, 0, x)
return np.where(result > 3, result, result * (result + 3) / 6)

def op_grad(self, dout, x):
result = np.where(x < -3, 0, x)
return np.where(result > 3, dout, dout * (2 * x + 3) / 6)


class TestMKLDNNTanhBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "tanh"

def op_forward(self, x):
return np.tanh(x)

def op_grad(self, dout, x):
return dout * (1 - np.tanh(x)**2)


class TestMKLDNNAbsBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "abs"

def op_forward(self, x):
return np.absolute(x)

def op_grad(self, dout, x):
return dout * np.sign(x)


class TestMKLDNNEluBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "elu"

def op_forward(self, x):
return np.where(x > 0, x, self.alpha * (np.exp(x) - 1))

def op_grad(self, dout, x):
return np.where(x > 0, dout, dout * self.alpha * np.exp(x))

def set_attrs(self):
self.alpha = 0.2
self.attrs = {"use_mkldnn": True, "alpha": self.alpha}


class TestMKLDNNExpBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "exp"

def op_forward(self, x):
return np.exp(x)

def op_grad(self, dout, x):
return dout * np.exp(x)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
Expand All @@ -30,23 +30,32 @@ def ref_softplus(x, beta, threshold):
return out


@OpTestTool.skip_if(not (isinstance(_current_expected_place(), core.CPUPlace)),
"GPU is not supported")
@OpTestTool.skip_if_not_cpu_bf16()
class TestSoftplusOneDNNOp(OpTest):
def setUp(self):
self.op_type = "softplus"
self.beta = 1
self.threshold = 20
self.config()
self.set_dtype()
self.attrs = {'use_mkldnn': True, 'beta': self.beta}
self.inputs = {'X': np.random.random(self.x_shape).astype(np.float32)}
self.x = np.random.random(self.x_shape)
self.out = ref_softplus(self.x, self.beta, self.threshold)

if self.dtype != np.float32:
self.x = convert_float_to_uint16(self.x)

self.inputs = {'X': self.out}
self.outputs = {
'Out': ref_softplus(self.inputs['X'], self.beta, self.threshold)
'Out': ref_softplus(self.out, self.beta, self.threshold)
}

def config(self):
self.x_shape = (10, 10)

def set_dtype(self):
self.dtype = np.float32

def test_check_output(self):
self.check_output()

Expand All @@ -73,6 +82,27 @@ def config(self):
self.beta = 0.4


class TestSoftplusBF16OneDNNOp(TestSoftplusOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16


class TestSoftplus4DBF16OneDNNOp(TestSoftplus4DOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16


class TestSoftplus6DBF16OneDNNOp(TestSoftplus6DOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16


class TestSoftplus3DExtendedFunctorBF16OneDNNOp(
TestSoftplus3DExtendedFunctorOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16


if __name__ == "__main__":
paddle.enable_static()
unittest.main()

0 comments on commit 8e67629

Please sign in to comment.