Skip to content

Commit

Permalink
follow comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Oct 11, 2017
1 parent 3db3a10 commit c85d777
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 34 deletions.
8 changes: 5 additions & 3 deletions paddle/operators/math/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu vol2col.cc vol2col.cu pooling.cc pooling.cu DEPS cblas device_context operator)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu pooling.cc pooling.cu DEPS cblas device_context operator)
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS cblas device_context operator)
else()
cc_library(math_function SRCS math_function.cc im2col.cc vol2col.cc pooling.cc DEPS cblas device_context operator)
cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(vol2col SRCS vol2col.cc DEPS cblas device_context operator)

endif()

cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS math_function tensor)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
2 changes: 1 addition & 1 deletion paddle/operators/math/vol2col.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Vol2ColFunctor<platform::CPUPlace, T> {
((c * output_depth + d) * output_height + h) * output_width + w;
if (h_pad < 0 || h_pad >= input_height || w_pad < 0 ||
w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) {
col_data[col_idx] = T(0);
col_data[col_idx] = static_cast<T>(0);
} else {
int vol_idx =
((c_in * input_depth + d_pad) * input_height + h_pad) *
Expand Down
40 changes: 10 additions & 30 deletions paddle/operators/math/vol2col_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ void testVol2col() {
context =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
} else {
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
#else
PADDLE_THROW("no GPU support");
#endif // PADDLE_ONLY_CPU
#endif // PADDLE_WITH_CUDA
}

/**
Expand Down Expand Up @@ -89,6 +89,7 @@ void testVol2col() {
vol2col(*context, input, output_cfo, stride, stride, stride, padding, padding,
padding);

float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11};
float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) {
out_cfo_ptr = output_cfo.data<float>();
Expand All @@ -97,24 +98,12 @@ void testVol2col() {
out_cfo_ptr = output_tmp.data<float>();
}

EXPECT_EQ(out_cfo_ptr[0], 0);
EXPECT_EQ(out_cfo_ptr[1], 1);
EXPECT_EQ(out_cfo_ptr[2], 1);
EXPECT_EQ(out_cfo_ptr[3], 2);
EXPECT_EQ(out_cfo_ptr[4], 3);
EXPECT_EQ(out_cfo_ptr[5], 4);
EXPECT_EQ(out_cfo_ptr[6], 4);
EXPECT_EQ(out_cfo_ptr[7], 5);
EXPECT_EQ(out_cfo_ptr[8], 6);
EXPECT_EQ(out_cfo_ptr[9], 7);
EXPECT_EQ(out_cfo_ptr[10], 7);
EXPECT_EQ(out_cfo_ptr[11], 8);
EXPECT_EQ(out_cfo_ptr[12], 9);
EXPECT_EQ(out_cfo_ptr[13], 10);
EXPECT_EQ(out_cfo_ptr[14], 10);
EXPECT_EQ(out_cfo_ptr[15], 11);
for (int i = 0; i < 16; ++i) {
EXPECT_EQ(out_cfo_ptr[i], vol_2_col[i]);
}

// Col2Vol test
float col_2_vol[] = {0, 2, 2, 3, 8, 5, 6, 14, 8, 9, 20, 11};
memset(input_ptr, 0, 12 * sizeof(float));
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
Expand All @@ -134,18 +123,9 @@ void testVol2col() {
in_cfo_ptr = input_tmp.data<float>();
}

EXPECT_EQ(in_cfo_ptr[0], 0);
EXPECT_EQ(in_cfo_ptr[1], 2);
EXPECT_EQ(in_cfo_ptr[2], 2);
EXPECT_EQ(in_cfo_ptr[3], 3);
EXPECT_EQ(in_cfo_ptr[4], 8);
EXPECT_EQ(in_cfo_ptr[5], 5);
EXPECT_EQ(in_cfo_ptr[6], 6);
EXPECT_EQ(in_cfo_ptr[7], 14);
EXPECT_EQ(in_cfo_ptr[8], 8);
EXPECT_EQ(in_cfo_ptr[9], 9);
EXPECT_EQ(in_cfo_ptr[10], 20);
EXPECT_EQ(in_cfo_ptr[11], 11);
for (int i = 0; i < 12; ++i) {
EXPECT_EQ(in_cfo_ptr[i], col_2_vol[i]);
}
}

TEST(math, vol2col) {
Expand Down

0 comments on commit c85d777

Please sign in to comment.