Skip to content

Commit

Permalink
add new API/OP:paddle.Tensor.exponential_ (#38256)
Browse files Browse the repository at this point in the history
* add new API/OP:paddle.Tensor.exponential_

* fix CI
  • Loading branch information
zhwesky2010 committed Dec 24, 2021
1 parent c396ee6 commit 3318500
Show file tree
Hide file tree
Showing 7 changed files with 684 additions and 3 deletions.
197 changes: 197 additions & 0 deletions paddle/fluid/operators/distribution_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#ifdef __NVCC__
#include <curand_kernel.h>
#endif
#ifdef __HIPCC__
#include <hiprand_kernel.h>
#endif

#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/hostdevice.h"

namespace paddle {
namespace distribution {

using Tensor = framework::Tensor;

template <typename T>
struct exponential_transform {
explicit exponential_transform(T lambda) : lambda_(lambda) {}

HOSTDEVICE inline T operator()(T val) const {
#if defined(__NVCC__) || defined(__HIPCC__)
if (std::is_same<T, double>::value) {
return static_cast<T>(-1.0) / lambda_ * log(val);
} else {
return static_cast<T>(-1.0) / lambda_ * __logf(val);
}
#else
return static_cast<T>(-1.0) / lambda_ * std::log(static_cast<T>(1.0) - val);
#endif
}

private:
T lambda_;
};

#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
struct uniform_distribution;

template <typename T>
struct normal_distribution;

#if defined(__NVCC__)
template <>
struct uniform_distribution<float> {
__device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const {
return curand_uniform4(state);
}
static constexpr int kReturnsCount = 4;
};

template <>
struct uniform_distribution<double> {
__device__ inline double2 operator()(
curandStatePhilox4_32_10_t *state) const {
return curand_uniform2_double(state);
}
static constexpr int kReturnsCount = 2;
};

template <>
struct normal_distribution<float> {
__device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const {
return curand_normal4(state);
}
static constexpr int kReturnsCount = 4;
};

template <>
struct normal_distribution<double> {
__device__ inline double2 operator()(
curandStatePhilox4_32_10_t *state) const {
return curand_normal2_double(state);
}
static constexpr int kReturnsCount = 2;
};

#else
template <>
struct uniform_distribution<float> {
__device__ inline float4 operator()(
hiprandStatePhilox4_32_10_t *state) const {
return hiprand_uniform4(state);
}
static constexpr int kReturnsCount = 4;
};

template <>
struct uniform_distribution<double> {
__device__ inline double2 operator()(
hiprandStatePhilox4_32_10_t *state) const {
return hiprand_uniform2_double(state);
}
static constexpr int kReturnsCount = 2;
};

template <>
struct normal_distribution<float> {
__device__ inline float4 operator()(
hiprandStatePhilox4_32_10_t *state) const {
return hiprand_normal4(state);
}
static constexpr int kReturnsCount = 4;
};

template <>
struct normal_distribution<double> {
__device__ inline double2 operator()(
hiprandStatePhilox4_32_10_t *state) const {
return hiprand_normal2_double(state);
}
static constexpr int kReturnsCount = 2;
};
#endif

template <typename T, typename DistOp, typename TransformOp>
__global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset,
DistOp dist, TransformOp trans,
T *out_data) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
int32_t returns_count = DistOp::kReturnsCount;
#if defined(__NVCC__)
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, offset, &state);
#else
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx, offset, &state);
#endif
size_t total_thread = gridDim.x * blockDim.x;
for (size_t i = idx; i < size; i += total_thread * returns_count) {
auto random_tuple = dist(&state);
for (size_t j = 0; j < returns_count; j++) {
size_t index = i + j * total_thread;
if (index < size) {
auto random = static_cast<T>((&random_tuple.x)[j]);
out_data[index] = trans(random);
}
}
}
}

template <typename T, typename DistOp, typename TransformOp>
void distribution_and_transform(const platform::CUDADeviceContext &dev_ctx,
Tensor *out, DistOp dist, TransformOp trans) {
T *out_data = out->mutable_data<T>(dev_ctx.GetPlace());
auto size = out->numel();

int64_t device_id =
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);

size_t block_size = 256;
size_t expect_grid_size = (size + block_size - 1) / block_size;
const auto &prop = platform::GetDeviceProperties(device_id);
size_t max_grid_size = (prop.maxThreadsPerMultiProcessor / block_size) *
prop.multiProcessorCount;
size_t grid_size =
expect_grid_size > max_grid_size ? max_grid_size : expect_grid_size;

size_t total_thread = block_size * grid_size;
size_t curand4_loop_times =
(size + 4 * total_thread - 1) / (4 * total_thread);
// 'increment' shoulde be multiple of 4
uint64_t increment = curand4_loop_times * 4;

auto seed_offset = gen_cuda->IncrementOffset(increment);
uint64_t seed = seed_offset.first;
uint64_t offset = seed_offset.second;

DistributionKernel<
T, DistOp, TransformOp><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
size, seed, offset, dist, trans, out_data);
}

#endif

} // namespace distribution
} // namespace paddle
137 changes: 137 additions & 0 deletions paddle/fluid/operators/exponential_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/exponential_op.h"

namespace paddle {
namespace operators {

class ExponentialOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExponentialOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExponentialOp");
auto dim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", dim);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};

class ExponentialOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(
This operator fills the input tensor with random values sampled from a
exponential distribution.
)DOC");
AddInput("X", "The input tensor.");
AddOutput("Out", "The output tensor of exponential OP.");
AddAttr<float>(
"lambda", "lambd parameter of exponential distribution. [default 1.0].")
.SetDefault(1.0f);
}
};

class ExponentialOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};

template <typename T>
class ExponentialKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *out = ctx.Output<framework::Tensor>("Out");
T *out_data = out->mutable_data<T>(ctx.GetPlace());

T lambda = static_cast<T>(ctx.Attr<float>("lambda"));
int64_t size = out->numel();

auto gen = framework::DefaultCPUGenerator();
auto engine = gen->GetCPUEngine();

std::uniform_real_distribution<T> uniform(0.0, 1.0);
distribution::exponential_transform<T> trans(lambda);
for (int64_t i = 0; i < size; ++i) {
out_data[i] = trans(uniform(*engine));
}
}
};

class ExponentialGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out_Grad", "ExponentialGradOp");

auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dim);
}
};

template <typename T>
class ExponentialGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("exponential_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

DECLARE_INPLACE_OP_INFERER(ExponentialInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ExponentialGradInferer,
{paddle::framework::GradVarName("Out"),
paddle::framework::GradVarName("X")});

REGISTER_OPERATOR(exponential, ops::ExponentialOp, ops::ExponentialOpMaker,
ops::ExponentialOpInferVarType,
ops::ExponentialGradOpMaker<paddle::framework::OpDesc>,
ops::ExponentialGradOpMaker<paddle::imperative::OpBase>,
ExponentialInferer);
REGISTER_OPERATOR(exponential_grad, ops::ExponentialGradOp,
ExponentialGradInferer);

REGISTER_OP_CPU_KERNEL(exponential,
ops::ExponentialKernel<plat::CPUDeviceContext, float>,
ops::ExponentialKernel<plat::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
exponential_grad, ops::ExponentialGradKernel<plat::CPUDeviceContext, float>,
ops::ExponentialGradKernel<plat::CPUDeviceContext, double>);
47 changes: 47 additions & 0 deletions paddle/fluid/operators/exponential_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/exponential_op.h"

namespace paddle {
namespace operators {

template <typename T>
class ExponentialKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
auto& dev_cxt = ctx.template device_context<platform::CUDADeviceContext>();
T lambda = static_cast<T>(ctx.Attr<float>("lambda"));

distribution::uniform_distribution<T> dist;
distribution::exponential_transform<T> trans(lambda);
distribution::distribution_and_transform<T>(dev_cxt, out, dist, trans);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
exponential, ops::ExponentialKernel<plat::CUDADeviceContext, float>,
ops::ExponentialKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
exponential_grad,
ops::ExponentialGradKernel<plat::CUDADeviceContext, float>,
ops::ExponentialGradKernel<plat::CUDADeviceContext, double>);
Loading

0 comments on commit 3318500

Please sign in to comment.