From 1eac657263677ac68133ecf4433568a559a7c5c8 Mon Sep 17 00:00:00 2001 From: ming1753 Date: Fri, 22 Jul 2022 04:17:25 +0000 Subject: [PATCH 1/4] (modified) fc support fp16 --- paddle/phi/kernels/funcs/fc_functor.cu | 61 ++++++++------------------ 1 file changed, 18 insertions(+), 43 deletions(-) diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu index 1f2db5583295a..d0bd7567c7d5c 100644 --- a/paddle/phi/kernels/funcs/fc_functor.cu +++ b/paddle/phi/kernels/funcs/fc_functor.cu @@ -36,6 +36,24 @@ struct FcTypeTraits { typedef double4 Type; }; +#if defined(PADDLE_WITH_CUDA) +#include + +template <> +struct FcTypeTraits { + typedef half2 Type; +}; +#else +struct float16_4 { + float16 x, y, z, w; +}; + +template <> +struct FcTypeTraits { + typedef float16_4 Type; +}; +#endif + template __global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) { int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -109,14 +127,6 @@ void AddReluKernel( } #if defined(PADDLE_WITH_CUDA) - -#include - -template <> -struct FcTypeTraits { - typedef half2 Type; -}; - template __global__ void bias_relu_v2(const int num, const half2* bias, @@ -200,46 +210,11 @@ void AddReluKernel(cudaStream_t stream, } #else - -struct float16_4 { - float16 x, y, z, w; -}; -template <> -struct FcTypeTraits { - typedef float16_4 Type; -}; - -template -__global__ void bias_relu_v4(const int num, - const float16_4* bias, - float16_4* data, - int K) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < num) { - int bias_idx = tid % K; - const float16_4 bias_ptr = bias[bias_idx]; - const float16_4 in_ptr = data[tid]; - float16_4 packed_val; - packed_val.x = in_ptr.x + bias_ptr.x; - packed_val.y = in_ptr.y + bias_ptr.y; - packed_val.z = in_ptr.z + bias_ptr.z; - packed_val.w = in_ptr.w + bias_ptr.w; - if (DoRelu) { - packed_val.x = fmaxf(0.f, packed_val.x); - packed_val.y = fmaxf(0.f, packed_val.y); - packed_val.z = fmaxf(0.f, packed_val.z); - packed_val.w = fmaxf(0.f, packed_val.w); - } - data[tid] = packed_val; - } -} - template __global__ void InplaceAddReluKernel(const int N, const float16* bias, float16* data) { int offset = blockIdx.x * N; - for (int i = threadIdx.x; i < N; i += BlockDim) { float16 temp; temp = data[offset + i] + bias[i]; From 437d6654b9d2b41d60dfe0cd0672e60a6d3565b3 Mon Sep 17 00:00:00 2001 From: ming1753 Date: Fri, 22 Jul 2022 10:38:50 +0000 Subject: [PATCH 2/4] __CUDA_ARCH__ version --- paddle/phi/kernels/funcs/fc_functor.cu | 36 ++++++++++++-------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu index d0bd7567c7d5c..5b2e456ef786c 100644 --- a/paddle/phi/kernels/funcs/fc_functor.cu +++ b/paddle/phi/kernels/funcs/fc_functor.cu @@ -36,24 +36,6 @@ struct FcTypeTraits { typedef double4 Type; }; -#if defined(PADDLE_WITH_CUDA) -#include - -template <> -struct FcTypeTraits { - typedef half2 Type; -}; -#else -struct float16_4 { - float16 x, y, z, w; -}; - -template <> -struct FcTypeTraits { - typedef float16_4 Type; -}; -#endif - template __global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) { int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -126,7 +108,14 @@ void AddReluKernel( } } -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) && __CUDA_ARCH__ >= 530 +#include + +template <> +struct FcTypeTraits { + typedef half2 Type; +}; + template __global__ void bias_relu_v2(const int num, const half2* bias, @@ -210,6 +199,15 @@ void AddReluKernel(cudaStream_t stream, } #else +struct float16_4 { + float16 x, y, z, w; +}; + +template <> +struct FcTypeTraits { + typedef float16_4 Type; +}; + template __global__ void InplaceAddReluKernel(const int N, const float16* bias, From a8582e01d1b690ad872006ae623faa3d272b0a9a Mon Sep 17 00:00:00 2001 From: ming1753 Date: Fri, 22 Jul 2022 12:15:52 +0000 Subject: [PATCH 3/4] delete half --- paddle/phi/kernels/funcs/fc_functor.cu | 95 +++++++------------------- 1 file changed, 23 insertions(+), 72 deletions(-) diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu index 8650de254d32d..334e41e0c71a5 100644 --- a/paddle/phi/kernels/funcs/fc_functor.cu +++ b/paddle/phi/kernels/funcs/fc_functor.cu @@ -36,6 +36,14 @@ struct FcTypeTraits { typedef double4 Type; }; +struct float16_4 { + float16 x, y, z, w; +}; +template <> +struct FcTypeTraits { + typedef float16_4 Type; +}; + template __global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) { int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -109,92 +117,36 @@ void AddReluKernel( } template -__global__ void bias_relu_v2(const int num, - const half2* bias, - half2* data, +__global__ void bias_relu_v4(const int num, + const float16_4* bias, + float16_4* data, int K) { int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < num) { int bias_idx = tid % K; - const half2 bias_ptr = bias[bias_idx]; - const half2 in_ptr = data[tid]; - half2 packed_val = __hadd2(bias_ptr, in_ptr); + const float16_4 bias_ptr = bias[bias_idx]; + const float16_4 in_ptr = data[tid]; + float16_4 packed_val; + packed_val.x = in_ptr.x + bias_ptr.x; + packed_val.y = in_ptr.y + bias_ptr.y; + packed_val.z = in_ptr.z + bias_ptr.z; + packed_val.w = in_ptr.w + bias_ptr.w; if (DoRelu) { -#if __CUDA_ARCH__ >= 800 - packed_val = __hmax2(__half2(0, 0), packed_val); -#else - packed_val = __hmul2(__hgt2(__half2(0, 0), packed_val), packed_val); -#endif + packed_val.x = fmaxf(0.f, packed_val.x); + packed_val.y = fmaxf(0.f, packed_val.y); + packed_val.z = fmaxf(0.f, packed_val.z); + packed_val.w = fmaxf(0.f, packed_val.w); } data[tid] = packed_val; } } -template -__global__ void InplaceAddReluKernel(const int N, - const half* bias, - half* data) { - int offset = blockIdx.x * N; - for (int i = threadIdx.x; i < N; i += BlockDim) { - half temp; -#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350 - temp = __ldg(data + offset + i) + __ldg(bias + i); -#else - temp = data[offset + i] + bias[i]; -#endif - if (DoRelu) { -#if __CUDA_ARCH__ >= 800 - data[offset + i] = __hmax(0, temp); -#else - data[offset + i] = __hmul(__hgt(temp, 0), temp); -#endif - } else { - data[offset + i] = temp; - } - } -} - -template <> -void AddReluKernel(cudaStream_t stream, - const int M, - const int N, - float16* Y, - const float16* B, - bool relu) { - if (N % 2 == 0) { - const int threads = 256; - const int num = M * N / 2; - const int blocks = (num + threads - 1) / threads; - typedef typename FcTypeTraits::Type trans_type; - auto* bias_ptr_v2 = reinterpret_cast(B); - auto* data_ptr_v2 = reinterpret_cast(Y); - if (relu) { - bias_relu_v2<<>>( - num, bias_ptr_v2, data_ptr_v2, N / 2); - } else { - bias_relu_v2<<>>( - num, bias_ptr_v2, data_ptr_v2, N / 2); - } - } else { - const int threads = 256; - const int blocks = M; - auto* halfB = reinterpret_cast(B); - auto* halfY = reinterpret_cast(Y); - if (relu) { - InplaceAddReluKernel - <<>>(N, halfB, halfY); - } else { - InplaceAddReluKernel - <<>>(N, halfB, halfY); - } - } -} - template __global__ void InplaceAddReluKernel(const int N, const float16* bias, float16* data) { int offset = blockIdx.x * N; + for (int i = threadIdx.x; i < N; i += BlockDim) { float16 temp; temp = data[offset + i] + bias[i]; @@ -240,7 +192,6 @@ void AddReluKernel(gpuStream_t stream, } } } -#endif template void FCFunctor::operator()(const DeviceContext& context, From c747777a765a842c449374c6d6f0468357d537d2 Mon Sep 17 00:00:00 2001 From: ming1753 Date: Fri, 22 Jul 2022 12:22:26 +0000 Subject: [PATCH 4/4] delete half --- paddle/phi/kernels/funcs/fc_functor.cu | 26 +------------------------- 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu index 334e41e0c71a5..d50bec2f635e7 100644 --- a/paddle/phi/kernels/funcs/fc_functor.cu +++ b/paddle/phi/kernels/funcs/fc_functor.cu @@ -39,6 +39,7 @@ struct FcTypeTraits { struct float16_4 { float16 x, y, z, w; }; + template <> struct FcTypeTraits { typedef float16_4 Type; @@ -116,31 +117,6 @@ void AddReluKernel( } } -template -__global__ void bias_relu_v4(const int num, - const float16_4* bias, - float16_4* data, - int K) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < num) { - int bias_idx = tid % K; - const float16_4 bias_ptr = bias[bias_idx]; - const float16_4 in_ptr = data[tid]; - float16_4 packed_val; - packed_val.x = in_ptr.x + bias_ptr.x; - packed_val.y = in_ptr.y + bias_ptr.y; - packed_val.z = in_ptr.z + bias_ptr.z; - packed_val.w = in_ptr.w + bias_ptr.w; - if (DoRelu) { - packed_val.x = fmaxf(0.f, packed_val.x); - packed_val.y = fmaxf(0.f, packed_val.y); - packed_val.z = fmaxf(0.f, packed_val.z); - packed_val.w = fmaxf(0.f, packed_val.w); - } - data[tid] = packed_val; - } -} - template __global__ void InplaceAddReluKernel(const int N, const float16* bias,