forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add New OP: gumbel_softmax (PaddlePaddle#35506)
* Add New Op: gumbel_softmax * Add New Op: gumbel_softmax * Add New Op: gumbel_softmax (amend) * add __main__ function in unit test * fix bugs when test in windows ci * update en docs * delete reletive error in unit test * delete relative error in unit test * set hard=True in unit test
- Loading branch information
1 parent
57ba97b
commit a690c42
Showing
6 changed files
with
852 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/gumbel_softmax_op.h" | ||
#include <string> | ||
#include <unordered_map> | ||
#include "paddle/fluid/operators/common_infer_shape_functions.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
class GumbelSoftmaxOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
return UnaryOpUnchangedInferShapeCheckAxis(ctx); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), | ||
ctx.device_context()); | ||
} | ||
}; | ||
|
||
class GumbelSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("X", | ||
"(Tensor) An N-D Tensor, N >= 1," | ||
"The first N - 1 dimensions index into a batch of independent " | ||
"distributions " | ||
"and the last dimension represents a vector of probabilities for " | ||
"each class."); | ||
AddOutput("Out", "The sampled tensor with the same shape as X."); | ||
AddAttr<float>("temperature", | ||
"(float, default 1.0) non-negative scalar temperature.") | ||
.SetDefault(1.0); | ||
AddAttr<bool>( | ||
"hard", | ||
"(bool, default false) " | ||
"if True, the returned samples will be discretized as one-hot vectors, " | ||
"but will be differentiated as if it is the soft sample in autograd.") | ||
.SetDefault(false); | ||
AddAttr<int>("axis", | ||
"(int, default -1)" | ||
"The dimension index of Input(x) to perform gumbel_softmax.") | ||
.SetDefault(-1); | ||
AddComment(R"DOC( | ||
GumbelSoftmax Operator. | ||
Samples from the Gumbel-Softmax distribution and optionally discretizes. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
class GumbelSoftmaxGradOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "gumbel_softmax_grad"); | ||
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", | ||
"Out@GRAD", "gumbel_softmax_grad"); | ||
PADDLE_ENFORCE_EQ( | ||
ctx->GetInputDim("Out"), | ||
ctx->GetInputDim(framework::GradVarName("Out")), | ||
platform::errors::InvalidArgument("Input(Out) and its gradients " | ||
"should have the same shape.")); | ||
|
||
ctx->SetOutputDim(framework::GradVarName("X"), | ||
ctx->GetInputDim(framework::GradVarName("Out"))); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class GumbelSoftmaxGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
protected: | ||
void Apply(GradOpPtr<T> op) const override { | ||
op->SetType("gumbel_softmax_grad"); | ||
op->SetInput("Out", this->Output("Out")); | ||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
op->SetAttrMap(this->Attrs()); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OPERATOR(gumbel_softmax, ops::GumbelSoftmaxOp, | ||
ops::GumbelSoftmaxOpMaker, | ||
ops::GumbelSoftmaxGradOpMaker<paddle::framework::OpDesc>, | ||
ops::GumbelSoftmaxGradOpMaker<paddle::imperative::OpBase>); | ||
REGISTER_OPERATOR(gumbel_softmax_grad, ops::GumbelSoftmaxGradOp); | ||
|
||
REGISTER_OP_CPU_KERNEL( | ||
gumbel_softmax, | ||
ops::GumbelSoftmaxKernel<paddle::platform::CPUDeviceContext, float>, | ||
ops::GumbelSoftmaxKernel<paddle::platform::CPUDeviceContext, double>); | ||
REGISTER_OP_CPU_KERNEL( | ||
gumbel_softmax_grad, | ||
ops::GumbelSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>, | ||
ops::GumbelSoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
#pragma once | ||
|
||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/operator.h" | ||
#include "paddle/fluid/operators/gumbel_softmax_op.h" | ||
|
||
#if defined(__NVCC__) || defined(__HIPCC__) | ||
#ifdef __NVCC__ | ||
#include "cub/cub.cuh" | ||
#endif | ||
#ifdef __HIPCC__ | ||
#include <hipcub/hipcub.hpp> | ||
namespace cub = hipcub; | ||
#endif | ||
|
||
#include <thrust/device_vector.h> | ||
#include <thrust/host_vector.h> | ||
#include <thrust/random.h> | ||
#include <thrust/transform.h> | ||
#include "paddle/fluid/framework/generator.h" | ||
#include "paddle/fluid/memory/memcpy.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename K, typename V> | ||
using KeyValuePair = cub::KeyValuePair<K, V>; | ||
|
||
template <typename T> | ||
struct UniformCUDAGenerator { | ||
T min_, max_; | ||
unsigned int seed_; | ||
unsigned int offset_ = 0; | ||
HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed) | ||
: min_(min), max_(max), seed_(seed) {} | ||
HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed, | ||
unsigned int offset) | ||
: min_(min), max_(max), seed_(seed), offset_(offset) {} | ||
|
||
HOSTDEVICE T operator()(const unsigned int n) const { | ||
thrust::minstd_rand rng; | ||
rng.seed(seed_); | ||
thrust::uniform_real_distribution<T> dist(min_, max_); | ||
rng.discard(n + offset_); | ||
return dist(rng); | ||
} | ||
}; | ||
|
||
template <typename T, size_t BlockDim> | ||
__global__ void OneHotCUDAKernel(const int64_t height, const int64_t width, | ||
const int64_t size_out_axis, const T init, | ||
const T* in, T* out) { | ||
typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce; | ||
__shared__ typename BlockReduce::TempStorage temp_storage; | ||
|
||
for (int64_t idx = blockIdx.x; idx < height; idx += gridDim.x) { | ||
KeyValuePair<int, T> kv_pair = {-1, init}; | ||
int h = idx / size_out_axis; | ||
int w = idx % size_out_axis; | ||
cub::ArgMax reducer; | ||
for (int k = threadIdx.x; k < width; k += blockDim.x) { | ||
kv_pair = reducer( | ||
{k, in[h * width * size_out_axis + k * size_out_axis + w]}, kv_pair); | ||
} | ||
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer); | ||
if (threadIdx.x == 0) { | ||
int index = static_cast<int>(kv_pair.key); | ||
out[h * width * size_out_axis + index * size_out_axis + w] = 1; | ||
} | ||
__syncthreads(); | ||
} | ||
} | ||
|
||
template <typename T> | ||
struct OneHotGenerator<platform::CUDADeviceContext, T> { | ||
static void Transform(const platform::CUDADeviceContext& context, | ||
const Tensor& X, Tensor* Out, int axis) { | ||
const int size_to_axis = SizeToAxis(axis, X.dims()); | ||
const int size_from_axis = SizeFromAxis(axis, X.dims()); | ||
const int size_out_axis = SizeOutAxis(axis, X.dims()); | ||
constexpr int thread_size = 512; | ||
int64_t max_grid_dimx = context.GetCUDAMaxGridDimSize().x; | ||
int64_t height = size_to_axis * size_out_axis; | ||
int block_size = height < max_grid_dimx ? height : max_grid_dimx; | ||
|
||
Tensor input_tensor; | ||
input_tensor.mutable_data<T>(Out->dims(), platform::CUDAPlace()); | ||
TensorCopy(*Out, context.GetPlace(), &input_tensor); | ||
math::set_constant(context, Out, 0.0); | ||
OneHotCUDAKernel< | ||
T, thread_size><<<block_size, thread_size, 0, context.stream()>>>( | ||
height, size_from_axis / size_out_axis, size_out_axis, | ||
std::numeric_limits<T>::lowest(), input_tensor.data<T>(), | ||
Out->data<T>()); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
__global__ void AddGumbelNoiseCUDAKernel(const T* input_data, T* output_data, | ||
T* noise, const float temperature, | ||
int64_t n) { | ||
int index = threadIdx.x + blockIdx.x * blockDim.x; | ||
int step = blockDim.x * gridDim.x; | ||
for (int64_t i = index; i < n; i += step) { | ||
T gumbel_noise = -log(-log(noise[i])); | ||
output_data[i] = (gumbel_noise + input_data[i]) / temperature; | ||
} | ||
} | ||
|
||
template <typename T> | ||
struct GumbleNoiseGenerator<platform::CUDADeviceContext, T> { | ||
static void Transform(const platform::CUDADeviceContext& context, | ||
const T* input_data, T* output_data, int size_to_axis, | ||
int size_from_axis, const float temperature) { | ||
Tensor random_tensor; | ||
int64_t size = size_to_axis * size_from_axis; | ||
T* random_data = | ||
random_tensor.mutable_data<T>({size}, platform::CUDAPlace()); | ||
thrust::counting_iterator<unsigned int> index_sequence_begin(0); | ||
const unsigned int seed = std::random_device()(); | ||
|
||
// generate gumbel noise | ||
int device_id = | ||
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId(); | ||
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); | ||
if (gen_cuda->GetIsInitPy()) { | ||
auto seed_offset = gen_cuda->IncrementOffset(1); | ||
int gen_offset = size * seed_offset.second; | ||
thrust::transform( | ||
index_sequence_begin, index_sequence_begin + size, | ||
thrust::device_ptr<T>(random_data), | ||
UniformCUDAGenerator<T>(0.00001, 1, seed_offset.first, gen_offset)); | ||
} else { | ||
thrust::transform(index_sequence_begin, index_sequence_begin + size, | ||
thrust::device_ptr<T>(random_data), | ||
UniformCUDAGenerator<T>(0.00001, 1, seed)); | ||
} | ||
|
||
// add gumbel noise to X | ||
const int thread_size = 512; | ||
int64_t block_size = (size + thread_size) / thread_size; | ||
AddGumbelNoiseCUDAKernel< | ||
T><<<block_size, thread_size, 0, context.stream()>>>( | ||
input_data, output_data, random_data, temperature, size); | ||
} | ||
}; | ||
|
||
#endif | ||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
REGISTER_OP_CUDA_KERNEL( | ||
gumbel_softmax, ops::GumbelSoftmaxKernel<plat::CUDADeviceContext, float>, | ||
ops::GumbelSoftmaxKernel<plat::CUDADeviceContext, double>); | ||
REGISTER_OP_CUDA_KERNEL( | ||
gumbel_softmax_grad, | ||
ops::GumbelSoftmaxGradKernel<plat::CUDADeviceContext, float>, | ||
ops::GumbelSoftmaxGradKernel<plat::CUDADeviceContext, double>); |
Oops, something went wrong.