Skip to content

Commit f0d8005

Browse files
committed
perf: SIMD-optimized string comparison and chunked escape rendering
Motivation: String comparison (compareStringsByCodepoint) and long string rendering are hot paths in sort-heavy and render-heavy Jsonnet workloads. The comparison used per-char charAt() virtual dispatch preventing JIT vectorization. Long string rendering used a binary scan (clean→bulk copy, dirty→full reprocess from position 0). Modification: 1. compareStrings: bulk getChars() + tight array loop enabling JIT auto-vectorization (AVX2/SSE). Surrogate check deferred to mismatch point only (O(1) vs O(n)). ThreadLocal buffers on JVM, local alloc on Native, scalar fallback on JS. 2. findFirstEscapeChar: SWAR scan returning position (not boolean). 3. visitLongString: chunked rendering — find escape position, arraycopy clean prefix, escape inline, repeat. Avoids re-processing entire string when only a few chars need escaping. Result: All tests pass across JVM (Scala 3.3.7, 2.13.18) and JS. All benchmark regressions pass. Endian-safe (SWAR operates on independent byte lanes).
1 parent 4d16e17 commit f0d8005

5 files changed

Lines changed: 399 additions & 44 deletions

File tree

sjsonnet/src-js/sjsonnet/CharSWAR.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,42 @@ object CharSWAR {
3333
}
3434
false
3535
}
36+
37+
/** Scalar scan returning position of first escape char, or -1 if none. */
38+
def findFirstEscapeChar(arr: Array[Byte], from: Int, to: Int): Int = {
39+
var i = from
40+
while (i < to) {
41+
val b = arr(i) & 0xff
42+
if (b < 32 || b == '"' || b == '\\') return i
43+
i += 1
44+
}
45+
-1
46+
}
47+
48+
/**
49+
* Compare two strings by Unicode codepoint values. Scalar fallback for Scala.js.
50+
* Uses equal-char-skip fast path with deferred surrogate check.
51+
*/
52+
def compareStrings(s1: String, s2: String): Int = {
53+
if (s1 eq s2) return 0
54+
val n1 = s1.length
55+
val n2 = s2.length
56+
val minLen = math.min(n1, n2)
57+
var i = 0
58+
while (i < minLen) {
59+
val c1 = s1.charAt(i)
60+
val c2 = s2.charAt(i)
61+
if (c1 == c2) {
62+
i += 1
63+
} else if (!Character.isSurrogate(c1) && !Character.isSurrogate(c2)) {
64+
return c1 - c2
65+
} else {
66+
val cp1 = Character.codePointAt(s1, i)
67+
val cp2 = Character.codePointAt(s2, i)
68+
if (cp1 != cp2) return Integer.compare(cp1, cp2)
69+
i += Character.charCount(cp1)
70+
}
71+
}
72+
Integer.compare(n1, n2)
73+
}
3674
}

sjsonnet/src-jvm/sjsonnet/CharSWAR.java

Lines changed: 146 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
import java.nio.charset.StandardCharsets;
77

88
/**
9-
* SWAR (SIMD Within A Register) escape-char scanner for JSON string rendering.
9+
* SWAR (SIMD Within A Register) utilities for JSON string rendering and string comparison.
1010
*
11-
* <p>Detects characters requiring JSON escaping: control chars ({@code < 32}),
12-
* double-quote ({@code '"'}), and backslash ({@code '\\'}).
11+
* <p>Provides:
12+
* <ul>
13+
* <li>Escape-char scanning: detects/locates chars requiring JSON escaping
14+
* (control chars, double-quote, backslash).</li>
15+
* <li>String comparison: codepoint-correct comparison with array-based inner loop
16+
* that the JIT can auto-vectorize to SIMD instructions.</li>
17+
* </ul>
1318
*
1419
* <p>For strings above a threshold length, converts to ISO-8859-1 bytes and
1520
* processes 8 bytes at a time using {@link VarHandle} bulk reads + Hacker's
@@ -138,4 +143,142 @@ private static boolean hasEscapeCharScalar(String s, int len) {
138143
}
139144
return false;
140145
}
146+
147+
// =========================================================================
148+
// findFirstEscapeChar — position-returning SWAR scan for chunked rendering
149+
// =========================================================================
150+
151+
/**
152+
* Find the index of the first byte in {@code arr[from..to)} that needs JSON
153+
* string escaping. Returns {@code -1} if no escape char is found.
154+
*
155+
* <p>Uses SWAR to scan 8 bytes per iteration, then pinpoints the exact byte
156+
* within a matched 8-byte word via scalar fallback.
157+
*/
158+
static int findFirstEscapeChar(byte[] arr, int from, int to) {
159+
int i = from;
160+
int limit = to - 7;
161+
while (i < limit) {
162+
long word = (long) LONG_VIEW.get(arr, i);
163+
if (swarHasMatch(word)) {
164+
// Pinpoint exact byte within the matched 8-byte word
165+
for (int j = i; j < i + 8; j++) {
166+
int b = arr[j] & 0xFF;
167+
if (b < 32 || b == '"' || b == '\\') return j;
168+
}
169+
}
170+
i += 8;
171+
}
172+
// Tail: remaining 0-7 bytes
173+
while (i < to) {
174+
int b = arr[i] & 0xFF;
175+
if (b < 32 || b == '"' || b == '\\') return i;
176+
i++;
177+
}
178+
return -1;
179+
}
180+
181+
// =========================================================================
182+
// compareStrings — JIT-vectorizable codepoint-correct string comparison
183+
// =========================================================================
184+
185+
/** Reusable char buffers for string comparison (one per thread). */
186+
private static final int CMP_BUF_SIZE = 32768;
187+
private static final ThreadLocal<char[]> CMP_BUF1 =
188+
ThreadLocal.withInitial(() -> new char[CMP_BUF_SIZE]);
189+
private static final ThreadLocal<char[]> CMP_BUF2 =
190+
ThreadLocal.withInitial(() -> new char[CMP_BUF_SIZE]);
191+
192+
/** Below this length, scalar charAt comparison is faster than getChars + array loop. */
193+
private static final int CMP_THRESHOLD = 16;
194+
195+
/**
196+
* Compare two strings by Unicode codepoint values. Equivalent to
197+
* {@code Util.compareStringsByCodepoint} but uses bulk {@code getChars} +
198+
* tight array loop so the JIT can auto-vectorize the comparison to SIMD
199+
* instructions (AVX2/SSE on x86, NEON on ARM).
200+
*
201+
* <p>Surrogate checks are deferred to the mismatch point (O(1) instead of
202+
* O(n)), which is correct because equal chars — even surrogates — can be
203+
* skipped without affecting ordering.
204+
*/
205+
static int compareStrings(String s1, String s2) {
206+
if (s1 == s2) return 0;
207+
int n1 = s1.length(), n2 = s2.length();
208+
int minLen = Math.min(n1, n2);
209+
210+
// Short strings or strings exceeding buffer: scalar path
211+
if (minLen < CMP_THRESHOLD || n1 > CMP_BUF_SIZE || n2 > CMP_BUF_SIZE) {
212+
return compareStringsScalar(s1, n1, s2, n2);
213+
}
214+
215+
// Bulk-copy to char arrays — eliminates String.charAt() virtual dispatch,
216+
// enabling the JIT to auto-vectorize the comparison loop.
217+
char[] c1 = CMP_BUF1.get();
218+
char[] c2 = CMP_BUF2.get();
219+
s1.getChars(0, n1, c1, 0);
220+
s2.getChars(0, n2, c2, 0);
221+
222+
// Tight comparison loop — the simple c1[i] != c2[i] pattern is what
223+
// the C2 JIT compiler recognizes and vectorizes.
224+
int i = 0;
225+
while (i < minLen) {
226+
if (c1[i] != c2[i]) {
227+
char a = c1[i], b = c2[i];
228+
if (!Character.isSurrogate(a) && !Character.isSurrogate(b)) {
229+
return a - b;
230+
}
231+
// Back up if we landed on a low surrogate that's part of a pair
232+
int pos = i;
233+
if (pos > 0 && Character.isLowSurrogate(a) && Character.isHighSurrogate(c1[pos - 1])) {
234+
pos--;
235+
}
236+
return compareCodepointsFrom(c1, n1, c2, n2, pos);
237+
}
238+
i++;
239+
}
240+
return Integer.compare(n1, n2);
241+
}
242+
243+
/**
244+
* Scalar codepoint comparison for short strings or overflow.
245+
* Uses the equal-char-skip fast path (no surrogate check on matching chars).
246+
*/
247+
private static int compareStringsScalar(String s1, int n1, String s2, int n2) {
248+
int minLen = Math.min(n1, n2);
249+
int i = 0;
250+
while (i < minLen) {
251+
char c1 = s1.charAt(i);
252+
char c2 = s2.charAt(i);
253+
if (c1 == c2) {
254+
i++;
255+
} else if (!Character.isSurrogate(c1) && !Character.isSurrogate(c2)) {
256+
return c1 - c2;
257+
} else {
258+
int cp1 = Character.codePointAt(s1, i);
259+
int cp2 = Character.codePointAt(s2, i);
260+
if (cp1 != cp2) return Integer.compare(cp1, cp2);
261+
i += Character.charCount(cp1);
262+
}
263+
}
264+
return Integer.compare(n1, n2);
265+
}
266+
267+
/**
268+
* Codepoint-level comparison from a given position in char arrays.
269+
* Used as fallback when a mismatch involves surrogate chars.
270+
*/
271+
private static int compareCodepointsFrom(char[] c1, int n1, char[] c2, int n2, int from) {
272+
int i1 = from, i2 = from;
273+
while (i1 < n1 && i2 < n2) {
274+
int cp1 = Character.codePointAt(c1, i1);
275+
int cp2 = Character.codePointAt(c2, i2);
276+
if (cp1 != cp2) return Integer.compare(cp1, cp2);
277+
i1 += Character.charCount(cp1);
278+
i2 += Character.charCount(cp2);
279+
}
280+
if (i1 < n1) return 1;
281+
if (i2 < n2) return -1;
282+
return 0;
283+
}
141284
}

sjsonnet/src-native/sjsonnet/CharSWAR.scala

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,129 @@ object CharSWAR {
108108
}
109109
false
110110
}
111+
112+
// =========================================================================
113+
// findFirstEscapeChar — position-returning SWAR scan for chunked rendering
114+
// =========================================================================
115+
116+
/**
117+
* Find the index of the first byte in `arr(from until to)` that needs JSON string escaping.
118+
* Returns -1 if no escape char is found. Uses SWAR via Intrinsics.loadLong.
119+
*/
120+
def findFirstEscapeChar(arr: Array[Byte], from: Int, to: Int): Int = {
121+
val len = to - from
122+
if (len < 8) return findFirstEscapeCharScalar(arr, from, to)
123+
val barr = arr.asInstanceOf[ByteArray]
124+
var i = from
125+
val limit = to - 7
126+
while (i < limit) {
127+
val word = Intrinsics.loadLong(barr.atRawUnsafe(i))
128+
if (swarHasMatch(word)) {
129+
// Pinpoint exact byte within the matched 8-byte word
130+
var j = i
131+
while (j < i + 8) {
132+
val b = arr(j) & 0xff
133+
if (b < 32 || b == '"' || b == '\\') return j
134+
j += 1
135+
}
136+
}
137+
i += 8
138+
}
139+
// Tail
140+
while (i < to) {
141+
val b = arr(i) & 0xff
142+
if (b < 32 || b == '"' || b == '\\') return i
143+
i += 1
144+
}
145+
-1
146+
}
147+
148+
@inline private def findFirstEscapeCharScalar(arr: Array[Byte], from: Int, to: Int): Int = {
149+
var i = from
150+
while (i < to) {
151+
val b = arr(i) & 0xff
152+
if (b < 32 || b == '"' || b == '\\') return i
153+
i += 1
154+
}
155+
-1
156+
}
157+
158+
// =========================================================================
159+
// compareStrings — codepoint-correct string comparison
160+
// =========================================================================
161+
162+
/**
163+
* Compare two strings by Unicode codepoint values. Uses bulk getChars + tight array loop for
164+
* LLVM auto-vectorization. Surrogate checks deferred to mismatch point only.
165+
*/
166+
def compareStrings(s1: String, s2: String): Int = {
167+
if (s1 eq s2) return 0
168+
val n1 = s1.length
169+
val n2 = s2.length
170+
val minLen = math.min(n1, n2)
171+
172+
if (minLen < 16) return compareStringsScalar(s1, n1, s2, n2)
173+
174+
// Bulk-copy to arrays — enables LLVM auto-vectorization
175+
val c1 = new Array[Char](n1)
176+
val c2 = new Array[Char](n2)
177+
s1.getChars(0, n1, c1, 0)
178+
s2.getChars(0, n2, c2, 0)
179+
180+
var i = 0
181+
while (i < minLen) {
182+
if (c1(i) != c2(i)) {
183+
val a = c1(i)
184+
val b = c2(i)
185+
if (!Character.isSurrogate(a) && !Character.isSurrogate(b)) {
186+
return a - b
187+
}
188+
var pos = i
189+
if (pos > 0 && Character.isLowSurrogate(a) && Character.isHighSurrogate(c1(pos - 1))) {
190+
pos -= 1
191+
}
192+
return compareCodepointsFrom(c1, n1, c2, n2, pos)
193+
}
194+
i += 1
195+
}
196+
Integer.compare(n1, n2)
197+
}
198+
199+
private def compareStringsScalar(s1: String, n1: Int, s2: String, n2: Int): Int = {
200+
val minLen = math.min(n1, n2)
201+
var i = 0
202+
while (i < minLen) {
203+
val c1 = s1.charAt(i)
204+
val c2 = s2.charAt(i)
205+
if (c1 == c2) {
206+
i += 1
207+
} else if (!Character.isSurrogate(c1) && !Character.isSurrogate(c2)) {
208+
return c1 - c2
209+
} else {
210+
val cp1 = Character.codePointAt(s1, i)
211+
val cp2 = Character.codePointAt(s2, i)
212+
if (cp1 != cp2) return Integer.compare(cp1, cp2)
213+
i += Character.charCount(cp1)
214+
}
215+
}
216+
Integer.compare(n1, n2)
217+
}
218+
219+
private def compareCodepointsFrom(
220+
c1: Array[Char],
221+
n1: Int,
222+
c2: Array[Char],
223+
n2: Int,
224+
from: Int): Int = {
225+
var i1 = from
226+
var i2 = from
227+
while (i1 < n1 && i2 < n2) {
228+
val cp1 = Character.codePointAt(c1, i1)
229+
val cp2 = Character.codePointAt(c2, i2)
230+
if (cp1 != cp2) return Integer.compare(cp1, cp2)
231+
i1 += Character.charCount(cp1)
232+
i2 += Character.charCount(cp2)
233+
}
234+
if (i1 < n1) 1 else if (i2 < n2) -1 else 0
235+
}
111236
}

0 commit comments

Comments
 (0)