Skip to content

Commit

Permalink
Faiss GPU CUDA 12 fix: warp synchronous behavior
Browse files Browse the repository at this point in the history
Summary:
This diff fixes the bug associated with moving Faiss GPU to CUDA 12.

The following tests were succeeding in CUDA 11.x but failed in CUDA 12:

```
  ✗ faiss/gpu/test:test_gpu_basics_py - test_input_types (faiss.gpu.test.test_gpu_basics.TestKnn)
  ✗ faiss/gpu/test:test_gpu_basics_py - test_dist (faiss.gpu.test.test_gpu_basics.TestAllPairwiseDistance)
  ✗ faiss/gpu/test:test_gpu_index_ivfpq - TestGpuIndexIVFPQ.Add_L2
  ✗ faiss/gpu/test:test_gpu_basics_py - test_input_types_tiling (faiss.gpu.test.test_gpu_basics.TestKnn)
  ✗ faiss/gpu/test:test_gpu_index_ivfpq - TestGpuIndexIVFPQ.Add_IP
  ✗ faiss/gpu/test:test_gpu_index_ivfpq - TestGpuIndexIVFPQ.Float16Coarse
  ✗ faiss/gpu/test:test_gpu_index_ivfpq - TestGpuIndexIVFPQ.LargeBatch
```

It took a long while to track down, but the issue presented itself when an odd number of dimensions not divisible by 32 was used in cases where we needed to calculate a L2 norm for vectors, which occurred with brute-force L2 distance computation, as well as certain L2 IVFPQ operations. This issue appeared as some tests were using 33 as the dimensionality of vectors.

The issue is that the number of threads given to the L2 norm kernel was effectively `min(dims, 1024)` where 1024 is the standard maximum number of CUDA threads per CTA on all devices at present. In the case where the result was not a multiple of 32, this would result in a partial warp being passed to the kernel (with non-participating lanes having no side effects).

The change in CUDA 12 here seemed to be a change in the compiler behavior for warp-synchronous shuffle instructions (such as `__shfl_up_sync`. In the case of the partial warp, we were passing `0xffffffff` as the active lane mask, implying that all lanes were present for the warp. In the case of dims = 33, we would have 1 full warp with all lanes present, and 1 partial warp with only 1 active thread, so `0xffffffff` is a lie in this case. Prior to CUDA 12, it appeared that these shuffle instructions may have passed 0? around for lanes not present (or would it stall?), so the result was still calculated correctly. However, with the change to CUDA 12, the compiler and/or device firmware (or something) interprets this differently, where the warp lanes not present were providing garbage. The shuffle instructions were used to perform in-warp reductions (e.g., summing a bunch of floating point numbers), namely those needed to sum up the L2 vector norm value. So for dims = 32 or dims = 64 (and bizarrely, dims = 40 and some other choices) it still worked, but for dims = 33 it was adding in garbage, producing erroneous results.

This diff removes the non-dim loop functionality for runL2Norm (where we can statically avoid a for loop over dimensions in case our threadblock is exactly sized with the number of dimensions present) and we just use the general-purpose fallback. Second, we now always provide an even number of warps when running the L2 norm kernel, avoiding the issue with the warp synchronous instructions not having a full warp present.

This bug has been present since the code was written 2016 and was technically wrong before, but is only surfaced to be a bug/problem with the CUDA 12 change.

tl;dr: if you use any kind of `_sync` instruction involving warp sync, always have a whole number of warps present, k thx.

Reviewed By: mdouze

Differential Revision: D51335172

fbshipit-source-id: 97da88a8dcbe6b4d8963083abc01d5d2121478bf
  • Loading branch information
Jeff Johnson authored and facebook-github-bot committed Nov 15, 2023
1 parent 0c2243c commit 09c7aac
Showing 1 changed file with 33 additions and 76 deletions.
109 changes: 33 additions & 76 deletions faiss/gpu/impl/L2Norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,7 @@ namespace gpu {
// T: the type we are doing the math in (e.g., float, half)
// TVec: the potentially vectorized type we are loading in (e.g.,
// float4, half2)
template <
typename T,
typename TVec,
int RowTileSize,
bool NormLoop,
bool NormSquared>
template <typename T, typename TVec, int RowTileSize, bool NormSquared>
__global__ void l2NormRowMajor(
Tensor<TVec, 2, true> input,
Tensor<float, 1, true> output) {
Expand All @@ -56,19 +51,13 @@ __global__ void l2NormRowMajor(
if (lastRowTile) {
// We are handling the very end of the input matrix rows
for (idx_t row = 0; row < input.getSize(0) - rowStart; ++row) {
if (NormLoop) {
rowNorm[0] = 0;

for (idx_t col = threadIdx.x; col < input.getSize(1);
col += blockDim.x) {
TVec val = input[rowStart + row][col];
val = Math<TVec>::mul(val, val);
rowNorm[0] = rowNorm[0] + Math<TVec>::reduceAdd(val);
}
} else {
TVec val = input[rowStart + row][threadIdx.x];
rowNorm[0] = 0;

for (idx_t col = threadIdx.x; col < input.getSize(1);
col += blockDim.x) {
TVec val = input[rowStart + row][col];
val = Math<TVec>::mul(val, val);
rowNorm[0] = Math<TVec>::reduceAdd(val);
rowNorm[0] = rowNorm[0] + Math<TVec>::reduceAdd(val);
}

rowNorm[0] = warpReduceAllSum(rowNorm[0]);
Expand All @@ -79,42 +68,18 @@ __global__ void l2NormRowMajor(
} else {
// We are guaranteed that all RowTileSize rows are available in
// [rowStart, rowStart + RowTileSize)

if (NormLoop) {
// A single block of threads is not big enough to span each
// vector
TVec tmp[RowTileSize];

#pragma unroll
for (int row = 0; row < RowTileSize; ++row) {
rowNorm[row] = 0;
}

for (idx_t col = threadIdx.x; col < input.getSize(1);
col += blockDim.x) {
#pragma unroll
for (int row = 0; row < RowTileSize; ++row) {
tmp[row] = input[rowStart + row][col];
}

#pragma unroll
for (int row = 0; row < RowTileSize; ++row) {
tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
}
TVec tmp[RowTileSize];

#pragma unroll
for (int row = 0; row < RowTileSize; ++row) {
rowNorm[row] =
rowNorm[row] + Math<TVec>::reduceAdd(tmp[row]);
}
}
} else {
TVec tmp[RowTileSize];
for (int row = 0; row < RowTileSize; ++row) {
rowNorm[row] = 0;
}

// A block of threads is the exact size of the vector
for (idx_t col = threadIdx.x; col < input.getSize(1);
col += blockDim.x) {
#pragma unroll
for (int row = 0; row < RowTileSize; ++row) {
tmp[row] = input[rowStart + row][threadIdx.x];
tmp[row] = input[rowStart + row][col];
}

#pragma unroll
Expand All @@ -124,7 +89,7 @@ __global__ void l2NormRowMajor(

#pragma unroll
for (int row = 0; row < RowTileSize; ++row) {
rowNorm[row] = Math<TVec>::reduceAdd(tmp[row]);
rowNorm[row] = rowNorm[row] + Math<TVec>::reduceAdd(tmp[row]);
}
}

Expand Down Expand Up @@ -161,7 +126,7 @@ __global__ void l2NormRowMajor(
if (laneId == 0) {
#pragma unroll
for (int row = 0; row < RowTileSize; ++row) {
int outCol = rowStart + row;
idx_t outCol = rowStart + row;

if (lastRowTile) {
if (outCol < output.getSize(0)) {
Expand Down Expand Up @@ -218,25 +183,15 @@ void runL2Norm(
idx_t maxThreads = (idx_t)getMaxThreadsCurrentDevice();
constexpr int rowTileSize = 8;

#define RUN_L2_ROW_MAJOR(TYPE_T, TYPE_TVEC, INPUT) \
do { \
if (normLoop) { \
if (normSquared) { \
l2NormRowMajor<TYPE_T, TYPE_TVEC, rowTileSize, true, true> \
<<<grid, block, smem, stream>>>(INPUT, output); \
} else { \
l2NormRowMajor<TYPE_T, TYPE_TVEC, rowTileSize, true, false> \
<<<grid, block, smem, stream>>>(INPUT, output); \
} \
} else { \
if (normSquared) { \
l2NormRowMajor<TYPE_T, TYPE_TVEC, rowTileSize, false, true> \
<<<grid, block, smem, stream>>>(INPUT, output); \
} else { \
l2NormRowMajor<TYPE_T, TYPE_TVEC, rowTileSize, false, false> \
<<<grid, block, smem, stream>>>(INPUT, output); \
} \
} \
#define RUN_L2_ROW_MAJOR(TYPE_T, TYPE_TVEC, INPUT) \
do { \
if (normSquared) { \
l2NormRowMajor<TYPE_T, TYPE_TVEC, rowTileSize, true> \
<<<grid, block, smem, stream>>>(INPUT, output); \
} else { \
l2NormRowMajor<TYPE_T, TYPE_TVEC, rowTileSize, false> \
<<<grid, block, smem, stream>>>(INPUT, output); \
} \
} while (0)

if (inputRowMajor) {
Expand All @@ -247,10 +202,11 @@ void runL2Norm(
if (input.template canCastResize<TVec>()) {
// Can load using the vectorized type
auto inputV = input.template castResize<TVec>();

auto dim = inputV.getSize(1);
bool normLoop = dim > maxThreads;
auto numThreads = std::min(dim, maxThreads);

// We must always have full warps present
auto numThreads =
std::min(utils::roundUp(dim, kWarpSize), maxThreads);

auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize));
auto block = dim3(numThreads);
Expand All @@ -261,10 +217,11 @@ void runL2Norm(
RUN_L2_ROW_MAJOR(T, TVec, inputV);
} else {
// Can't load using the vectorized type

auto dim = input.getSize(1);
bool normLoop = dim > maxThreads;
auto numThreads = std::min(dim, maxThreads);

// We must always have full warps present
auto numThreads =
std::min(utils::roundUp(dim, kWarpSize), maxThreads);

auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize));
auto block = dim3(numThreads);
Expand Down

0 comments on commit 09c7aac

Please sign in to comment.