Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DCU] fix compile error, test=develop #61872

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions cmake/generic.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -763,12 +763,6 @@ function(hip_library TARGET_NAME)
cmake_parse_arguments(hip_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
if(hip_library_SRCS)
# FindHIP.cmake defined hip_add_library, HIP_SOURCE_PROPERTY_FORMAT is requried if no .cu files found
if(NOT (${CMAKE_CURRENT_SOURCE_DIR} MATCHES ".*/operators"
OR ${CMAKE_CURRENT_SOURCE_DIR} MATCHES ".*/phi/kernels"))
set_source_files_properties(${hip_library_SRCS}
PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
endif()
if(hip_library_SHARED OR hip_library_shared) # build *.so
hip_add_library(${TARGET_NAME} SHARED ${hip_library_SRCS})
else()
Expand All @@ -782,6 +776,10 @@ function(hip_library TARGET_NAME)
endif()
# cpplint code style
foreach(source_file ${hip_library_SRCS})
if(NOT ${source_file} MATCHES "\\.cu$")
set_source_files_properties(${source_file}
PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
endif()
string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file})
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
list(APPEND hip_library_HEADERS
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,11 @@ if(WITH_GPU)
DEPS ${PHI_DEPS})

elseif(WITH_ROCM)
hip_add_library(phi ${PHI_BUILD_TYPE} ${PHI_SRCS})
target_link_libraries(phi ${PHI_DEPS})
hip_library(
phi ${PHI_BUILD_TYPE}
SRCS ${PHI_SRCS}
DEPS ${PHI_DEPS})

elseif(WITH_XPU_KP)
xpu_library(
phi ${PHI_BUILD_TYPE}
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/core/visit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ namespace phi {
"`"); \
} \
}()
#if defined(PADDLE_WITH_XPU)
#if defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_HIP)
#define PD_VISIT_ALL_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
Expand Down
26 changes: 26 additions & 0 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,32 @@ if(NOT WITH_CUDNN_FRONTEND)
"fusion/gpu/max_pool2d_v2_kernel.cu")
endif()

# Note(qili93): remove kernels not supported on DCU yet
if(WITH_ROCM)
list(
REMOVE_ITEM
kernel_cu
"gpu/affine_grid_grad_kernel.cu"
"gpu/apply_per_channel_scale_kernel.cu"
"gpu/cholesky_solve_kernel.cu"
"gpu/eigh_kernel.cu"
"gpu/eigvalsh_kernel.cu"
"gpu/lstsq_kernel.cu"
"gpu/lu_kernel.cu"
"gpu/matrix_rank_kernel.cu"
"gpu/matrix_rank_tol_kernel.cu"
"gpu/multiclass_nms3_kernel.cu"
"gpu/put_along_axis_grad_kernel.cu"
"gpu/put_along_axis_kernel.cu"
"gpu/qr_kernel.cu"
"gpu/svd_kernel.cu"
"gpudnn/mha_cudnn_frontend.cu"
"fusion/gpu/block_multi_head_attention_kernel.cu"
"fusion/gpu/fused_bn_add_activation_grad_kernel.cu"
"fusion/gpu/fused_bn_add_activation_kernel.cu"
"fusion/gpu/fusion_transpose_flatten_concat_kernel.cu")
endif()

set(cc_search_pattern
"*.cc"
"cpu/*.cc"
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/funcs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,9 @@ if(WITH_GPU OR WITH_ROCM)
"*.cu")
endif()

# Note(qili93): remove kernels not supported on DCU yet
if(WITH_ROCM)
list(REMOVE_ITEM func_cu_srcs "weight_only_gemv.cu")
endif()

collect_srcs(kernels_srcs SRCS ${func_cc_srcs} ${func_cu_srcs})
9 changes: 9 additions & 0 deletions paddle/phi/kernels/gpu/binomial_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,13 @@ __device__ int64_t btrs(
const T m = std::floor((n + 1) * p);

while (1) {
#ifdef __NVCC__
U = static_cast<T>(curand_uniform(&state)) - 0.5;
V = static_cast<T>(curand_uniform(&state));
#elif __HIPCC__
U = static_cast<T>(hiprand_uniform(&state)) - 0.5;
V = static_cast<T>(hiprand_uniform(&state));
#endif

us = 0.5 - std::abs(U);
k = static_cast<int64_t>(std::floor((2 * a / us + b) * U + c));
Expand Down Expand Up @@ -118,7 +123,11 @@ __device__ int64_t binomial_inversion(
#endif

while (1) {
#ifdef __NVCC__
unif = static_cast<T>(curand_uniform(&state));
#elif __HIPCC__
unif = static_cast<T>(hiprand_uniform(&state));
#endif
T geom = std::ceil(std::log(unif) / logprob);
geom_sum += geom;
if (geom_sum > n) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/top_p_sampling_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
if (*beam >= MaxLength) break;
} else {
#ifdef PADDLE_WITH_HIP
uint64 mask = 0;
unsigned mask = 0u;
mask = __ballot(true);
if (tid_max / WARP_SIZE == wid) {
if (__shfl_down(*beam, tid_max % WARP_SIZE, WARP_SIZE) == MaxLength)
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h"

#ifdef PADDLE_WITH_CUDA
PD_REGISTER_KERNEL(triangular_solve_grad,
GPU,
ALL_LAYOUT,
Expand All @@ -23,3 +24,12 @@ PD_REGISTER_KERNEL(triangular_solve_grad,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#else // PADDLE_WITH_HIP
// blas_impl.hip.h not support CUBlas<T>::TRSM for complex
PD_REGISTER_KERNEL(triangular_solve_grad,
GPU,
ALL_LAYOUT,
phi::TriangularSolveGradKernel,
float,
double) {}
#endif
10 changes: 10 additions & 0 deletions paddle/phi/kernels/gpu/triangular_solve_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ void TriangularSolveKernel(const Context& dev_ctx,

} // namespace phi

#ifdef PADDLE_WITH_CUDA
PD_REGISTER_KERNEL(triangular_solve,
GPU,
ALL_LAYOUT,
Expand All @@ -131,3 +132,12 @@ PD_REGISTER_KERNEL(triangular_solve,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#else // PADDLE_WITH_HIP
// blas_impl.hip.h not support CUBlas<T>::TRSM for complex
PD_REGISTER_KERNEL(triangular_solve,
GPU,
ALL_LAYOUT,
phi::TriangularSolveKernel,
float,
double) {}
#endif
5 changes: 5 additions & 0 deletions paddle/phi/kernels/gpu/unique_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
#include <iostream>
#include <vector>

#ifdef PADDLE_WITH_CUDA
#include "cub/cub.cuh"
#else
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
Expand Down
12 changes: 10 additions & 2 deletions paddle/phi/kernels/gpudnn/pool_gpudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,21 @@ class CudnnIndexType;
template <>
class CudnnIndexType<int> {
public:
static const cudnnDataType_t type = CUDNN_DATA_INT32;
#ifdef PADDLE_WITH_CUDA
static const dnnDataType_t type = CUDNN_DATA_INT32;
#else
static const dnnDataType_t type = miopenInt32;
#endif
};

template <>
class CudnnIndexType<int8_t> {
public:
static const cudnnDataType_t type = CUDNN_DATA_INT8;
#ifdef PADDLE_WITH_CUDA
static const dnnDataType_t type = CUDNN_DATA_INT8;
#else
static const dnnDataType_t type = miopenInt8;
#endif
};

inline GPUDNNDataLayout GetLayoutFromStr(std::string data_format) {
Expand Down