Skip to content

Commit

Permalink
Optimize sparse convolution (#43576)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangkaihuo committed Jul 26, 2022
1 parent 22342d5 commit 9841b30
Show file tree
Hide file tree
Showing 27 changed files with 1,474 additions and 345 deletions.
18 changes: 9 additions & 9 deletions paddle/phi/api/yaml/sparse_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@
data_type : x
backward : cast_grad

- api : conv3d
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
output : Tensor(out), Tensor(rulebook)
- api : conv3d_coo
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key)
output : Tensor(out), Tensor(rulebook), Tensor(counter)
kernel :
func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense}
func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense, dense}
layout : x
intermediate : rulebook
backward : conv3d_grad
intermediate: rulebook, counter
backward : conv3d_coo_grad

- api : coo_to_dense
args : (Tensor x)
Expand Down Expand Up @@ -352,11 +352,11 @@

- api: maxpool
args : (Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides)
output : Tensor(out), Tensor(rulebook)
output : Tensor(out), Tensor(rulebook), Tensor(counter)
kernel :
func : maxpool_coo{sparse_coo -> sparse_coo, dense}
func : maxpool_coo{sparse_coo -> sparse_coo, dense, dense}
layout : x
intermediate : rulebook
intermediate : rulebook, counter
backward : maxpool_grad

- api: mv
Expand Down
14 changes: 7 additions & 7 deletions paddle/phi/api/yaml/sparse_bw_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@
cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
data_type : out_grad

- backward_api : conv3d_grad
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
- backward_api : conv3d_coo_grad
forward : conv3d_coo (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter)
args : (Tensor x, Tensor kernel, Tensor out, Tensor rulebook, Tensor counter, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key)
output : Tensor(x_grad), Tensor(kernel_grad)
kernel :
func : conv3d_coo_grad{sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense}
func : conv3d_coo_grad{sparse_coo, dense, sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense}

- backward_api : coo_to_dense_grad
forward : coo_to_dense(Tensor x) -> Tensor(out)
Expand Down Expand Up @@ -164,11 +164,11 @@
matmul_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}

- backward_api : maxpool_grad
forward : maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook)
args : (Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes)
forward : maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook), Tensor(counter)
args : (Tensor x, Tensor rulebook, Tensor counter, Tensor out, Tensor out_grad, int[] kernel_sizes)
output : Tensor(x_grad)
kernel :
func : maxpool_coo_grad {sparse_coo, dense, sparse_coo, sparse_coo -> sparse_coo}
func : maxpool_coo_grad {sparse_coo, dense, dense, sparse_coo, sparse_coo -> sparse_coo}

- backward_api : multiply_grad
forward : multiply(Tensor x, Tensor y) -> Tensor(out)
Expand Down
50 changes: 50 additions & 0 deletions paddle/phi/core/sparse_coo_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,48 @@ class SparseCooTensor : public TensorBase,
/// \brief get the dnese dim
int32_t dense_dim() const;

/// \brief query table according to key
const std::pair<DenseTensor, DenseTensor>* IndicesPairs(
const std::string& key) const {
if (indices_dict_ == nullptr) {
return nullptr;
}
const auto& iter = indices_dict_->find(key);
if (iter == indices_dict_->end()) {
return nullptr;
}
return &iter->second;
}

/// \brief save (key, indices_pairs)
void SaveIndicesPairs(
const std::string& key,
const std::pair<DenseTensor, DenseTensor>& indices_pairs) {
if (indices_dict_ == nullptr) {
indices_dict_ = std::make_shared<
std::map<std::string, std::pair<DenseTensor, DenseTensor>>>();
}
auto ret = indices_dict_->insert({key, indices_pairs});
if (ret.second == false) {
ret.first->second = indices_pairs;
}
}

/// \brief get indices_dict_
const std::shared_ptr<
std::map<std::string, std::pair<DenseTensor, DenseTensor>>>&
GetIndicesDict() const {
return indices_dict_;
}

/// \brief set indices_dict_
void SetIndicesDict(
const std::shared_ptr<
std::map<std::string, std::pair<DenseTensor, DenseTensor>>>&
indices_dict) {
indices_dict_ = indices_dict;
}

private:
// save the indices of non zero elements in original dense tensor
DenseTensor non_zero_indices_;
Expand All @@ -165,6 +207,14 @@ class SparseCooTensor : public TensorBase,
bool coalesced_ = false;
// save the number of non zero elements in each batch
DDim dims_;

// for submanifold conv
// SubmConv will generate a rulebook and a counter, which can be
// reused by different SubmConv.
// refer to sparse/gpu/convolution_kernel.cu.
std::shared_ptr<std::map<std::string, std::pair<DenseTensor, DenseTensor>>>
indices_dict_ = nullptr;

/* --------------------------- */
/* example: non zero element is scalar */
/* --------------------------- */
Expand Down
83 changes: 83 additions & 0 deletions paddle/phi/kernels/funcs/sparse/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"

namespace phi {
Expand Down Expand Up @@ -188,6 +189,88 @@ inline void PrefixSum(const T* counter, T* offsets, const int n) {
offsets[n] = offset;
}

template <typename IntT>
inline const IntT* GetRulebookPtr(const SparseCooTensor& coo,
const DenseTensor& rulebook,
const std::string& key,
int* rulebook_len) {
if (!key.empty()) {
const auto* indices_pairs = coo.IndicesPairs(key);
if (indices_pairs != nullptr) {
const DenseTensor& tmp_rulebook = indices_pairs->first;
*rulebook_len = tmp_rulebook.dims()[1];
return tmp_rulebook.data<IntT>();
}
}
*rulebook_len = rulebook.dims()[1];
return rulebook.data<IntT>();
}

inline const int* GetCounterPtr(const SparseCooTensor& coo,
const DenseTensor& counter,
const std::string& key) {
if (!key.empty()) {
const auto* indices_pairs = coo.IndicesPairs(key);
if (indices_pairs != nullptr) {
return indices_pairs->second.data<int>();
}
}
return counter.data<int>();
}

template <typename T, typename IntT, typename Context>
inline const IntT* PrepareSubm(const Context& dev_ctx,
const SparseCooTensor& x,
const std::string& key,
const DDim& out_dims,
SparseCooTensor* out,
int* counter,
int* offsets,
int* rulebook_len,
bool* need_product_rulebook) {
const auto* indices_pairs = x.IndicesPairs(key);
if (indices_pairs != nullptr) {
*need_product_rulebook = false;
const DenseTensor& rulebook = indices_pairs->first;
const int counter_size = indices_pairs->second.numel();
memcpy(
counter, indices_pairs->second.data<int>(), counter_size * sizeof(int));
out->SetIndicesDict(x.GetIndicesDict());

*rulebook_len = rulebook.dims()[1];

DenseTensor out_indices =
phi::EmptyLike<IntT>(dev_ctx, x.non_zero_indices());
DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
phi::Copy(
dev_ctx, x.non_zero_indices(), dev_ctx.GetPlace(), false, &out_indices);
out->SetMember(out_indices, out_values, out_dims, false);
PrefixSum<int>(counter, offsets, counter_size);
return rulebook.data<IntT>();
}
return nullptr;
}

template <typename Context>
inline void SaveToTable(const Context& dev_ctx,
const SparseCooTensor& x,
const std::string& key,
const DenseTensor& in_rulebook,
const DenseTensor& h_counter,
SparseCooTensor* out,
DenseTensor* out_rulebook,
DenseTensor* counter) {
out->SetIndicesDict(x.GetIndicesDict());
if (!key.empty()) {
out->SaveIndicesPairs(key, std::make_pair(in_rulebook, h_counter));
} else {
*out_rulebook = in_rulebook;
counter->Resize({h_counter.numel()});
int* counter_ptr = dev_ctx.template HostAlloc<int>(counter);
memcpy(counter_ptr, h_counter.data<int>(), h_counter.numel() * sizeof(int));
}
}

} // namespace sparse
} // namespace funcs
} // namespace phi
122 changes: 110 additions & 12 deletions paddle/phi/kernels/funcs/sparse/scatter.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"

#define VecBytes 16

namespace phi {
namespace funcs {
Expand All @@ -28,33 +33,126 @@ namespace sparse {
* channels: the output channel size
* out: the outputs
**/
template <typename T>
template <typename T, int VecSize>
__global__ void ScatterKernel(const T* input,
const int* unique_value,
const int* out_index,
const int non_zero_num,
const int rulebook_len,
const int channels,
T* out,
const bool subm = false) {
T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) {
int indices_i = i / channels;
int channels_i = i - indices_i * channels;
const int vec_channels = channels / VecSize;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
for (int i = tid; i < non_zero_num * vec_channels;
i += gridDim.x * blockDim.x) {
int indices_i = i / vec_channels;
int channels_i = i - indices_i * vec_channels;

int start = unique_value[indices_i];
int end = indices_i == non_zero_num - 1 ? rulebook_len
: unique_value[indices_i + 1];
// max(end-start) = kernel_size
T sum = static_cast<T>(0);
if (subm) {
sum = out[indices_i * channels + channels_i];
}
StoreT sums = {static_cast<T>(0)};
for (int j = start; j < end; j++) {
const int out_feature_i = out_index[j];
sum += input[out_feature_i * channels + channels_i];
LoadT vec_in;
phi::Load<T, VecSize>(
input + out_feature_i * channels + channels_i * VecSize, &vec_in);
#pragma unroll
for (int k = 0; k < VecSize; k++) {
sums[k] += vec_in[k];
}
}
out[indices_i * channels + channels_i] = sum;
phi::Store<T, VecSize>(sums,
out + indices_i * channels + channels_i * VecSize);
}
}

// scatter's index has been grouped in advance
// index_counts record the count of each group
// index_groups save the index of each group
template <typename T, int VecSize>
__global__ void ScatterKernelV2(const T* input,
const int* index_counts,
const int* index_groups,
const int non_zero_num,
const int kernel_size,
const int channels,
const int buffer_counts,
T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int vec_channels = channels / VecSize;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
for (int i = tid; i < non_zero_num * vec_channels;
i += gridDim.x * blockDim.x) {
int indices_i = i / vec_channels;
int channels_i = i - indices_i * vec_channels;

StoreT sums = {static_cast<T>(0)};
phi::Load<T, VecSize>(out + indices_i * channels + channels_i * VecSize,
&sums);
for (int it = 0; it < buffer_counts; it++) {
int len = index_counts[indices_i + it * non_zero_num];
const int group_offset = it * kernel_size * non_zero_num;
for (int j = 0; j < len; j++) {
const int out_feature_i =
index_groups[indices_i * kernel_size + j + group_offset];
LoadT vec_in;
phi::Load<T, VecSize>(
input + out_feature_i * channels + channels_i * VecSize, &vec_in);
#pragma unroll
for (int k = 0; k < VecSize; k++) {
sums[k] += vec_in[k];
}
}
}
phi::Store<T, VecSize>(sums,
out + indices_i * channels + channels_i * VecSize);
}
}

template <typename T>
void ScatterV2(const GPUContext& dev_ctx,
const T* input,
const int* index_counts,
const int* index_groups,
const int non_zero_num,
const int kernel_size,
const int channels,
const int buffer_counts,
T* output) {
const int VecSize = VecBytes / sizeof(T);
if (channels % VecSize == 0) {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, non_zero_num * channels / VecSize, 1);
ScatterKernelV2<T, VecSize><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(input,
index_counts,
index_groups,
non_zero_num,
kernel_size,
channels,
buffer_counts,
output);
} else {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, non_zero_num * channels, 1);
ScatterKernelV2<T, 1><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(input,
index_counts,
index_groups,
non_zero_num,
kernel_size,
channels,
buffer_counts,
output);
}
}

Expand Down
Loading

0 comments on commit 9841b30

Please sign in to comment.