Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"derived cudnnDevice context" #6585

Merged
merged 3 commits into from
Dec 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions paddle/operators/math/math_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,13 @@ void set_constant_with_place<platform::GPUPlace>(
TensorSetConstantGPU(context, tensor, value));
}

template <>
void set_constant_with_place<platform::CudnnPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
set_constant_with_place<platform::GPUPlace>(context, tensor, value);
}

template struct RowwiseAdd<platform::CUDADeviceContext, float>;
template struct RowwiseAdd<platform::CUDADeviceContext, double>;
template struct ColwiseSum<platform::CUDADeviceContext, float>;
Expand Down
16 changes: 16 additions & 0 deletions paddle/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }

cudaStream_t CUDADeviceContext::stream() const { return stream_; }

CudnnDeviceContext::CudnnDeviceContext(CudnnPlace place)
: CUDADeviceContext(place), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream()));
}

CudnnDeviceContext::~CudnnDeviceContext() {
SetDeviceId(place_.device);
Wait();
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}

Place CudnnDeviceContext::GetPlace() const { return CudnnPlace(); }

cudnnHandle_t CudnnDeviceContext::cudnn_handle() const { return cudnn_handle_; }

#endif

} // namespace platform
Expand Down
16 changes: 16 additions & 0 deletions paddle/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t cublas_handle_;
};

class CudnnDeviceContext : public CUDADeviceContext {
public:
explicit CudnnDeviceContext(CudnnPlace place);
virtual ~CudnnDeviceContext();

/*! \brief Return place in the device context. */
Place GetPlace() const final;

/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;

private:
cudnnHandle_t cudnn_handle_;
CudnnPlace place_;
};

#endif

} // namespace platform
Expand Down
16 changes: 16 additions & 0 deletions paddle/platform/device_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,19 @@ TEST(Device, CUDADeviceContext) {
delete device_context;
}
}

TEST(Device, CudnnDeviceContext) {
using paddle::platform::CudnnDeviceContext;
using paddle::platform::CudnnPlace;
if (paddle::platform::dynload::HasCUDNN()) {
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
CudnnDeviceContext* device_context =
new CudnnDeviceContext(CudnnPlace(i));
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle);
ASSERT_NE(nullptr, device_context->stream());
delete device_context;
}
}
}
7 changes: 6 additions & 1 deletion paddle/platform/place.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ struct GPUPlace {
int device;
};

struct CudnnPlace : public GPUPlace {
CudnnPlace() : GPUPlace() {}
explicit CudnnPlace(int d) : GPUPlace(d) {}
};

struct IsGPUPlace : public boost::static_visitor<bool> {
bool operator()(const CPUPlace &) const { return false; }
bool operator()(const GPUPlace &gpu) const { return true; }
Expand All @@ -52,7 +57,7 @@ struct IsGPUPlace : public boost::static_visitor<bool> {
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4

typedef boost::variant<GPUPlace, CPUPlace> Place;
typedef boost::variant<CudnnPlace, GPUPlace, CPUPlace> Place;

// static check number of place types is less equal than
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
Expand Down