Skip to content

Commit

Permalink
Add exception throw for norm_conv when platform is not supported (#40166
Browse files Browse the repository at this point in the history
)

* Add throw for norm_conv when platform is not supported

* fix format
  • Loading branch information
ZzSean committed Mar 8, 2022
1 parent 73583f8 commit 00566ea
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions paddle/fluid/operators/fused/cudnn_norm_conv_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,18 @@ TEST(CudnnNormConvFp16, K1S1) {
CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size,
stride);
test.CheckForward(1e-3, true);
test.CheckBackward(1e-3, true);
platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));

if (ctx->GetComputeCapability() <= 70) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3, true),
paddle::platform::EnforceNotMet);
} else {
ASSERT_NO_THROW(test.CheckForward(1e-3, true));
ASSERT_NO_THROW(test.CheckBackward(1e-3, true));
}
}

// test for fp16, kernel = 3, output_channels = input_channels
Expand All @@ -421,8 +431,18 @@ TEST(CudnnNormConvFp16, K3S1) {
CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size,
stride);
test.CheckForward(1e-3, true);
test.CheckBackward(1e-3, true);
platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));

if (ctx->GetComputeCapability() <= 70) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3, true),
paddle::platform::EnforceNotMet);
} else {
ASSERT_NO_THROW(test.CheckForward(1e-3, true));
ASSERT_NO_THROW(test.CheckBackward(1e-3, true));
}
}

// test for fp16, kernel = 1, output_channels = input_channels * 4
Expand All @@ -437,8 +457,18 @@ TEST(CudnnNormConvFp16, K1S1O4) {
CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size,
stride);
test.CheckForward(1e-3, true);
test.CheckBackward(1e-3, true);
platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));

if (ctx->GetComputeCapability() <= 70) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3, true),
paddle::platform::EnforceNotMet);
} else {
ASSERT_NO_THROW(test.CheckForward(1e-3, true));
ASSERT_NO_THROW(test.CheckBackward(1e-3, true));
}
}

// test for fp16, kernel = 1, stride = 2, output_channels = input_channels * 4
Expand Down

0 comments on commit 00566ea

Please sign in to comment.