diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index 528843f606..7314bdde32 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -44,7 +44,10 @@ namespace faiss { * that hides the template mess. ********************************************************************/ -#ifdef __AVX2__ +#if defined(__AVX512F__) && defined(__F16C__) +#define USE_AVX512_F16C +#endif +#if defined(__AVX2__) #ifdef __F16C__ #define USE_F16C #else @@ -79,6 +82,18 @@ struct Codec8bit { return (code[i] + 0.5f) / 255.0f; } +#ifdef __AVX512F__ + static FAISS_ALWAYS_INLINE __m512 + decode_16_components(const uint8_t* code, int i) { + const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i)); + const __m512i i32 = _mm512_cvtepu8_epi32(c16); + const __m512 f16 = _mm512_cvtepi32_ps(i32); + const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f); + const __m512 one_255 = _mm512_set1_ps(1.f / 255.f); + return _mm512_fmadd_ps(f16, one_255, half_one_255); + } +#endif + #ifdef __AVX2__ static FAISS_ALWAYS_INLINE __m256 decode_8_components(const uint8_t* code, int i) { @@ -121,6 +136,27 @@ struct Codec4bit { return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f; } +#ifdef __AVX512F__ + static FAISS_ALWAYS_INLINE __m512 + decode_16_components(const uint8_t* code, int i) { + uint64_t c8 = *(uint64_t*)(code + (i >> 1)); + uint64_t mask = 0x0f0f0f0f0f0f0f0f; + uint64_t c8ev = c8 & mask; + uint64_t c8od = (c8 >> 4) & mask; + + __m128i c16 = + _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od)); + __m256i c8lo = _mm256_cvtepu8_epi32(c16); + __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8)); + __m512i i16 = _mm512_castsi256_si512(c8lo); + i16 = _mm512_inserti32x8(i16, c8hi, 1); + __m512 f16 = _mm512_cvtepi32_ps(i16); + const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f); + const __m512 one_255 = _mm512_set1_ps(1.f / 15.f); + return _mm512_fmadd_ps(f16, one_255, half_one_255); + } +#endif + #ifdef __AVX2__ static FAISS_ALWAYS_INLINE __m256 decode_8_components(const uint8_t* code, int i) { @@ -207,6 +243,57 @@ struct Codec6bit { return (bits + 0.5f) / 63.0f; } +#ifdef __AVX512F__ + + static FAISS_ALWAYS_INLINE __m512 + decode_16_components(const uint8_t* code, int i) { + // pure AVX512 implementation (not necessarily the fastest). + // see: + // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h + + // clang-format off + + // 16 components, 16x6 bit=12 bytes + const __m128i bit_6v = + _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3); + const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v); + + // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F + // 00 01 02 03 + const __m256i shuffle_mask = _mm256_setr_epi16( + 0xFF00, 0x0100, 0x0201, 0xFF02, + 0xFF03, 0x0403, 0x0504, 0xFF05, + 0xFF06, 0x0706, 0x0807, 0xFF08, + 0xFF09, 0x0A09, 0x0B0A, 0xFF0B); + const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask); + + // 0: xxxxxxxx xx543210 + // 1: xxxx5432 10xxxxxx + // 2: xxxxxx54 3210xxxx + // 3: xxxxxxxx 543210xx + const __m256i shift_right_v = _mm256_setr_epi16( + 0x0U, 0x6U, 0x4U, 0x2U, + 0x0U, 0x6U, 0x4U, 0x2U, + 0x0U, 0x6U, 0x4U, 0x2U, + 0x0U, 0x6U, 0x4U, 0x2U); + __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v); + + // remove unneeded bits + shuffled_shifted = + _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F)); + + // scale + const __m512 f8 = + _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted)); + const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f); + const __m512 one_255 = _mm512_set1_ps(1.f / 63.f); + return _mm512_fmadd_ps(f8, one_255, half_one_255); + + // clang-format on + } + +#endif + #ifdef __AVX2__ /* Load 6 bytes that represent 8 6-bit values, return them as a @@ -316,6 +403,23 @@ struct QuantizerTemplate : ScalarQuantizer::SQuantizer { } }; +#ifdef __AVX512F__ + +template +struct QuantizerTemplate : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate(d, trained) {} + + FAISS_ALWAYS_INLINE __m512 + reconstruct_16_components(const uint8_t* code, int i) const { + __m512 xi = Codec::decode_16_components(code, i); + return _mm512_fmadd_ps( + xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin)); + } +}; + +#endif + #ifdef __AVX2__ template @@ -394,6 +498,26 @@ struct QuantizerTemplate : ScalarQuantizer::SQuantizer { } }; +#ifdef __AVX512F__ + +template +struct QuantizerTemplate + : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate(d, trained) {} + + FAISS_ALWAYS_INLINE __m512 + reconstruct_16_components(const uint8_t* code, int i) const { + __m512 xi = Codec::decode_16_components(code, i); + return _mm512_fmadd_ps( + xi, + _mm512_loadu_ps(this->vdiff + i), + _mm512_loadu_ps(this->vmin + i)); + } +}; + +#endif + #ifdef __AVX2__ template @@ -465,7 +589,23 @@ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer { } }; -#ifdef USE_F16C +#if defined(USE_AVX512_F16C) + +template <> +struct QuantizerFP16<16> : QuantizerFP16<1> { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE __m512 + reconstruct_16_components(const uint8_t* code, int i) const { + __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); + return _mm512_cvtph_ps(codei); + } +}; + +#endif + +#if defined(USE_F16C) template <> struct QuantizerFP16<8> : QuantizerFP16<1> { @@ -528,6 +668,23 @@ struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer { } }; +#ifdef __AVX512F__ + +template <> +struct QuantizerBF16<16> : QuantizerBF16<1> { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16<1>(d, trained) {} + FAISS_ALWAYS_INLINE __m512 + reconstruct_16_components(const uint8_t* code, int i) const { + __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); + __m512i code_512i = _mm512_cvtepu16_epi32(code_256i); + code_512i = _mm512_slli_epi32(code_512i, 16); + return _mm512_castsi512_ps(code_512i); + } +}; + +#endif + #ifdef __AVX2__ template <> @@ -595,6 +752,23 @@ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer { } }; +#ifdef __AVX512F__ + +template <> +struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect<1>(d, trained) {} + + FAISS_ALWAYS_INLINE __m512 + reconstruct_16_components(const uint8_t* code, int i) const { + __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 + __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 + return _mm512_cvtepi32_ps(y16); // 16 * float32 + } +}; + +#endif + #ifdef __AVX2__ template <> @@ -665,6 +839,25 @@ struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer { } }; +#ifdef __AVX512F__ + +template <> +struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned<1>(d, trained) {} + + FAISS_ALWAYS_INLINE __m512 + reconstruct_16_components(const uint8_t* code, int i) const { + __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 + __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 + __m512i c16 = _mm512_set1_epi32(128); + __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes + return _mm512_cvtepi32_ps(z16); // 16 * float32 + } +}; + +#endif + #ifdef __AVX2__ template <> @@ -955,7 +1148,45 @@ struct SimilarityL2<1> { } }; +#ifdef __AVX512F__ + +template <> +struct SimilarityL2<16> { + static constexpr int simdwidth = 16; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2(const float* y) : y(y) {} + __m512 accu16; + + FAISS_ALWAYS_INLINE void begin_16() { + accu16 = _mm512_setzero_ps(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_16_components(__m512 x) { + __m512 yiv = _mm512_loadu_ps(yi); + yi += 16; + __m512 tmp = _mm512_sub_ps(yiv, x); + accu16 = _mm512_fmadd_ps(tmp, tmp, accu16); + } + + FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x, __m512 y_2) { + __m512 tmp = _mm512_sub_ps(y_2, x); + accu16 = _mm512_fmadd_ps(tmp, tmp, accu16); + } + + FAISS_ALWAYS_INLINE float result_16() { + // performs better than dividing into _mm256 and adding + return _mm512_reduce_add_ps(accu16); + } +}; + +#endif + #ifdef __AVX2__ + template <> struct SimilarityL2<8> { static constexpr int simdwidth = 8; @@ -1078,6 +1309,44 @@ struct SimilarityIP<1> { } }; +#ifdef __AVX512F__ + +template <> +struct SimilarityIP<16> { + static constexpr int simdwidth = 16; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + float accu; + + explicit SimilarityIP(const float* y) : y(y) {} + + __m512 accu16; + + FAISS_ALWAYS_INLINE void begin_16() { + accu16 = _mm512_setzero_ps(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_16_components(__m512 x) { + __m512 yiv = _mm512_loadu_ps(yi); + yi += 16; + accu16 = _mm512_fmadd_ps(yiv, x, accu16); + } + + FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x1, __m512 x2) { + accu16 = _mm512_fmadd_ps(x1, x2, accu16); + } + + FAISS_ALWAYS_INLINE float result_16() { + // performs better than dividing into _mm256 and adding + return _mm512_reduce_add_ps(accu16); + } +}; + +#endif + #ifdef __AVX2__ template <> @@ -1220,7 +1489,56 @@ struct DCTemplate : SQDistanceComputer { } }; -#ifdef USE_F16C +#ifdef USE_AVX512_F16C + +template +struct DCTemplate + : SQDistanceComputer { // Update to handle 16 lanes + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) + : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + __m512 xi = quant.reconstruct_16_components(code, i); + sim.add_16_components(xi); + } + return sim.result_16(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + __m512 x1 = quant.reconstruct_16_components(code1, i); + __m512 x2 = quant.reconstruct_16_components(code2, i); + sim.add_16_components_2(x1, x2); + } + return sim.result_16(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; +#endif + +#if defined(USE_F16C) template struct DCTemplate : SQDistanceComputer { @@ -1367,6 +1685,61 @@ struct DistanceComputerByte : SQDistanceComputer { } }; +#ifdef __AVX512F__ + +template +struct DistanceComputerByte : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + __m512i accu = _mm512_setzero_si512(); + for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time + __m512i c1 = _mm512_cvtepu8_epi16( + _mm256_loadu_si256((__m256i*)(code1 + i))); + __m512i c2 = _mm512_cvtepu8_epi16( + _mm256_loadu_si256((__m256i*)(code2 + i))); + __m512i prod32; + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + prod32 = _mm512_madd_epi16(c1, c2); + } else { + __m512i diff = _mm512_sub_epi16(c1, c2); + prod32 = _mm512_madd_epi16(diff, diff); + } + accu = _mm512_add_epi32(accu, prod32); + } + // Horizontally add elements of accu + return _mm512_reduce_add_epi32(accu); + } + + void set_query(const float* x) final { + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +#endif + #ifdef __AVX2__ template @@ -1531,9 +1904,17 @@ SQDistanceComputer* select_distance_computer( d, trained); case ScalarQuantizer::QT_8bit_direct: - if (d % 16 == 0) { +#if defined(__AVX512F__) + if (d % 32 == 0) { return new DistanceComputerByte(d, trained); - } else { + } else +#endif +#if defined(__AVX2__) + if (d % 16 == 0) { + return new DistanceComputerByte(d, trained); + } else +#endif + { return new DCTemplate< Quantizer8bitDirect, Sim, @@ -1632,8 +2013,13 @@ void ScalarQuantizer::train(size_t n, const float* x) { } ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const { +#if defined(USE_AVX512_F16C) + if (d % 16 == 0) { + return select_quantizer_1<16>(qtype, d, trained); + } else +#endif #if defined(USE_F16C) || defined(__aarch64__) - if (d % 8 == 0) { + if (d % 8 == 0) { return select_quantizer_1<8>(qtype, d, trained); } else #endif @@ -1663,8 +2049,19 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const { SQDistanceComputer* ScalarQuantizer::get_distance_computer( MetricType metric) const { FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT); +#if defined(USE_AVX512_F16C) + if (d % 16 == 0) { + if (metric == METRIC_L2) { + return select_distance_computer>( + qtype, d, trained); + } else { + return select_distance_computer>( + qtype, d, trained); + } + } else +#endif #if defined(USE_F16C) || defined(__aarch64__) - if (d % 8 == 0) { + if (d % 8 == 0) { if (metric == METRIC_L2) { return select_distance_computer>(qtype, d, trained); } else { @@ -1961,11 +2358,21 @@ InvertedListScanner* sel1_InvertedListScanner( Similarity, SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_8bit_direct: - if (sq->d % 16 == 0) { +#if defined(__AVX512F__) + if (sq->d % 32 == 0) { return sel2_InvertedListScanner< DistanceComputerByte>( sq, quantizer, store_pairs, sel, r); - } else { + } else +#endif +#if defined(__AVX2__) + if (sq->d % 16 == 0) { + return sel2_InvertedListScanner< + DistanceComputerByte>( + sq, quantizer, store_pairs, sel, r); + } else +#endif + { return sel2_InvertedListScanner, Similarity, @@ -2009,8 +2416,14 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner( bool store_pairs, const IDSelector* sel, bool by_residual) const { +#if defined(USE_AVX512_F16C) + if (d % 16 == 0) { + return sel0_InvertedListScanner<16>( + mt, this, quantizer, store_pairs, sel, by_residual); + } else +#endif #if defined(USE_F16C) || defined(__aarch64__) - if (d % 8 == 0) { + if (d % 8 == 0) { return sel0_InvertedListScanner<8>( mt, this, quantizer, store_pairs, sel, by_residual); } else diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index 323859f43b..7cebd2ae33 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -23,7 +23,9 @@ #include #endif -#ifdef __AVX2__ +#if defined(__AVX512F__) +#include +#elif defined(__AVX2__) #include #endif @@ -346,6 +348,14 @@ inline float horizontal_sum(const __m256 v) { } #endif +#ifdef __AVX512F__ +/// helper function for AVX512 +inline float horizontal_sum(const __m512 v) { + // performs better than adding the high and low parts + return _mm512_reduce_add_ps(v); +} +#endif + /// Function that does a component-wise operation between x and y /// to compute L2 distances. ElementOp can then be used in the fvec_op_ny /// functions below @@ -366,6 +376,13 @@ struct ElementOpL2 { return _mm256_mul_ps(tmp, tmp); } #endif + +#ifdef __AVX512F__ + static __m512 op(__m512 x, __m512 y) { + __m512 tmp = _mm512_sub_ps(x, y); + return _mm512_mul_ps(tmp, tmp); + } +#endif }; /// Function that does a component-wise operation between x and y @@ -384,6 +401,12 @@ struct ElementOpIP { return _mm256_mul_ps(x, y); } #endif + +#ifdef __AVX512F__ + static __m512 op(__m512 x, __m512 y) { + return _mm512_mul_ps(x, y); + } +#endif }; template @@ -426,7 +449,130 @@ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) { } } -#ifdef __AVX2__ +#if defined(__AVX512F__) + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + + for (i = 0; i < ny16 * 16; i += 16) { + _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); + + // load 16x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + + transpose_16x2( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + v0, + v1); + + // compute distances (dot product) + __m512 distances = _mm512_mul_ps(m0, v0); + distances = _mm512_fmadd_ps(m1, v1, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 32; // move to the next set of 16x2 elements + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float distance = x0 * y[0] + x1 * y[1]; + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + + for (i = 0; i < ny16 * 16; i += 16) { + _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); + + // load 16x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + + transpose_16x2( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + v0, + v1); + + // compute differences + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + + // compute squares of differences + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 32; // move to the next set of 16x2 elements + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + dis[i] = distance; + } + } +} + +#elif defined(__AVX2__) template <> void fvec_op_ny_D2( @@ -562,7 +708,137 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) { } } -#ifdef __AVX2__ +#if defined(__AVX512F__) + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D4-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + + transpose_16x4( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + v0, + v1, + v2, + v3); + + // compute distances + __m512 distances = _mm512_mul_ps(m0, v0); + distances = _mm512_fmadd_ps(m1, v1, distances); + distances = _mm512_fmadd_ps(m2, v2, distances); + distances = _mm512_fmadd_ps(m3, v3, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 64; // move to the next set of 16x4 elements + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D4-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + + transpose_16x4( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + v0, + v1, + v2, + v3); + + // compute differences + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + + // compute squares of differences + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 64; // move to the next set of 16x4 elements + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +#elif defined(__AVX2__) template <> void fvec_op_ny_D4( @@ -710,7 +986,181 @@ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) { } } -#ifdef __AVX2__ +#if defined(__AVX512F__) + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D16-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + const __m512 m4 = _mm512_set1_ps(x[4]); + const __m512 m5 = _mm512_set1_ps(x[5]); + const __m512 m6 = _mm512_set1_ps(x[6]); + const __m512 m7 = _mm512_set1_ps(x[7]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + __m512 v4; + __m512 v5; + __m512 v6; + __m512 v7; + + transpose_16x8( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + _mm512_loadu_ps(y + 4 * 16), + _mm512_loadu_ps(y + 5 * 16), + _mm512_loadu_ps(y + 6 * 16), + _mm512_loadu_ps(y + 7 * 16), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute distances + __m512 distances = _mm512_mul_ps(m0, v0); + distances = _mm512_fmadd_ps(m1, v1, distances); + distances = _mm512_fmadd_ps(m2, v2, distances); + distances = _mm512_fmadd_ps(m3, v3, distances); + distances = _mm512_fmadd_ps(m4, v4, distances); + distances = _mm512_fmadd_ps(m5, v5, distances); + distances = _mm512_fmadd_ps(m6, v6, distances); + distances = _mm512_fmadd_ps(m7, v7, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 128; // 16 floats * 8 rows + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D16-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + const __m512 m4 = _mm512_set1_ps(x[4]); + const __m512 m5 = _mm512_set1_ps(x[5]); + const __m512 m6 = _mm512_set1_ps(x[6]); + const __m512 m7 = _mm512_set1_ps(x[7]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + __m512 v4; + __m512 v5; + __m512 v6; + __m512 v7; + + transpose_16x8( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + _mm512_loadu_ps(y + 4 * 16), + _mm512_loadu_ps(y + 5 * 16), + _mm512_loadu_ps(y + 6 * 16), + _mm512_loadu_ps(y + 7 * 16), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute differences + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + const __m512 d4 = _mm512_sub_ps(m4, v4); + const __m512 d5 = _mm512_sub_ps(m5, v5); + const __m512 d6 = _mm512_sub_ps(m6, v6); + const __m512 d7 = _mm512_sub_ps(m7, v7); + + // compute squares of differences + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + distances = _mm512_fmadd_ps(d4, d4, distances); + distances = _mm512_fmadd_ps(d5, d5, distances); + distances = _mm512_fmadd_ps(d6, d6, distances); + distances = _mm512_fmadd_ps(d7, d7, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 128; // 16 floats * 8 rows + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +#elif defined(__AVX2__) template <> void fvec_op_ny_D8( @@ -955,7 +1405,83 @@ void fvec_inner_products_ny( #undef DISPATCH } -#ifdef __AVX2__ +#if defined(__AVX512F__) + +template +void fvec_L2sqr_ny_y_transposed_D( + float* distances, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // current index being processed + size_t i = 0; + + // squared length of x + float x_sqlen = 0; + for (size_t j = 0; j < DIM; j++) { + x_sqlen += x[j] * x[j]; + } + + // process 16 vectors per loop + const size_t ny16 = ny / 16; + + if (ny16 > 0) { + // m[i] = (2 * x[i], ... 2 * x[i]) + __m512 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm512_set1_ps(x[j]); + m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j] + } + + __m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen); + + for (; i < ny16 * 16; i += 16) { + // Load vectors for 16 dimensions + __m512 v[DIM]; + for (size_t j = 0; j < DIM; j++) { + v[j] = _mm512_loadu_ps(y + j * d_offset); + } + + // Compute dot products + __m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm); + for (size_t j = 1; j < DIM; j++) { + dp = _mm512_fnmadd_ps(m[j], v[j], dp); + } + + // Compute y^2 - (2 * x, y) + x^2 + __m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp); + + _mm512_storeu_ps(distances + i, distances_v); + + // Scroll y and y_sqlen forward + y += 16; + y_sqlen += 16; + } + } + + if (i < ny) { + // Process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // Compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp + x_sqlen; + distances[i] = distance; + + y += 1; + y_sqlen += 1; + } + } +} + +#elif defined(__AVX2__) + template void fvec_L2sqr_ny_y_transposed_D( float* distances, @@ -1031,6 +1557,7 @@ void fvec_L2sqr_ny_y_transposed_D( } } } + #endif void fvec_L2sqr_ny_transposed( @@ -1065,7 +1592,316 @@ void fvec_L2sqr_ny_transposed( #endif } -#ifdef __AVX2__ +#if defined(__AVX512F__) + +size_t fvec_L2sqr_ny_nearest_D2( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + size_t i = 0; + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + const size_t ny16 = ny / 16; + if (ny16 > 0) { + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + + for (; i < ny16 * 16; i += 16) { + _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); + + __m512 v0; + __m512 v1; + + transpose_16x2( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + v0, + v1); + + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + + __mmask16 comparison = + _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); + + min_distances = _mm512_min_ps(distances, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison, min_indices, current_indices); + + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + y += 32; + } + + alignas(64) float min_distances_scalar[16]; + alignas(64) uint32_t min_indices_scalar[16]; + _mm512_store_ps(min_distances_scalar, min_distances); + _mm512_store_epi32(min_indices_scalar, min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +size_t fvec_L2sqr_ny_nearest_D4( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + size_t i = 0; + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + const size_t ny16 = ny / 16; + + if (ny16 > 0) { + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + for (; i < ny16 * 16; i += 16) { + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + + transpose_16x4( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + v0, + v1, + v2, + v3); + + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + + __mmask16 comparison = + _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); + + min_distances = _mm512_min_ps(distances, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison, min_indices, current_indices); + + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + y += 64; + } + + alignas(64) float min_distances_scalar[16]; + alignas(64) uint32_t min_indices_scalar[16]; + _mm512_store_ps(min_distances_scalar, min_distances); + _mm512_store_epi32(min_indices_scalar, min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +size_t fvec_L2sqr_ny_nearest_D8( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + size_t i = 0; + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + const size_t ny16 = ny / 16; + if (ny16 > 0) { + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + const __m512 m4 = _mm512_set1_ps(x[4]); + const __m512 m5 = _mm512_set1_ps(x[5]); + const __m512 m6 = _mm512_set1_ps(x[6]); + const __m512 m7 = _mm512_set1_ps(x[7]); + + for (; i < ny16 * 16; i += 16) { + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + __m512 v4; + __m512 v5; + __m512 v6; + __m512 v7; + + transpose_16x8( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + _mm512_loadu_ps(y + 4 * 16), + _mm512_loadu_ps(y + 5 * 16), + _mm512_loadu_ps(y + 6 * 16), + _mm512_loadu_ps(y + 7 * 16), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + const __m512 d4 = _mm512_sub_ps(m4, v4); + const __m512 d5 = _mm512_sub_ps(m5, v5); + const __m512 d6 = _mm512_sub_ps(m6, v6); + const __m512 d7 = _mm512_sub_ps(m7, v7); + + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + distances = _mm512_fmadd_ps(d4, d4, distances); + distances = _mm512_fmadd_ps(d5, d5, distances); + distances = _mm512_fmadd_ps(d6, d6, distances); + distances = _mm512_fmadd_ps(d7, d7, distances); + + __mmask16 comparison = + _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); + + min_distances = _mm512_min_ps(distances, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison, min_indices, current_indices); + + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + y += 128; + } + + alignas(64) float min_distances_scalar[16]; + alignas(64) uint32_t min_indices_scalar[16]; + _mm512_store_ps(min_distances_scalar, min_distances); + _mm512_store_epi32(min_indices_scalar, min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +#elif defined(__AVX2__) size_t fvec_L2sqr_ny_nearest_D2( float* distances_tmp_buffer, @@ -1476,7 +2312,123 @@ size_t fvec_L2sqr_ny_nearest( #undef DISPATCH } -#ifdef __AVX2__ +#if defined(__AVX512F__) + +template +size_t fvec_L2sqr_ny_nearest_y_transposed_D( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // This implementation does not use distances_tmp_buffer. + + // Current index being processed + size_t i = 0; + + // Min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // Process 16 vectors per loop + const size_t ny16 = ny / 16; + + if (ny16 > 0) { + // Track min distance and the closest vector independently + // for each of 16 AVX-512 components. + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + // m[i] = (2 * x[i], ... 2 * x[i]) + __m512 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm512_set1_ps(x[j]); + m[j] = _mm512_add_ps(m[j], m[j]); + } + + for (; i < ny16 * 16; i += 16) { + // Compute dot products + const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset); + __m512 dp = _mm512_mul_ps(m[0], v0); + for (size_t j = 1; j < DIM; j++) { + const __m512 vj = _mm512_loadu_ps(y + j * d_offset); + dp = _mm512_fmadd_ps(m[j], vj, dp); + } + + // Compute y^2 - (2 * x, y), which is sufficient for looking for the + // lowest distance. + // x^2 is the constant that can be avoided. + const __m512 distances = + _mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp); + + // Compare the new distances to the min distances + __mmask16 comparison = + _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); + + // Update min distances and indices with closest vectors if needed + min_distances = + _mm512_mask_blend_ps(comparison, distances, min_distances); + min_indices = _mm512_castps_si512(_mm512_mask_blend_ps( + comparison, + _mm512_castsi512_ps(current_indices), + _mm512_castsi512_ps(min_indices))); + + // Update current indices values. Basically, +16 to each of the 16 + // AVX-512 components. + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + // Scroll y and y_sqlen forward. + y += 16; + y_sqlen += 16; + } + + // Dump values and find the minimum distance / minimum index + float min_distances_scalar[16]; + uint32_t min_indices_scalar[16]; + _mm512_storeu_ps(min_distances_scalar, min_distances); + _mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // Process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // Compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + + y += 1; + y_sqlen += 1; + } + } + + return current_min_index; +} + +#elif defined(__AVX2__) + template size_t fvec_L2sqr_ny_nearest_y_transposed_D( float* distances_tmp_buffer, @@ -1592,6 +2544,7 @@ size_t fvec_L2sqr_ny_nearest_y_transposed_D( return current_min_index; } + #endif size_t fvec_L2sqr_ny_nearest_y_transposed( @@ -1858,7 +2811,39 @@ void fvec_inner_products_ny( c[i] = a[i] + bf * b[i]; } -#ifdef __AVX2__ +#if defined(__AVX512F__) + +static inline void fvec_madd_avx512( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + const size_t n16 = n / 16; + const size_t n_for_masking = n % 16; + + const __m512 bfmm = _mm512_set1_ps(bf); + + size_t idx = 0; + for (idx = 0; idx < n16 * 16; idx += 16) { + const __m512 ax = _mm512_loadu_ps(a + idx); + const __m512 bx = _mm512_loadu_ps(b + idx); + const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); + _mm512_storeu_ps(c + idx, abmul); + } + + if (n_for_masking > 0) { + const __mmask16 mask = (1 << n_for_masking) - 1; + + const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx); + const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx); + const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); + _mm512_mask_storeu_ps(c + idx, mask, abmul); + } +} + +#elif defined(__AVX2__) + static inline void fvec_madd_avx2( const size_t n, const float* __restrict a, @@ -1911,6 +2896,7 @@ static inline void fvec_madd_avx2( _mm256_maskstore_ps(c + idx, mask, abmul); } } + #endif #ifdef __SSE3__ @@ -1936,7 +2922,9 @@ static inline void fvec_madd_avx2( } void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { -#ifdef __AVX2__ +#ifdef __AVX512F__ + fvec_madd_avx512(n, a, bf, b, c); +#elif __AVX2__ fvec_madd_avx2(n, a, bf, b, c); #else if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) diff --git a/faiss/utils/transpose/transpose-avx512-inl.h b/faiss/utils/transpose/transpose-avx512-inl.h new file mode 100644 index 0000000000..d8c41af91c --- /dev/null +++ b/faiss/utils/transpose/transpose-avx512-inl.h @@ -0,0 +1,176 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// This file contains transposing kernels for AVX512 for // tiny float/int32 +// matrices, such as 16x2. + +#ifdef __AVX512F__ + +#include + +namespace faiss { + +// 16x2 -> 2x16 +inline void transpose_16x2( + const __m512 i0, + const __m512 i1, + __m512& o0, + __m512& o1) { + // assume we have the following input: + // i0: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // i1: 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 + + // 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 + const __m512 r0 = _mm512_shuffle_f32x4(i0, i1, _MM_SHUFFLE(2, 0, 2, 0)); + // 4 5 6 7 12 13 14 15 20 21 22 23 28 29 30 31 + const __m512 r1 = _mm512_shuffle_f32x4(i0, i1, _MM_SHUFFLE(3, 1, 3, 1)); + + // 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 + o0 = _mm512_shuffle_ps(r0, r1, _MM_SHUFFLE(2, 0, 2, 0)); + // 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 + o1 = _mm512_shuffle_ps(r0, r1, _MM_SHUFFLE(3, 1, 3, 1)); +} + +// 16x4 -> 4x16 +inline void transpose_16x4( + const __m512 i0, + const __m512 i1, + const __m512 i2, + const __m512 i3, + __m512& o0, + __m512& o1, + __m512& o2, + __m512& o3) { + // assume that we have the following input: + // i0: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // i1: 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 + // i2: 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 + // i3: 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 + + // 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 + const __m512 r0 = _mm512_shuffle_f32x4(i0, i1, _MM_SHUFFLE(2, 0, 2, 0)); + // 4 5 6 7 12 13 14 15 20 21 22 23 28 29 30 31 + const __m512 r1 = _mm512_shuffle_f32x4(i0, i1, _MM_SHUFFLE(3, 1, 3, 1)); + // 32 33 34 35 40 41 42 43 48 49 50 51 56 57 58 59 + const __m512 r2 = _mm512_shuffle_f32x4(i2, i3, _MM_SHUFFLE(2, 0, 2, 0)); + // 52 53 54 55 60 61 62 63 52 53 54 55 60 61 62 63 + const __m512 r3 = _mm512_shuffle_f32x4(i2, i3, _MM_SHUFFLE(3, 1, 3, 1)); + + // 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 + const __m512 t0 = _mm512_shuffle_ps(r0, r1, _MM_SHUFFLE(2, 0, 2, 0)); + // 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 + const __m512 t1 = _mm512_shuffle_ps(r0, r1, _MM_SHUFFLE(3, 1, 3, 1)); + // 32 34 52 54 40 42 60 62 48 50 52 54 56 58 60 62 + const __m512 t2 = _mm512_shuffle_ps(r2, r3, _MM_SHUFFLE(2, 0, 2, 0)); + // 33 35 53 55 41 43 61 63 49 51 53 55 57 59 61 63 + const __m512 t3 = _mm512_shuffle_ps(r2, r3, _MM_SHUFFLE(3, 1, 3, 1)); + + const __m512i idx0 = _mm512_set_epi32( + 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); + const __m512i idx1 = _mm512_set_epi32( + 31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); + + // 0 4 8 12 16 20 24 28 32 52 40 60 48 52 56 60 + o0 = _mm512_permutex2var_ps(t0, idx0, t2); + // 1 5 9 13 17 21 25 29 33 53 41 61 49 53 57 61 + o1 = _mm512_permutex2var_ps(t1, idx0, t3); + // 2 6 10 14 18 22 26 30 34 54 42 62 50 54 58 62 + o2 = _mm512_permutex2var_ps(t0, idx1, t2); + // 3 7 11 15 19 23 27 31 35 55 43 63 51 55 59 63 + o3 = _mm512_permutex2var_ps(t1, idx1, t3); +} + +// 16x8 -> 8x16 transpose +inline void transpose_16x8( + const __m512 i0, + const __m512 i1, + const __m512 i2, + const __m512 i3, + const __m512 i4, + const __m512 i5, + const __m512 i6, + const __m512 i7, + __m512& o0, + __m512& o1, + __m512& o2, + __m512& o3, + __m512& o4, + __m512& o5, + __m512& o6, + __m512& o7) { + // assume that we have the following input: + // i0: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // i1: 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 + // i2: 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 + // i3: 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 + // i4: 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 + // i5: 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 + // i6: 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 + // i7: 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 + + // 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29 + const __m512 r0 = _mm512_unpacklo_ps(i0, i1); + // 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31 + const __m512 r1 = _mm512_unpackhi_ps(i0, i1); + // 32 48 33 49 36 52 37 53 40 56 41 57 44 60 45 61 + const __m512 r2 = _mm512_unpacklo_ps(i2, i3); + // 34 50 35 51 38 54 39 55 42 58 43 59 46 62 47 63 + const __m512 r3 = _mm512_unpackhi_ps(i2, i3); + // 64 80 65 81 68 84 69 85 72 88 73 89 76 92 77 93 + const __m512 r4 = _mm512_unpacklo_ps(i4, i5); + // 66 82 67 83 70 86 71 87 74 90 75 91 78 94 79 95 + const __m512 r5 = _mm512_unpackhi_ps(i4, i5); + // 96 112 97 113 100 116 101 117 104 120 105 121 108 124 109 125 + const __m512 r6 = _mm512_unpacklo_ps(i6, i7); + // 98 114 99 115 102 118 103 119 106 122 107 123 110 126 111 127 + const __m512 r7 = _mm512_unpackhi_ps(i6, i7); + + // 0 16 32 48 4 20 36 52 8 24 40 56 12 28 44 60 + const __m512 t0 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(1, 0, 1, 0)); + // 1 17 33 49 5 21 37 53 9 25 41 57 13 29 45 61 + const __m512 t1 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 2, 3, 2)); + // 2 18 34 50 6 22 38 54 10 26 42 58 14 30 46 62 + const __m512 t2 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(1, 0, 1, 0)); + // 3 19 35 51 7 23 39 55 11 27 43 59 15 31 47 63 + const __m512 t3 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 2, 3, 2)); + // 64 80 96 112 68 84 100 116 72 88 104 120 76 92 108 124 + const __m512 t4 = _mm512_shuffle_ps(r4, r6, _MM_SHUFFLE(1, 0, 1, 0)); + // 65 81 97 113 69 85 101 117 73 89 105 121 77 93 109 125 + const __m512 t5 = _mm512_shuffle_ps(r4, r6, _MM_SHUFFLE(3, 2, 3, 2)); + // 66 82 98 114 70 86 102 118 74 90 106 122 78 94 110 126 + const __m512 t6 = _mm512_shuffle_ps(r5, r7, _MM_SHUFFLE(1, 0, 1, 0)); + // 67 83 99 115 71 87 103 119 75 91 107 123 79 95 111 127 + const __m512 t7 = _mm512_shuffle_ps(r5, r7, _MM_SHUFFLE(3, 2, 3, 2)); + + const __m512i idx0 = _mm512_set_epi32( + 27, 19, 26, 18, 25, 17, 24, 16, 11, 3, 10, 2, 9, 1, 8, 0); + const __m512i idx1 = _mm512_set_epi32( + 31, 23, 30, 22, 29, 21, 28, 20, 15, 7, 14, 6, 13, 5, 12, 4); + + // 0 8 16 24 32 40 48 56 64 72 80 88 96 104 112 120 + o0 = _mm512_permutex2var_ps(t0, idx0, t4); + // 1 9 17 25 33 41 49 57 65 73 81 89 97 105 113 121 + o1 = _mm512_permutex2var_ps(t1, idx0, t5); + // 2 10 18 26 34 42 50 58 66 74 82 90 98 106 114 122 + o2 = _mm512_permutex2var_ps(t2, idx0, t6); + // 3 11 19 27 35 43 51 59 67 75 83 91 99 107 115 123 + o3 = _mm512_permutex2var_ps(t3, idx0, t7); + // 4 12 20 28 36 44 52 60 68 76 84 92 100 108 116 124 + o4 = _mm512_permutex2var_ps(t0, idx1, t4); + // 5 13 21 29 37 45 53 61 69 77 85 93 101 109 117 125 + o5 = _mm512_permutex2var_ps(t1, idx1, t5); + // 6 14 22 30 38 46 54 62 70 78 86 94 102 110 118 126 + o6 = _mm512_permutex2var_ps(t2, idx1, t6); + // 7 15 23 31 39 47 55 63 71 79 87 95 103 111 119 127 + o7 = _mm512_permutex2var_ps(t3, idx1, t7); +} + +} // namespace faiss + +#endif