Skip to content

Commit

Permalink
support broadcast case for int64 index
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Sep 7, 2022
1 parent 55c1543 commit b7b0100
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 12 deletions.
11 changes: 6 additions & 5 deletions paddle/fluid/platform/device/gpu/cuda/cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ namespace platform {
*
*/

#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \
int64_t __index__ = \
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x; \
for (index_type i = __index__; __index__ < (num); \
__index__ += blockDim.x * gridDim.x, i = __index__)
#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \
int64_t __index__ = \
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x; \
int64_t __stride__ = static_cast<int64_t>(blockDim.x) * gridDim.x; \
for (index_type i = __index__; __index__ < (num); \
__index__ += __stride__, i = __index__)

class CublasHandleHolder {
public:
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/platform/device/gpu/rocm/rocm_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ namespace platform {
#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \
int64_t __index__ = \
static_cast<int64_t>(hipBlockIdx_x) * hipBlockDim_x + hipThreadIdx_x; \
int64_t __stride__ = static_cast<int64_t>(hipBlockDim_x) * hipGridDim_x; \
for (index_type i = __index__; __index__ < (num); \
__index__ += hipBlockDim_x * hipGridDim_x, i = __index__)
__index__ += __stride__, i = __index__)

class CublasHandleHolder {
public:
Expand Down
11 changes: 6 additions & 5 deletions paddle/phi/backends/gpu/cuda/cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,12 @@ namespace gpu {
*
*/

#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \
int64_t __index__ = \
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x; \
for (index_type i = __index__; __index__ < (num); \
__index__ += blockDim.x * gridDim.x, i = __index__)
#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \
int64_t __index__ = \
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x; \
int64_t __stride__ = static_cast<int64_t>(blockDim.x) * gridDim.x; \
for (index_type i = __index__; __index__ < (num); \
__index__ += __stride__, i = __index__)

} // namespace gpu
} // namespace backends
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/backends/gpu/rocm/rocm_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ namespace gpu {
#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \
int64_t __index__ = \
static_cast<int64_t>(hipBlockIdx_x) * hipBlockDim_x + hipThreadIdx_x; \
int64_t __stride__ = static_cast<int64_t>(hipBlockDim_x) * hipGridDim_x; \
for (index_type i = __index__; __index__ < (num); \
__index__ += hipBlockDim_x * hipGridDim_x, i = __index__)
__index__ += __stride__, i = __index__)

} // namespace gpu
} // namespace backends
Expand Down
286 changes: 286 additions & 0 deletions paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ struct DimensionsTransform {
};

template <typename InT, typename OutT, int NumOuts = 1>

int GetVecsize(const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs) {
int in_vec_size = 4;
Expand Down Expand Up @@ -468,6 +469,233 @@ void LaunchBroadcastKernel(
func);
}

#ifndef PADDLE_WITH_XPU_KP
HOSTDEVICE static int64_t ConvertSrcIdxToDstIdx(
int64_t src_idx,
const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides,
const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &dst_strides,
int rank) {
int64_t dst_idx = 0;
int64_t old_src_idx = src_idx;
for (int k = 0; k < rank; ++k) {
auto local_idx = src_idx / src_strides[k + 1];
src_idx -= local_idx * src_strides[k + 1];

if (dst_strides[k] != dst_strides[k + 1]) {
dst_idx += local_idx * dst_strides[k + 1];
}
}
return dst_idx;
}

template <typename T, int VecSize>
HOSTDEVICE static void ReadVecDataWithInt64Index(
const T *in,
int64_t idx,
const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides,
const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &dst_strides,
int rank,
phi::AlignedVector<T, VecSize> *out) {
if (src_strides[0] == dst_strides[0]) {
phi::Load<T, VecSize>(in + idx, out);
} else {
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
(*out)[i] =
in[ConvertSrcIdxToDstIdx(idx + i, src_strides, dst_strides, rank)];
}
}
}

template <typename InT, typename OutT, typename Functor, int VecSize>
__global__ void BinaryBroadcastKernelWithInt64Index(
const InT *x,
const InT *y,
OutT *z,
phi::Array<int64_t, phi::DDim::kMaxRank + 1> x_strides,
phi::Array<int64_t, phi::DDim::kMaxRank + 1> y_strides,
phi::Array<int64_t, phi::DDim::kMaxRank + 1> z_strides,
int rank,
Functor functor) {
int64_t numel = z_strides[0];
int64_t idx =
(static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x) * VecSize;
int64_t stride = static_cast<int64_t>(blockDim.x) * gridDim.x * VecSize;
int64_t limit = numel - VecSize;

for (; idx <= limit; idx += stride) {
phi::AlignedVector<InT, VecSize> x_vec, y_vec;
phi::AlignedVector<OutT, VecSize> z_vec;
ReadVecDataWithInt64Index<InT, VecSize>(
x, idx, z_strides, x_strides, rank, &x_vec);
ReadVecDataWithInt64Index<InT, VecSize>(
y, idx, z_strides, y_strides, rank, &y_vec);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
z_vec[i] = functor(x_vec[i], y_vec[i]);
}
phi::Store<OutT, VecSize>(z_vec, z + idx);
}

for (; idx < numel; ++idx) {
z[idx] = functor(x[idx], y[idx]);
}
}

template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper {
static void Run(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
PADDLE_THROW(phi::errors::PermissionDenied(
"Unreachable code branch. This may be a bug."));
}
};

template <typename InT, typename OutT, typename Functor, int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper<InT,
OutT,
Functor,
/*Arity=*/2,
/*NumOuts=*/1,
VecSize> {
static void Run(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
const auto *x = ins[0], *y = ins[1];
auto *z = (*outs)[0];
const auto *x_data = x->data<InT>();
const auto *y_data = y->data<InT>();
auto *z_data = ctx.template Alloc<OutT>(z);

phi::Array<int64_t, phi::DDim::kMaxRank> x_out_dims, y_out_dims,
broadcast_out_dims;
int rank;
CalculateBroadcastDims(x->dims(),
y->dims(),
axis,
&x_out_dims,
&y_out_dims,
&broadcast_out_dims,
&rank);

auto x_strides = ShapeToStride(x_out_dims, rank);
auto y_strides = ShapeToStride(y_out_dims, rank);
auto z_strides = ShapeToStride(broadcast_out_dims, rank);
int64_t numel = z_strides[0];
auto gpu_config =
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);

BinaryBroadcastKernelWithInt64Index<InT, OutT, Functor, VecSize>
<<<gpu_config.block_per_grid,
gpu_config.thread_per_block,
0,
ctx.stream()>>>(x_data,
y_data,
z_data,
x_strides,
y_strides,
z_strides,
rank,
func);
}

private:
static void CalculateBroadcastDims(
const phi::DDim &x_dims,
const phi::DDim &y_dims,
int axis,
phi::Array<int64_t, phi::DDim::kMaxRank> *x_out_dims,
phi::Array<int64_t, phi::DDim::kMaxRank> *y_out_dims,
phi::Array<int64_t, phi::DDim::kMaxRank> *broadcast_out_dims,
int *length) {
int nx = x_dims.size(), ny = y_dims.size();
PADDLE_ENFORCE_GE(
axis, 0, phi::errors::InvalidArgument("Invalid axis value: %d", axis));
if (nx == ny) {
*length = nx;
for (int i = 0; i < nx; ++i) {
if (x_dims[i] != y_dims[i]) {
PADDLE_ENFORCE_EQ(
x_dims[i] == 1 || y_dims[i] == 1,
true,
phi::errors::InvalidArgument("Cannot broadcast input shape where "
"x_dims[%d] = %d, y_dims[%d] = %d.",
i,
x_dims[i],
i,
y_dims[i]));
}
(*broadcast_out_dims)[i] = std::max(x_dims[i], y_dims[i]);
(*x_out_dims)[i] = x_dims[i];
(*y_out_dims)[i] = y_dims[i];
}
} else if (nx > ny) {
*length = nx;
for (int i = nx - axis; i < ny; ++i) {
PADDLE_ENFORCE_EQ(
y_dims[i],
1,
phi::errors::InvalidArgument(
"The trailing Y.shape[%d] should be 1 but got %d.",
i,
y_dims[i]));
}

for (int i = 0; i < nx; ++i) {
if (i >= axis && i - axis < ny) {
if (x_dims[i] != y_dims[i - axis]) {
PADDLE_ENFORCE_EQ(x_dims[i] == 1 || y_dims[i - axis] == 1,
true,
phi::errors::InvalidArgument(
"Cannot broadcast input shape where "
"x_dims[%d] = %d, y_dims[%d] = %d.",
i,
x_dims[i],
i - axis,
y_dims[i - axis]));
}
(*broadcast_out_dims)[i] = std::max(x_dims[i], y_dims[i - axis]);
(*x_out_dims)[i] = x_dims[i];
(*y_out_dims)[i] = y_dims[i - axis];
} else {
(*broadcast_out_dims)[i] = x_dims[i];
(*x_out_dims)[i] = x_dims[i];
(*y_out_dims)[i] = 1;
}
}
} else {
CalculateBroadcastDims(y_dims,
x_dims,
axis,
y_out_dims,
x_out_dims,
broadcast_out_dims,
length);
}
}

static phi::Array<int64_t, phi::DDim::kMaxRank + 1> ShapeToStride(
const phi::Array<int64_t, phi::DDim::kMaxRank> &arr, int rank) {
phi::Array<int64_t, phi::DDim::kMaxRank + 1> strides;
strides[rank] = 1;
for (int i = rank - 1; i >= 0; --i) {
strides[i] = strides[i + 1] * arr[i];
}
return strides;
}
};
#endif

template <ElementwiseType ET,
typename InT,
typename OutT,
Expand Down Expand Up @@ -509,6 +737,64 @@ void BroadcastKernelForDifferentVecSize(
outs->size(),
NumOuts));

#ifndef PADDLE_WITH_XPU_KP
constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity == 2);
bool use_int64_index_kernel =
kEnabledInt64IndexKernel &&
(*outs)[0]->numel() >= std::numeric_limits<int32_t>::max();
use_int64_index_kernel = kEnabledInt64IndexKernel;
if (use_int64_index_kernel) {
int vec_size = GetVecsize<InT, OutT, NumOuts>(ins, outs);
switch (vec_size) {
case VecSizeL: {
LaunchBroadcastKernelWithInt64IndexHelper<InT,
OutT,
Functor,
kArity,
NumOuts,
VecSizeL>::Run(ctx,
ins,
outs,
axis,
func);
break;
}
case VecSizeM: {
LaunchBroadcastKernelWithInt64IndexHelper<InT,
OutT,
Functor,
kArity,
NumOuts,
VecSizeM>::Run(ctx,
ins,
outs,
axis,
func);
break;
}
case VecSizeS: {
LaunchBroadcastKernelWithInt64IndexHelper<InT,
OutT,
Functor,
kArity,
NumOuts,
VecSizeS>::Run(ctx,
ins,
outs,
axis,
func);
break;
}
default: {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported vectorized size: %d!", vec_size));
break;
}
}
return;
}
#endif

// mergedim and get vec_size
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
phi::Array<kps::details::BroadcastConfig, kArity> configs;
Expand Down

0 comments on commit b7b0100

Please sign in to comment.