-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support different data type between input and output #32823
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,77 +49,81 @@ int GetVectorizedSizeImpl(const T *pointer) { | |
return 1; | ||
} | ||
|
||
template <typename T> | ||
template <typename InT, typename OutT> | ||
int GetVectorizedSize(const std::vector<const framework::Tensor *> &ins, | ||
const std::vector<framework::Tensor *> &outs) { | ||
int vec_size = 4; | ||
for (auto iter = ins.begin(); iter != ins.end(); ++iter) { | ||
vec_size = | ||
std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<T>())); | ||
std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<InT>())); | ||
} | ||
for (auto iter = outs.begin(); iter != outs.end(); ++iter) { | ||
vec_size = | ||
std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<T>())); | ||
std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<OutT>())); | ||
} | ||
return vec_size; | ||
} | ||
|
||
template <ElementwiseType ET, int VecSize, typename T> | ||
template <ElementwiseType ET, int VecSize, typename InT, typename OutT> | ||
struct ElementwiseDataWrapper { | ||
T *out; | ||
const T *in0; | ||
const T *in1; | ||
__device__ ElementwiseDataWrapper(T *out, const T *in0, | ||
const T *in1 = nullptr) | ||
OutT *out; | ||
const InT *in0; | ||
const InT *in1; | ||
__device__ ElementwiseDataWrapper(OutT *out, const InT *in0, | ||
const InT *in1 = nullptr) | ||
: out(out), in0(in0), in1(in1) {} | ||
|
||
using VecType = CudaAlignedVector<T, VecSize>; | ||
using InVecType = CudaAlignedVector<InT, VecSize>; | ||
using OutVecType = CudaAlignedVector<OutT, VecSize>; | ||
|
||
inline __device__ void load_vector(VecType args[], int idx) { | ||
const VecType *x_vec = reinterpret_cast<const VecType *>(in0); | ||
inline __device__ void load_vector(InVecType args[], int idx) { | ||
const InVecType *x_vec = reinterpret_cast<const InVecType *>(in0); | ||
args[0] = x_vec[idx]; | ||
if (ET == ElementwiseType::kBinary) { | ||
const VecType *y_vec = reinterpret_cast<const VecType *>(in1); | ||
const InVecType *y_vec = reinterpret_cast<const InVecType *>(in1); | ||
args[1] = y_vec[idx]; | ||
} | ||
} | ||
|
||
inline __device__ void load_scalar(T args[], int idx) { | ||
inline __device__ void load_scalar(InT args[], int idx) { | ||
args[0] = in0[idx]; | ||
if (ET == ElementwiseType::kBinary) { | ||
args[1] = in1[idx]; | ||
} | ||
} | ||
|
||
inline __device__ void store_vector(VecType res, int idx) { | ||
VecType *out_vec = reinterpret_cast<VecType *>(out); | ||
inline __device__ void store_vector(OutVecType res, int idx) { | ||
OutVecType *out_vec = reinterpret_cast<OutVecType *>(out); | ||
out_vec[idx] = res; | ||
} | ||
|
||
inline __device__ void store_scalar(T res, int idx) { out[idx] = res; } | ||
inline __device__ void store_scalar(OutT res, int idx) { out[idx] = res; } | ||
}; | ||
|
||
template <ElementwiseType ET, int VecSize, typename T, typename Functor> | ||
template <ElementwiseType ET, int VecSize, typename InT, typename OutT, | ||
typename Functor> | ||
__device__ void VectorizedKernelImpl( | ||
ElementwiseDataWrapper<ET, VecSize, T> data, Functor func, int tid) { | ||
using VecType = CudaAlignedVector<T, VecSize>; | ||
VecType ins_vec[ET]; | ||
VecType out_vec; | ||
T *ins_ptr[ET]; | ||
T *out_ptr; | ||
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func, | ||
int tid) { | ||
using InVecType = CudaAlignedVector<InT, VecSize>; | ||
using OutVecType = CudaAlignedVector<OutT, VecSize>; | ||
InVecType ins_vec[ET]; | ||
OutVecType out_vec; | ||
InT *ins_ptr[ET]; | ||
OutT *out_ptr; | ||
#pragma unroll | ||
for (int i = 0; i < ET; ++i) { | ||
ins_ptr[i] = reinterpret_cast<T *>(&(ins_vec[i])); | ||
ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i])); | ||
} | ||
out_ptr = reinterpret_cast<T *>(&out_vec); | ||
out_ptr = reinterpret_cast<OutT *>(&out_vec); | ||
|
||
// load | ||
data.load_vector(ins_vec, tid); | ||
|
||
// compute | ||
#pragma unroll | ||
for (int i = 0; i < VecSize; ++i) { | ||
T ins[ET]; | ||
InT ins[ET]; | ||
#pragma unroll | ||
for (int j = 0; j < ET; ++j) { | ||
ins[j] = ins_ptr[j][i]; | ||
|
@@ -131,11 +135,13 @@ __device__ void VectorizedKernelImpl( | |
data.store_vector(out_vec, tid); | ||
} | ||
|
||
template <ElementwiseType ET, int VecSize, typename T, typename Functor> | ||
__device__ void ScalarKernelImpl(ElementwiseDataWrapper<ET, VecSize, T> data, | ||
Functor func, int start, int remain) { | ||
T ins[ET]; | ||
T out; | ||
template <ElementwiseType ET, int VecSize, typename InT, typename OutT, | ||
typename Functor> | ||
__device__ void ScalarKernelImpl( | ||
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func, | ||
int start, int remain) { | ||
InT ins[ET]; | ||
OutT out; | ||
|
||
for (int i = 0; i < remain; ++i) { | ||
int idx = start + i; | ||
|
@@ -148,45 +154,47 @@ __device__ void ScalarKernelImpl(ElementwiseDataWrapper<ET, VecSize, T> data, | |
} | ||
} | ||
|
||
template <ElementwiseType ET, int VecSize, typename T, typename Functor> | ||
__global__ void VectorizedKernel(const T *__restrict__ in0, | ||
const T *__restrict__ in1, T *out, int size, | ||
Functor func) { | ||
template <ElementwiseType ET, int VecSize, typename InT, typename OutT, | ||
typename Functor> | ||
__global__ void VectorizedKernel(const InT *__restrict__ in0, | ||
const InT *__restrict__ in1, OutT *out, | ||
int size, Functor func) { | ||
int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
int remain = size - VecSize * tid; | ||
remain = remain > 0 ? remain : 0; | ||
auto data = ElementwiseDataWrapper<ET, VecSize, T>(out, in0, in1); | ||
auto data = ElementwiseDataWrapper<ET, VecSize, InT, OutT>(out, in0, in1); | ||
if (remain >= VecSize) { | ||
VectorizedKernelImpl(data, func, tid); | ||
} else { | ||
ScalarKernelImpl(data, func, tid * VecSize, remain); | ||
} | ||
} | ||
|
||
template <ElementwiseType ET, typename T, typename Functor> | ||
__global__ void ScalarKernel(const T *__restrict__ in0, | ||
const T *__restrict__ in1, T *out, int size, | ||
template <ElementwiseType ET, typename InT, typename OutT, typename Functor> | ||
__global__ void ScalarKernel(const InT *__restrict__ in0, | ||
const InT *__restrict__ in1, OutT *out, int size, | ||
Functor func) { | ||
auto data = ElementwiseDataWrapper<ET, 1, T>(out, in0, in1); | ||
auto data = ElementwiseDataWrapper<ET, 1, InT, OutT>(out, in0, in1); | ||
int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
int remain = tid < size ? 1 : 0; | ||
ScalarKernelImpl(data, func, tid, remain); | ||
} | ||
|
||
template <ElementwiseType ET, typename T, typename Functor> | ||
template <ElementwiseType ET, typename InT, typename OutT, typename Functor> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以设置默认值 |
||
void LaunchElementwiseCudaKernel( | ||
const platform::CUDADeviceContext &ctx, | ||
const std::vector<const framework::Tensor *> &ins, | ||
std::vector<framework::Tensor *> *outs, Functor func) { | ||
// calculate the max vec_size for all ins and outs | ||
auto size = ins[0]->numel(); | ||
int vec_size = GetVectorizedSize<T>(ins, *outs); | ||
int vec_size = GetVectorizedSize<InT, OutT>(ins, *outs); | ||
int block_size = ELEMENTWISE_BLOCK_SIZE; | ||
int grid_size = | ||
((size + vec_size - 1) / vec_size + block_size - 1) / block_size; | ||
const T *in0 = ins[0]->data<T>(); | ||
const T *in1 = (ET == ElementwiseType::kBinary) ? ins[1]->data<T>() : nullptr; | ||
T *out = (*outs)[0]->data<T>(); | ||
const InT *in0 = ins[0]->data<InT>(); | ||
const InT *in1 = | ||
(ET == ElementwiseType::kBinary) ? ins[1]->data<InT>() : nullptr; | ||
OutT *out = (*outs)[0]->data<OutT>(); | ||
// cuda kernel | ||
auto stream = ctx.stream(); | ||
switch (vec_size) { | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个循环的写法其实可以简写成如下: