Skip to content

Commit 6abfc23

Browse files
committed
Added support for binary (1bit) vectors and SIMD-Optimized Hamming distance
1 parent bc1cae5 commit 6abfc23

9 files changed

Lines changed: 492 additions & 69 deletions

src/distance-avx2.c

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,47 @@ float int8_distance_cosine_avx2 (const void *a, const void *b, int n) {
949949
return 1.0f - cosine_similarity;
950950
}
951951

952+
// MARK: - BIT -
953+
954+
// lookup table for popcount of 4-bit values
955+
static const __m256i popcount_lut = _mm256_setr_epi8(0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4);
956+
957+
static inline __m256i popcount_avx2(__m256i v) {
958+
__m256i low_mask = _mm256_set1_epi8(0x0f);
959+
__m256i lo = _mm256_and_si256(v, low_mask);
960+
__m256i hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), low_mask);
961+
__m256i cnt_lo = _mm256_shuffle_epi8(popcount_lut, lo);
962+
__m256i cnt_hi = _mm256_shuffle_epi8(popcount_lut, hi);
963+
return _mm256_add_epi8(cnt_lo, cnt_hi);
964+
}
965+
966+
float bit1_distance_hamming_avx2 (const void *v1, const void *v2, int n) {
967+
const uint8_t *a = (const uint8_t *)v1;
968+
const uint8_t *b = (const uint8_t *)v2;
969+
__m256i acc = _mm256_setzero_si256();
970+
int i = 0;
971+
972+
// Process 32 bytes at a time
973+
for (; i + 32 <= n; i += 32) {
974+
__m256i va = _mm256_loadu_si256((const __m256i *)(a + i));
975+
__m256i vb = _mm256_loadu_si256((const __m256i *)(b + i));
976+
__m256i xored = _mm256_xor_si256(va, vb);
977+
__m256i popcnt = popcount_avx2(xored);
978+
acc = _mm256_add_epi64(acc, _mm256_sad_epu8(popcnt, _mm256_setzero_si256()));
979+
}
980+
981+
// Horizontal sum
982+
__m128i sum128 = _mm_add_epi64(_mm256_extracti128_si256(acc, 0), _mm256_extracti128_si256(acc, 1));
983+
int distance = _mm_extract_epi64(sum128, 0) + _mm_extract_epi64(sum128, 1);
984+
985+
// Handle remainder with scalar
986+
for (; i < n; i++) {
987+
distance += __builtin_popcount(a[i] ^ b[i]);
988+
}
989+
990+
return (float)distance;
991+
}
992+
952993
#endif
953994

954995
// MARK: -
@@ -985,6 +1026,8 @@ void init_distance_functions_avx2 (void) {
9851026
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_avx2;
9861027
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_avx2;
9871028

1029+
dispatch_distance_table[VECTOR_DISTANCE_HAMMING][VECTOR_TYPE_BIT] = bit1_distance_hamming_avx2;
1030+
9881031
distance_backend_name = "AVX2";
9891032
#endif
9901033
}

src/distance-avx512.c

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ static inline bool block_has_l2_inf_mismatch_16(const uint16_t* a, const uint16_
4848
return false;
4949
}
5050

51-
/* 16×bf16 -> 16×f32: widen to u32, shift <<16, reinterpret as f32 */
51+
/* 16bf16 -> 16f32: widen to u32, shift <<16, reinterpret as f32 */
5252
static inline __m512 bf16x16_to_f32x16_loadu(const uint16_t* p) {
5353
// Load 16x u16 (256 bits)
5454
__m256i v16 = _mm256_loadu_si256((const __m256i*)p);
@@ -846,6 +846,72 @@ float int8_distance_cosine_avx512(const void* a, const void* b, int n) {
846846
return 1.0f - cosine_similarity;
847847
}
848848

849+
// MARK: - BIT -
850+
851+
// AVX-512 popcount using lookup table (works on all AVX-512 CPUs)
852+
static inline __m512i popcount_avx512(__m512i v) {
853+
// Lookup table for popcount of 4-bit values
854+
const __m512i popcount_lut = _mm512_set_epi8(
855+
4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0,
856+
4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0,
857+
4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0,
858+
4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0
859+
);
860+
const __m512i low_mask = _mm512_set1_epi8(0x0f);
861+
862+
__m512i lo = _mm512_and_si512(v, low_mask);
863+
__m512i hi = _mm512_and_si512(_mm512_srli_epi16(v, 4), low_mask);
864+
__m512i cnt_lo = _mm512_shuffle_epi8(popcount_lut, lo);
865+
__m512i cnt_hi = _mm512_shuffle_epi8(popcount_lut, hi);
866+
return _mm512_add_epi8(cnt_lo, cnt_hi);
867+
}
868+
869+
// Hamming distance for 1-bit packed binary vectors
870+
// n = number of dimensions (bits), not bytes
871+
static float bit1_distance_hamming_avx512(const void *v1, const void *v2, int n) {
872+
const uint8_t *a = (const uint8_t *)v1;
873+
const uint8_t *b = (const uint8_t *)v2;
874+
int num_bytes = (n + 7) / 8;
875+
876+
__m512i acc = _mm512_setzero_si512();
877+
int i = 0;
878+
879+
// Process 64 bytes at a time
880+
for (; i + 64 <= num_bytes; i += 64) {
881+
__m512i va = _mm512_loadu_si512((const __m512i *)(a + i));
882+
__m512i vb = _mm512_loadu_si512((const __m512i *)(b + i));
883+
__m512i xored = _mm512_xor_si512(va, vb);
884+
885+
#if defined(__AVX512VPOPCNTDQ__)
886+
// Native popcount (Ice Lake+)
887+
__m512i popcnt = _mm512_popcnt_epi64(xored);
888+
acc = _mm512_add_epi64(acc, popcnt);
889+
#else
890+
// Lookup table popcount (Skylake-X compatible)
891+
__m512i popcnt = popcount_avx512(xored);
892+
// Sum bytes to 64-bit using SAD against zero
893+
acc = _mm512_add_epi64(acc, _mm512_sad_epu8(popcnt, _mm512_setzero_si512()));
894+
#endif
895+
}
896+
897+
// Horizontal sum
898+
uint64_t distance = _mm512_reduce_add_epi64(acc);
899+
900+
// Handle remaining bytes with scalar code
901+
for (; i < num_bytes; i++) {
902+
#if defined(__GNUC__) || defined(__clang__)
903+
distance += __builtin_popcount(a[i] ^ b[i]);
904+
#else
905+
uint8_t x = a[i] ^ b[i];
906+
x = x - ((x >> 1) & 0x55);
907+
x = (x & 0x33) + ((x >> 2) & 0x33);
908+
distance += (x + (x >> 4)) & 0x0f;
909+
#endif
910+
}
911+
912+
return (float)distance;
913+
}
914+
849915
#endif
850916

851917
// MARK: -
@@ -882,6 +948,8 @@ void init_distance_functions_avx512(void) {
882948
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_avx512;
883949
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_avx512;
884950

951+
dispatch_distance_table[VECTOR_DISTANCE_HAMMING][VECTOR_TYPE_BIT] = bit1_distance_hamming_avx512;
952+
885953
distance_backend_name = "AVX512";
886954
#endif
887-
}
955+
}

src/distance-avx512.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010

1111
#include <stdio.h>
1212

13-
void init_distance_functions_avx512(void);
13+
void init_distance_functions_avx512 (void);
1414

1515
#endif

src/distance-cpu.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,42 @@ float int8_distance_l1_cpu (const void *v1, const void *v2, int n) {
693693
return sum;
694694
}
695695

696+
// MARK: - BIT -
697+
698+
static inline int popcount64(uint64_t x) {
699+
#if defined(__GNUC__) || defined(__clang__)
700+
return __builtin_popcountll(x);
701+
#else
702+
// fallback: bit manipulation
703+
x = x - ((x >> 1) & 0x5555555555555555ULL);
704+
x = (x & 0x3333333333333333ULL) + ((x >> 2) & 0x3333333333333333ULL);
705+
x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0fULL;
706+
return (x * 0x0101010101010101ULL) >> 56;
707+
#endif
708+
}
709+
710+
float bit1_distance_hamming_cpu (const void *v1, const void *v2, int n) {
711+
const uint8_t *a = (const uint8_t *)v1;
712+
const uint8_t *b = (const uint8_t *)v2;
713+
714+
int distance = 0;
715+
int i = 0;
716+
717+
// process 8 bytes at a time
718+
for (; i + 8 <= n; i += 8) {
719+
uint64_t xa = *(const uint64_t *)(a + i);
720+
uint64_t xb = *(const uint64_t *)(b + i);
721+
distance += popcount64(xa ^ xb);
722+
}
723+
724+
// handle remainder
725+
for (; i < n; i++) {
726+
distance += popcount64(a[i] ^ b[i]);
727+
}
728+
729+
return (float)distance;
730+
}
731+
696732
// MARK: - ENTRYPOINT -
697733

698734
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86)
@@ -845,6 +881,9 @@ void init_cpu_functions (void) {
845881
[VECTOR_TYPE_BF16] = bfloat16_distance_l1_cpu,
846882
[VECTOR_TYPE_U8] = uint8_distance_l1_cpu,
847883
[VECTOR_TYPE_I8] = int8_distance_l1_cpu,
884+
},
885+
[VECTOR_DISTANCE_HAMMING] = {
886+
[VECTOR_TYPE_BIT] = bit1_distance_hamming_cpu
848887
}
849888
};
850889

src/distance-cpu.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,16 @@ typedef enum {
3838
VECTOR_TYPE_F16,
3939
VECTOR_TYPE_BF16,
4040
VECTOR_TYPE_U8,
41-
VECTOR_TYPE_I8
41+
VECTOR_TYPE_I8,
42+
VECTOR_TYPE_BIT
4243
} vector_type;
43-
#define VECTOR_TYPE_MAX 6
44+
#define VECTOR_TYPE_MAX 7
4445

4546
typedef enum {
4647
VECTOR_QUANT_AUTO = 0,
4748
VECTOR_QUANT_U8BIT = 1,
48-
VECTOR_QUANT_S8BIT = 2
49+
VECTOR_QUANT_S8BIT = 2,
50+
VECTOR_QUANT_1BIT = 3
4951
} vector_qtype;
5052

5153
typedef enum {
@@ -54,8 +56,9 @@ typedef enum {
5456
VECTOR_DISTANCE_COSINE,
5557
VECTOR_DISTANCE_DOT,
5658
VECTOR_DISTANCE_L1,
59+
VECTOR_DISTANCE_HAMMING
5760
} vector_distance;
58-
#define VECTOR_DISTANCE_MAX 6
61+
#define VECTOR_DISTANCE_MAX 7
5962

6063
typedef float (*distance_function_t)(const void *v1, const void *v2, int n);
6164

src/distance-neon.c

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,38 @@ float int8_distance_l1_neon(const void *v1, const void *v2, int n) {
12301230

12311231
return (float)final;
12321232
}
1233+
1234+
// MARK: - BIT -
1235+
1236+
float bit1_distance_hamming_neon (const void *v1, const void *v2, int n) {
1237+
const uint8_t *a = (const uint8_t *)v1;
1238+
const uint8_t *b = (const uint8_t *)v2;
1239+
uint64x2_t acc = vdupq_n_u64(0);
1240+
int i = 0;
1241+
1242+
// Process 16 bytes at a time
1243+
for (; i + 16 <= n; i += 16) {
1244+
uint8x16_t va = vld1q_u8(a + i);
1245+
uint8x16_t vb = vld1q_u8(b + i);
1246+
uint8x16_t xored = veorq_u8(va, vb);
1247+
1248+
// vcntq_u8: popcount per byte
1249+
uint8x16_t popcnt = vcntq_u8(xored);
1250+
1251+
// Sum bytes to 64-bit accumulators
1252+
acc = vpadalq_u32(acc, vpaddlq_u16(vpaddlq_u8(popcnt)));
1253+
}
1254+
1255+
int distance = (int)(vgetq_lane_u64(acc, 0) + vgetq_lane_u64(acc, 1));
1256+
1257+
// Handle remainder
1258+
for (; i < n; i++) {
1259+
distance += __builtin_popcount(a[i] ^ b[i]);
1260+
}
1261+
1262+
return (float)distance;
1263+
}
1264+
12331265
#endif
12341266

12351267
// MARK: -
@@ -1266,6 +1298,8 @@ void init_distance_functions_neon (void) {
12661298
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_neon;
12671299
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_neon;
12681300

1301+
dispatch_distance_table[VECTOR_DISTANCE_HAMMING][VECTOR_TYPE_BIT] = bit1_distance_hamming_neon;
1302+
12691303
distance_backend_name = "NEON";
12701304
#endif
12711305
}

src/distance-sse2.c

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,67 @@ float int8_distance_cosine_sse2 (const void *v1, const void *v2, int n) {
10061006
return 1.0f - cosine_sim;
10071007
}
10081008

1009+
// MARK: - BIT -
1010+
1011+
static inline __m128i popcount_sse2 (__m128i v) {
1012+
// Classic parallel bit count algorithm vectorized for SSE2
1013+
1014+
const __m128i mask1 = _mm_set1_epi8(0x55); // 01010101
1015+
const __m128i mask2 = _mm_set1_epi8(0x33); // 00110011
1016+
const __m128i mask4 = _mm_set1_epi8(0x0f); // 00001111
1017+
1018+
// x = x - ((x >> 1) & 0x55555555)
1019+
__m128i t = _mm_and_si128(_mm_srli_epi16(v, 1), mask1);
1020+
v = _mm_sub_epi8(v, t);
1021+
1022+
// x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
1023+
t = _mm_and_si128(_mm_srli_epi16(v, 2), mask2);
1024+
v = _mm_add_epi8(_mm_and_si128(v, mask2), t);
1025+
1026+
// x = (x + (x >> 4)) & 0x0f0f0f0f
1027+
t = _mm_srli_epi16(v, 4);
1028+
v = _mm_and_si128(_mm_add_epi8(v, t), mask4);
1029+
1030+
// Now each byte contains popcount for that byte (0-8)
1031+
return v;
1032+
}
1033+
1034+
float bit1_distance_hamming_sse2 (const void *v1, const void *v2, int n) {
1035+
const uint8_t *a = (const uint8_t *)v1;
1036+
const uint8_t *b = (const uint8_t *)v2;
1037+
__m128i acc = _mm_setzero_si128();
1038+
int i = 0;
1039+
1040+
// Process 16 bytes at a time
1041+
for (; i + 16 <= n; i += 16) {
1042+
__m128i va = _mm_loadu_si128((const __m128i *)(a + i));
1043+
__m128i vb = _mm_loadu_si128((const __m128i *)(b + i));
1044+
__m128i xored = _mm_xor_si128(va, vb);
1045+
__m128i popcnt = popcount_sse2(xored);
1046+
1047+
// Sum bytes using SAD (sum of absolute differences against zero)
1048+
// This sums all 16 bytes into two 64-bit values
1049+
acc = _mm_add_epi64(acc, _mm_sad_epu8(popcnt, _mm_setzero_si128()));
1050+
}
1051+
1052+
// Horizontal sum of the two 64-bit accumulators
1053+
int distance = _mm_cvtsi128_si64(acc) + _mm_cvtsi128_si64(_mm_srli_si128(acc, 8));
1054+
1055+
// Handle remainder with scalar code
1056+
for (; i < n; i++) {
1057+
#if defined(__GNUC__) || defined(__clang__)
1058+
distance += __builtin_popcount(a[i] ^ b[i]);
1059+
#else
1060+
uint8_t x = a[i] ^ b[i];
1061+
x = x - ((x >> 1) & 0x55);
1062+
x = (x & 0x33) + ((x >> 2) & 0x33);
1063+
distance += (x + (x >> 4)) & 0x0f;
1064+
#endif
1065+
}
1066+
1067+
return (float)distance;
1068+
}
1069+
10091070
#endif
10101071

10111072
// MARK: -
@@ -1042,6 +1103,8 @@ void init_distance_functions_sse2 (void) {
10421103
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_sse2;
10431104
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_sse2;
10441105

1106+
dispatch_distance_table[VECTOR_DISTANCE_HAMMING][VECTOR_TYPE_BIT] = bit1_distance_hamming_sse2;
1107+
10451108
distance_backend_name = "SSE2";
10461109
#endif
10471110
}

0 commit comments

Comments
 (0)