diff --git a/faiss/gpu/impl/L2Norm.cu b/faiss/gpu/impl/L2Norm.cu index e3b77c8857..e0db8e2b69 100644 --- a/faiss/gpu/impl/L2Norm.cu +++ b/faiss/gpu/impl/L2Norm.cu @@ -31,12 +31,7 @@ namespace gpu { // T: the type we are doing the math in (e.g., float, half) // TVec: the potentially vectorized type we are loading in (e.g., // float4, half2) -template < - typename T, - typename TVec, - int RowTileSize, - bool NormLoop, - bool NormSquared> +template __global__ void l2NormRowMajor( Tensor input, Tensor output) { @@ -56,19 +51,13 @@ __global__ void l2NormRowMajor( if (lastRowTile) { // We are handling the very end of the input matrix rows for (idx_t row = 0; row < input.getSize(0) - rowStart; ++row) { - if (NormLoop) { - rowNorm[0] = 0; - - for (idx_t col = threadIdx.x; col < input.getSize(1); - col += blockDim.x) { - TVec val = input[rowStart + row][col]; - val = Math::mul(val, val); - rowNorm[0] = rowNorm[0] + Math::reduceAdd(val); - } - } else { - TVec val = input[rowStart + row][threadIdx.x]; + rowNorm[0] = 0; + + for (idx_t col = threadIdx.x; col < input.getSize(1); + col += blockDim.x) { + TVec val = input[rowStart + row][col]; val = Math::mul(val, val); - rowNorm[0] = Math::reduceAdd(val); + rowNorm[0] = rowNorm[0] + Math::reduceAdd(val); } rowNorm[0] = warpReduceAllSum(rowNorm[0]); @@ -79,42 +68,18 @@ __global__ void l2NormRowMajor( } else { // We are guaranteed that all RowTileSize rows are available in // [rowStart, rowStart + RowTileSize) - - if (NormLoop) { - // A single block of threads is not big enough to span each - // vector - TVec tmp[RowTileSize]; - -#pragma unroll - for (int row = 0; row < RowTileSize; ++row) { - rowNorm[row] = 0; - } - - for (idx_t col = threadIdx.x; col < input.getSize(1); - col += blockDim.x) { -#pragma unroll - for (int row = 0; row < RowTileSize; ++row) { - tmp[row] = input[rowStart + row][col]; - } - -#pragma unroll - for (int row = 0; row < RowTileSize; ++row) { - tmp[row] = Math::mul(tmp[row], tmp[row]); - } + TVec tmp[RowTileSize]; #pragma unroll - for (int row = 0; row < RowTileSize; ++row) { - rowNorm[row] = - rowNorm[row] + Math::reduceAdd(tmp[row]); - } - } - } else { - TVec tmp[RowTileSize]; + for (int row = 0; row < RowTileSize; ++row) { + rowNorm[row] = 0; + } - // A block of threads is the exact size of the vector + for (idx_t col = threadIdx.x; col < input.getSize(1); + col += blockDim.x) { #pragma unroll for (int row = 0; row < RowTileSize; ++row) { - tmp[row] = input[rowStart + row][threadIdx.x]; + tmp[row] = input[rowStart + row][col]; } #pragma unroll @@ -124,7 +89,7 @@ __global__ void l2NormRowMajor( #pragma unroll for (int row = 0; row < RowTileSize; ++row) { - rowNorm[row] = Math::reduceAdd(tmp[row]); + rowNorm[row] = rowNorm[row] + Math::reduceAdd(tmp[row]); } } @@ -161,7 +126,7 @@ __global__ void l2NormRowMajor( if (laneId == 0) { #pragma unroll for (int row = 0; row < RowTileSize; ++row) { - int outCol = rowStart + row; + idx_t outCol = rowStart + row; if (lastRowTile) { if (outCol < output.getSize(0)) { @@ -218,25 +183,15 @@ void runL2Norm( idx_t maxThreads = (idx_t)getMaxThreadsCurrentDevice(); constexpr int rowTileSize = 8; -#define RUN_L2_ROW_MAJOR(TYPE_T, TYPE_TVEC, INPUT) \ - do { \ - if (normLoop) { \ - if (normSquared) { \ - l2NormRowMajor \ - <<>>(INPUT, output); \ - } else { \ - l2NormRowMajor \ - <<>>(INPUT, output); \ - } \ - } else { \ - if (normSquared) { \ - l2NormRowMajor \ - <<>>(INPUT, output); \ - } else { \ - l2NormRowMajor \ - <<>>(INPUT, output); \ - } \ - } \ +#define RUN_L2_ROW_MAJOR(TYPE_T, TYPE_TVEC, INPUT) \ + do { \ + if (normSquared) { \ + l2NormRowMajor \ + <<>>(INPUT, output); \ + } else { \ + l2NormRowMajor \ + <<>>(INPUT, output); \ + } \ } while (0) if (inputRowMajor) { @@ -247,10 +202,11 @@ void runL2Norm( if (input.template canCastResize()) { // Can load using the vectorized type auto inputV = input.template castResize(); - auto dim = inputV.getSize(1); - bool normLoop = dim > maxThreads; - auto numThreads = std::min(dim, maxThreads); + + // We must always have full warps present + auto numThreads = + std::min(utils::roundUp(dim, kWarpSize), maxThreads); auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize)); auto block = dim3(numThreads); @@ -261,10 +217,11 @@ void runL2Norm( RUN_L2_ROW_MAJOR(T, TVec, inputV); } else { // Can't load using the vectorized type - auto dim = input.getSize(1); - bool normLoop = dim > maxThreads; - auto numThreads = std::min(dim, maxThreads); + + // We must always have full warps present + auto numThreads = + std::min(utils::roundUp(dim, kWarpSize), maxThreads); auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize)); auto block = dim3(numThreads);