diff --git a/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc b/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc index b3792a176fabe..a80f590aa495d 100644 --- a/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc +++ b/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc @@ -405,8 +405,18 @@ TEST(CudnnNormConvFp16, K1S1) { CudnnNormConvolutionTester test( batch_size, height, width, input_channels, output_channels, kernel_size, stride); - test.CheckForward(1e-3, true); - test.CheckBackward(1e-3, true); + platform::CUDADeviceContext *ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0))); + + if (ctx->GetComputeCapability() <= 70) { + ASSERT_THROW(test.CheckForward(1e-3, true), + paddle::platform::EnforceNotMet); + ASSERT_THROW(test.CheckBackward(1e-3, true), + paddle::platform::EnforceNotMet); + } else { + ASSERT_NO_THROW(test.CheckForward(1e-3, true)); + ASSERT_NO_THROW(test.CheckBackward(1e-3, true)); + } } // test for fp16, kernel = 3, output_channels = input_channels @@ -421,8 +431,18 @@ TEST(CudnnNormConvFp16, K3S1) { CudnnNormConvolutionTester test( batch_size, height, width, input_channels, output_channels, kernel_size, stride); - test.CheckForward(1e-3, true); - test.CheckBackward(1e-3, true); + platform::CUDADeviceContext *ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0))); + + if (ctx->GetComputeCapability() <= 70) { + ASSERT_THROW(test.CheckForward(1e-3, true), + paddle::platform::EnforceNotMet); + ASSERT_THROW(test.CheckBackward(1e-3, true), + paddle::platform::EnforceNotMet); + } else { + ASSERT_NO_THROW(test.CheckForward(1e-3, true)); + ASSERT_NO_THROW(test.CheckBackward(1e-3, true)); + } } // test for fp16, kernel = 1, output_channels = input_channels * 4 @@ -437,8 +457,18 @@ TEST(CudnnNormConvFp16, K1S1O4) { CudnnNormConvolutionTester test( batch_size, height, width, input_channels, output_channels, kernel_size, stride); - test.CheckForward(1e-3, true); - test.CheckBackward(1e-3, true); + platform::CUDADeviceContext *ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0))); + + if (ctx->GetComputeCapability() <= 70) { + ASSERT_THROW(test.CheckForward(1e-3, true), + paddle::platform::EnforceNotMet); + ASSERT_THROW(test.CheckBackward(1e-3, true), + paddle::platform::EnforceNotMet); + } else { + ASSERT_NO_THROW(test.CheckForward(1e-3, true)); + ASSERT_NO_THROW(test.CheckBackward(1e-3, true)); + } } // test for fp16, kernel = 1, stride = 2, output_channels = input_channels * 4