@@ -48,7 +48,7 @@ static inline bool block_has_l2_inf_mismatch_16(const uint16_t* a, const uint16_
4848 return false;
4949}
5050
51- /* 16× bf16 -> 16× f32: widen to u32, shift <<16, reinterpret as f32 */
51+ /* 16� bf16 -> 16� f32: widen to u32, shift <<16, reinterpret as f32 */
5252static inline __m512 bf16x16_to_f32x16_loadu (const uint16_t * p ) {
5353 // Load 16x u16 (256 bits)
5454 __m256i v16 = _mm256_loadu_si256 ((const __m256i * )p );
@@ -846,6 +846,72 @@ float int8_distance_cosine_avx512(const void* a, const void* b, int n) {
846846 return 1.0f - cosine_similarity ;
847847}
848848
849+ // MARK: - BIT -
850+
851+ // AVX-512 popcount using lookup table (works on all AVX-512 CPUs)
852+ static inline __m512i popcount_avx512 (__m512i v ) {
853+ // Lookup table for popcount of 4-bit values
854+ const __m512i popcount_lut = _mm512_set_epi8 (
855+ 4 , 3 , 3 , 2 , 3 , 2 , 2 , 1 , 3 , 2 , 2 , 1 , 2 , 1 , 1 , 0 ,
856+ 4 , 3 , 3 , 2 , 3 , 2 , 2 , 1 , 3 , 2 , 2 , 1 , 2 , 1 , 1 , 0 ,
857+ 4 , 3 , 3 , 2 , 3 , 2 , 2 , 1 , 3 , 2 , 2 , 1 , 2 , 1 , 1 , 0 ,
858+ 4 , 3 , 3 , 2 , 3 , 2 , 2 , 1 , 3 , 2 , 2 , 1 , 2 , 1 , 1 , 0
859+ );
860+ const __m512i low_mask = _mm512_set1_epi8 (0x0f );
861+
862+ __m512i lo = _mm512_and_si512 (v , low_mask );
863+ __m512i hi = _mm512_and_si512 (_mm512_srli_epi16 (v , 4 ), low_mask );
864+ __m512i cnt_lo = _mm512_shuffle_epi8 (popcount_lut , lo );
865+ __m512i cnt_hi = _mm512_shuffle_epi8 (popcount_lut , hi );
866+ return _mm512_add_epi8 (cnt_lo , cnt_hi );
867+ }
868+
869+ // Hamming distance for 1-bit packed binary vectors
870+ // n = number of dimensions (bits), not bytes
871+ static float bit1_distance_hamming_avx512 (const void * v1 , const void * v2 , int n ) {
872+ const uint8_t * a = (const uint8_t * )v1 ;
873+ const uint8_t * b = (const uint8_t * )v2 ;
874+ int num_bytes = (n + 7 ) / 8 ;
875+
876+ __m512i acc = _mm512_setzero_si512 ();
877+ int i = 0 ;
878+
879+ // Process 64 bytes at a time
880+ for (; i + 64 <= num_bytes ; i += 64 ) {
881+ __m512i va = _mm512_loadu_si512 ((const __m512i * )(a + i ));
882+ __m512i vb = _mm512_loadu_si512 ((const __m512i * )(b + i ));
883+ __m512i xored = _mm512_xor_si512 (va , vb );
884+
885+ #if defined(__AVX512VPOPCNTDQ__ )
886+ // Native popcount (Ice Lake+)
887+ __m512i popcnt = _mm512_popcnt_epi64 (xored );
888+ acc = _mm512_add_epi64 (acc , popcnt );
889+ #else
890+ // Lookup table popcount (Skylake-X compatible)
891+ __m512i popcnt = popcount_avx512 (xored );
892+ // Sum bytes to 64-bit using SAD against zero
893+ acc = _mm512_add_epi64 (acc , _mm512_sad_epu8 (popcnt , _mm512_setzero_si512 ()));
894+ #endif
895+ }
896+
897+ // Horizontal sum
898+ uint64_t distance = _mm512_reduce_add_epi64 (acc );
899+
900+ // Handle remaining bytes with scalar code
901+ for (; i < num_bytes ; i ++ ) {
902+ #if defined(__GNUC__ ) || defined(__clang__ )
903+ distance += __builtin_popcount (a [i ] ^ b [i ]);
904+ #else
905+ uint8_t x = a [i ] ^ b [i ];
906+ x = x - ((x >> 1 ) & 0x55 );
907+ x = (x & 0x33 ) + ((x >> 2 ) & 0x33 );
908+ distance += (x + (x >> 4 )) & 0x0f ;
909+ #endif
910+ }
911+
912+ return (float )distance ;
913+ }
914+
849915#endif
850916
851917// MARK: -
@@ -882,6 +948,8 @@ void init_distance_functions_avx512(void) {
882948 dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_U8 ] = uint8_distance_l1_avx512 ;
883949 dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_I8 ] = int8_distance_l1_avx512 ;
884950
951+ dispatch_distance_table [VECTOR_DISTANCE_HAMMING ][VECTOR_TYPE_BIT ] = bit1_distance_hamming_avx512 ;
952+
885953 distance_backend_name = "AVX512" ;
886954#endif
887- }
955+ }
0 commit comments