Skip to content

Commit

Permalink
Add AVX2 and AVX512 optimization for wavelet transform (#1552)
Browse files Browse the repository at this point in the history
Encoder: performance gain ~0.1%​
Decoder: performance gain ~2.5%
  • Loading branch information
tszumski committed Sep 6, 2024
1 parent 606304d commit e0e0c80
Show file tree
Hide file tree
Showing 2 changed files with 360 additions and 7 deletions.
224 changes: 217 additions & 7 deletions src/lib/openjp2/dwt.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
#ifdef __SSSE3__
#include <tmmintrin.h>
#endif
#ifdef __AVX2__
#if (defined(__AVX2__) || defined(__AVX512F__))
#include <immintrin.h>
#endif

Expand All @@ -66,7 +66,10 @@
#define OPJ_WS(i) v->mem[(i)*2]
#define OPJ_WD(i) v->mem[(1+(i)*2)]

#ifdef __AVX2__
#if defined(__AVX512F__)
/** Number of int32 values in a AVX512 register */
#define VREG_INT_COUNT 16
#elif defined(__AVX2__)
/** Number of int32 values in a AVX2 register */
#define VREG_INT_COUNT 8
#else
Expand Down Expand Up @@ -331,6 +334,51 @@ static void opj_dwt_decode_1(const opj_dwt_t *v)

#endif /* STANDARD_SLOW_VERSION */

#if defined(__AVX512F__)
static int32_t loop_short_sse(int32_t len, const int32_t** lf_ptr,
const int32_t** hf_ptr, int32_t** out_ptr,
int32_t* prev_even)
{
int32_t next_even;
__m128i odd, even_m1, unpack1, unpack2;
const int32_t batch = (len - 2) / 8;
const __m128i two = _mm_set1_epi32(2);

for (int32_t i = 0; i < batch; i++) {
const __m128i lf_ = _mm_loadu_si128((__m128i*)(*lf_ptr + 1));
const __m128i hf1_ = _mm_loadu_si128((__m128i*)(*hf_ptr));
const __m128i hf2_ = _mm_loadu_si128((__m128i*)(*hf_ptr + 1));

__m128i even = _mm_add_epi32(hf1_, hf2_);
even = _mm_add_epi32(even, two);
even = _mm_srai_epi32(even, 2);
even = _mm_sub_epi32(lf_, even);

next_even = _mm_extract_epi32(even, 3);
even_m1 = _mm_bslli_si128(even, 4);
even_m1 = _mm_insert_epi32(even_m1, *prev_even, 0);

//out[0] + out[2]
odd = _mm_add_epi32(even_m1, even);
odd = _mm_srai_epi32(odd, 1);
odd = _mm_add_epi32(odd, hf1_);

unpack1 = _mm_unpacklo_epi32(even_m1, odd);
unpack2 = _mm_unpackhi_epi32(even_m1, odd);

_mm_storeu_si128((__m128i*)(*out_ptr + 0), unpack1);
_mm_storeu_si128((__m128i*)(*out_ptr + 4), unpack2);

*prev_even = next_even;

*out_ptr += 8;
*lf_ptr += 4;
*hf_ptr += 4;
}
return batch;
}
#endif

#if !defined(STANDARD_SLOW_VERSION)
static void opj_idwt53_h_cas0(OPJ_INT32* tmp,
const OPJ_INT32 sn,
Expand Down Expand Up @@ -363,6 +411,145 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp,
if (!(len & 1)) { /* if len is even */
tmp[len - 1] = in_odd[(len - 1) / 2] + tmp[len - 2];
}
#else
#if defined(__AVX512F__)
OPJ_INT32* out_ptr = tmp;
int32_t prev_even = in_even[0] - ((in_odd[0] + 1) >> 1);

const __m512i permutevar_mask = _mm512_setr_epi32(
0x10, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
0x0c, 0x0d, 0x0e);
const __m512i store1_perm = _mm512_setr_epi64(0x00, 0x01, 0x08, 0x09, 0x02,
0x03, 0x0a, 0x0b);
const __m512i store2_perm = _mm512_setr_epi64(0x04, 0x05, 0x0c, 0x0d, 0x06,
0x07, 0x0e, 0x0f);

const __m512i two = _mm512_set1_epi32(2);

int32_t simd_batch_512 = (len - 2) / 32;
int32_t leftover;

for (i = 0; i < simd_batch_512; i++) {
const __m512i lf_avx2 = _mm512_loadu_si512((__m512i*)(in_even + 1));
const __m512i hf1_avx2 = _mm512_loadu_si512((__m512i*)(in_odd));
const __m512i hf2_avx2 = _mm512_loadu_si512((__m512i*)(in_odd + 1));
int32_t next_even;
__m512i duplicate, even_m1, odd, unpack1, unpack2, store1, store2;

__m512i even = _mm512_add_epi32(hf1_avx2, hf2_avx2);
even = _mm512_add_epi32(even, two);
even = _mm512_srai_epi32(even, 2);
even = _mm512_sub_epi32(lf_avx2, even);

next_even = _mm_extract_epi32(_mm512_extracti32x4_epi32(even, 3), 3);

duplicate = _mm512_set1_epi32(prev_even);
even_m1 = _mm512_permutex2var_epi32(even, permutevar_mask, duplicate);

//out[0] + out[2]
odd = _mm512_add_epi32(even_m1, even);
odd = _mm512_srai_epi32(odd, 1);
odd = _mm512_add_epi32(odd, hf1_avx2);

unpack1 = _mm512_unpacklo_epi32(even_m1, odd);
unpack2 = _mm512_unpackhi_epi32(even_m1, odd);

store1 = _mm512_permutex2var_epi64(unpack1, store1_perm, unpack2);
store2 = _mm512_permutex2var_epi64(unpack1, store2_perm, unpack2);

_mm512_storeu_si512(out_ptr, store1);
_mm512_storeu_si512(out_ptr + 16, store2);

prev_even = next_even;

out_ptr += 32;
in_even += 16;
in_odd += 16;
}

leftover = len - simd_batch_512 * 32;
if (leftover > 8) {
leftover -= 8 * loop_short_sse(leftover, &in_even, &in_odd, &out_ptr,
&prev_even);
}
out_ptr[0] = prev_even;

for (j = 1; j < (leftover - 2); j += 2) {
out_ptr[2] = in_even[1] - ((in_odd[0] + (in_odd[1]) + 2) >> 2);
out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1);
in_even++;
in_odd++;
out_ptr += 2;
}

if (len & 1) {
out_ptr[2] = in_even[1] - ((in_odd[0] + 1) >> 1);
out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1);
} else { //!(len & 1)
out_ptr[1] = in_odd[0] + out_ptr[0];
}
#elif defined(__AVX2__)
OPJ_INT32* out_ptr = tmp;
int32_t prev_even = in_even[0] - ((in_odd[0] + 1) >> 1);

const __m256i reg_permutevar_mask_move_right = _mm256_setr_epi32(0x00, 0x00,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06);
const __m256i two = _mm256_set1_epi32(2);

int32_t simd_batch = (len - 2) / 16;
int32_t next_even;
__m256i even_m1, odd, unpack1_avx2, unpack2_avx2;

for (i = 0; i < simd_batch; i++) {
const __m256i lf_avx2 = _mm256_loadu_si256((__m256i*)(in_even + 1));
const __m256i hf1_avx2 = _mm256_loadu_si256((__m256i*)(in_odd));
const __m256i hf2_avx2 = _mm256_loadu_si256((__m256i*)(in_odd + 1));

__m256i even = _mm256_add_epi32(hf1_avx2, hf2_avx2);
even = _mm256_add_epi32(even, two);
even = _mm256_srai_epi32(even, 2);
even = _mm256_sub_epi32(lf_avx2, even);

next_even = _mm_extract_epi32(_mm256_extracti128_si256(even, 1), 3);
even_m1 = _mm256_permutevar8x32_epi32(even, reg_permutevar_mask_move_right);
even_m1 = _mm256_blend_epi32(even_m1, _mm256_set1_epi32(prev_even), (1 << 0));

//out[0] + out[2]
odd = _mm256_add_epi32(even_m1, even);
odd = _mm256_srai_epi32(odd, 1);
odd = _mm256_add_epi32(odd, hf1_avx2);

unpack1_avx2 = _mm256_unpacklo_epi32(even_m1, odd);
unpack2_avx2 = _mm256_unpackhi_epi32(even_m1, odd);

_mm_storeu_si128((__m128i*)(out_ptr + 0), _mm256_castsi256_si128(unpack1_avx2));
_mm_storeu_si128((__m128i*)(out_ptr + 4), _mm256_castsi256_si128(unpack2_avx2));
_mm_storeu_si128((__m128i*)(out_ptr + 8), _mm256_extracti128_si256(unpack1_avx2,
0x1));
_mm_storeu_si128((__m128i*)(out_ptr + 12),
_mm256_extracti128_si256(unpack2_avx2, 0x1));

prev_even = next_even;

out_ptr += 16;
in_even += 8;
in_odd += 8;
}
out_ptr[0] = prev_even;
for (j = simd_batch * 16 + 1; j < (len - 2); j += 2) {
out_ptr[2] = in_even[1] - ((in_odd[0] + in_odd[1] + 2) >> 2);
out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1);
in_even++;
in_odd++;
out_ptr += 2;
}

if (len & 1) {
out_ptr[2] = in_even[1] - ((in_odd[0] + 1) >> 1);
out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1);
} else { //!(len & 1)
out_ptr[1] = in_odd[0] + out_ptr[0];
}
#else
OPJ_INT32 d1c, d1n, s1n, s0c, s0n;

Expand Down Expand Up @@ -397,7 +584,8 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp,
} else {
tmp[len - 1] = d1n + s0n;
}
#endif
#endif /*(__AVX512F__ || __AVX2__)*/
#endif /*TWO_PASS_VERSION*/
memcpy(tiledp, tmp, (OPJ_UINT32)len * sizeof(OPJ_INT32));
}

Expand Down Expand Up @@ -511,10 +699,20 @@ static void opj_idwt53_h(const opj_dwt_t *dwt,
#endif
}

#if (defined(__SSE2__) || defined(__AVX2__)) && !defined(STANDARD_SLOW_VERSION)
#if (defined(__SSE2__) || defined(__AVX2__) || defined(__AVX512F__)) && !defined(STANDARD_SLOW_VERSION)

/* Conveniency macros to improve the readability of the formulas */
#if __AVX2__
#if defined(__AVX512F__)
#define VREG __m512i
#define LOAD_CST(x) _mm512_set1_epi32(x)
#define LOAD(x) _mm512_loadu_si512((const VREG*)(x))
#define LOADU(x) _mm512_loadu_si512((const VREG*)(x))
#define STORE(x,y) _mm512_storeu_si512((VREG*)(x),(y))
#define STOREU(x,y) _mm512_storeu_si512((VREG*)(x),(y))
#define ADD(x,y) _mm512_add_epi32((x),(y))
#define SUB(x,y) _mm512_sub_epi32((x),(y))
#define SAR(x,y) _mm512_srai_epi32((x),(y))
#elif defined(__AVX2__)
#define VREG __m256i
#define LOAD_CST(x) _mm256_set1_epi32(x)
#define LOAD(x) _mm256_load_si256((const VREG*)(x))
Expand Down Expand Up @@ -576,18 +774,24 @@ static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(
const VREG two = LOAD_CST(2);

assert(len > 1);
#if __AVX2__
#if defined(__AVX512F__)
assert(PARALLEL_COLS_53 == 32);
assert(VREG_INT_COUNT == 16);
#elif defined(__AVX2__)
assert(PARALLEL_COLS_53 == 16);
assert(VREG_INT_COUNT == 8);
#else
assert(PARALLEL_COLS_53 == 8);
assert(VREG_INT_COUNT == 4);
#endif

//For AVX512 code aligned load/store is set to it's unaligned equivalents
#if !defined(__AVX512F__)
/* Note: loads of input even/odd values must be done in a unaligned */
/* fashion. But stores in tmp can be done with aligned store, since */
/* the temporary buffer is properly aligned */
assert((OPJ_SIZE_T)tmp % (sizeof(OPJ_INT32) * VREG_INT_COUNT) == 0);
#endif

s1n_0 = LOADU(in_even + 0);
s1n_1 = LOADU(in_even + VREG_INT_COUNT);
Expand Down Expand Up @@ -678,18 +882,24 @@ static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(
const OPJ_INT32* in_odd = &tiledp_col[0];

assert(len > 2);
#if __AVX2__
#if defined(__AVX512F__)
assert(PARALLEL_COLS_53 == 32);
assert(VREG_INT_COUNT == 16);
#elif defined(__AVX2__)
assert(PARALLEL_COLS_53 == 16);
assert(VREG_INT_COUNT == 8);
#else
assert(PARALLEL_COLS_53 == 8);
assert(VREG_INT_COUNT == 4);
#endif

//For AVX512 code aligned load/store is set to it's unaligned equivalents
#if !defined(__AVX512F__)
/* Note: loads of input even/odd values must be done in a unaligned */
/* fashion. But stores in tmp can be done with aligned store, since */
/* the temporary buffer is properly aligned */
assert((OPJ_SIZE_T)tmp % (sizeof(OPJ_INT32) * VREG_INT_COUNT) == 0);
#endif

s1_0 = LOADU(in_even + stride);
/* in_odd[0] - ((in_even[0] + s1 + 2) >> 2); */
Expand Down
Loading

0 comments on commit e0e0c80

Please sign in to comment.