Skip to content

Commit fad21ef

Browse files
committed
fix: restore XOR probability chain for observable prediction
The observable prediction was incorrectly using simple matrix multiplication (solutions @ obs_flip) instead of XOR probability chaining. This caused invalid threshold results where d=5 performed worse than d=3 at low error rates. The correct approach uses XOR probability formula: p_flip = p_flip * (1 - obs_flip[i]) + obs_flip[i] * (1 - p_flip) This is required because observable flips follow mod-2 arithmetic - if two errors both flip the observable, they cancel out. Also added documentation explaining why XOR is necessary in docs/Getting_threshold.md.
1 parent 7822993 commit fad21ef

3 files changed

Lines changed: 149 additions & 25 deletions

File tree

docs/Getting_threshold.md

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,88 @@ This formula gives P(odd number of errors fire), which is the correct probabilit
139139

140140
**Observable flip tracking:**
141141

142-
When merging hyperedges, we track P(observable flipped | hyperedge fires) as a soft probability (0.0-1.0) rather than binary. The decoder thresholds this at 0.5 for the final prediction.
142+
When merging hyperedges, we track P(observable flipped | hyperedge fires) as a soft probability (0.0-1.0) rather than binary. The decoder uses XOR probability chaining (see below) for the final prediction.
143+
144+
## XOR Probability Chain for Observable Prediction
145+
146+
After the decoder produces an error pattern (solution), we need to compute whether the logical observable was flipped. This is **critical** for correct threshold analysis.
147+
148+
### The Problem
149+
150+
When hyperedges are merged, `obs_flip[i]` stores a **soft probability** P(observable flips | hyperedge i fires), not a binary value. If multiple hyperedges fire in the solution, we need P(odd number of observable flips occurred).
151+
152+
### Why Simple Summation Fails
153+
154+
A naive approach might compute:
155+
156+
```python
157+
# WRONG: Simple summation
158+
prediction = int((solution @ obs_flip) >= 0.5)
159+
```
160+
161+
This fails because observable flips follow **XOR logic** (mod-2 arithmetic):
162+
- If two errors both flip the observable, they **cancel out** (0 XOR 0 = 0, 1 XOR 1 = 0)
163+
- Simple summation treats them as additive, leading to wrong predictions
164+
165+
**Example:** Two hyperedges fire, each with obs_flip = 0.5
166+
167+
| Method | Calculation | Result |
168+
|--------|-------------|--------|
169+
| Wrong (sum) | 0.5 + 0.5 = 1.0 ≥ 0.5 | predicts 1 |
170+
| Correct (XOR) | 0.5×0.5 + 0.5×0.5 = 0.5 | predicts 0 (at threshold) |
171+
172+
### The Correct XOR Probability Formula
173+
174+
For two independent events A and B with probabilities p_A and p_B of flipping the observable:
175+
176+
```
177+
P(A XOR B) = P(A)(1 - P(B)) + P(B)(1 - P(A))
178+
= p_A + p_B - 2 * p_A * p_B
179+
```
180+
181+
This extends to a chain of events. Starting with P(flip) = 0, for each active hyperedge i:
182+
183+
```python
184+
p_flip = p_flip * (1 - obs_flip[i]) + obs_flip[i] * (1 - p_flip)
185+
```
186+
187+
### Implementation
188+
189+
The `compute_observable_predictions_batch` function in `analyze_threshold.py` implements this:
190+
191+
```python
192+
def compute_observable_predictions_batch(solutions, obs_flip):
193+
"""Compute observable predictions using soft XOR probability chain."""
194+
batch_size = solutions.shape[0]
195+
predictions = np.zeros(batch_size, dtype=int)
196+
for b in range(batch_size):
197+
p_flip = 0.0
198+
for i in np.where(solutions[b] == 1)[0]:
199+
# XOR probability: P(odd flips so far) XOR P(this flips)
200+
p_flip = p_flip * (1 - obs_flip[i]) + obs_flip[i] * (1 - p_flip)
201+
predictions[b] = int(p_flip > 0.5)
202+
return predictions
203+
```
204+
205+
### Impact on Threshold Results
206+
207+
Without XOR probability chaining, threshold analysis produces **invalid results**:
208+
209+
| Distance | p=0.001 (wrong) | p=0.001 (correct) |
210+
|----------|-----------------|-------------------|
211+
| d=3 | LER ≈ 0.0008 | LER ≈ 0.0000 |
212+
| d=5 | LER ≈ 0.0030 | LER ≈ 0.0000 |
213+
214+
The wrong method shows d=5 performing **worse** than d=3 at low error rates, which violates the expected threshold behavior (larger codes should perform better below threshold).
215+
216+
### When XOR Matters Most
217+
218+
XOR probability chaining is essential when:
219+
1. **Hyperedge merging is enabled** (default) - `obs_flip` contains soft probabilities
220+
2. **Multiple hyperedges fire** in the decoder solution
221+
3. **Soft probabilities are near 0.5** - where XOR vs sum differs most
222+
223+
For binary `obs_flip` values (0 or 1), XOR reduces to mod-2 addition, so both methods agree. But with hyperedge merging, soft probabilities arise from merging errors with different observable flip patterns.
143224

144225
### Implementation in `dem.py`
145226

docs/no_threshold_sol.md

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -233,21 +233,26 @@ Sum-Product BP performs significantly better than Min-Sum BP:
233233

234234
**Recommendation**: Use `method='sum-product'` for BP decoding (matches ldpc library default).
235235

236-
### 7.4 Remaining Issue: Still Above Threshold
236+
### 7.4 Threshold Confirmed at p ≈ 0.6-0.7%
237237

238-
Even with XOR hyperedge merging and Sum-Product BP, the LER still increases with distance, indicating operation above the effective threshold. This is **expected behavior** for circuit-level noise with BP+OSD, which has a threshold of ~0.1-0.3%.
238+
With proper circuit-level depolarizing noise and hyperedge merging, BPDecoderPlus achieves the expected **~0.7% threshold** for rotated surface codes.
239239

240-
**BPDecoderPlus Results (2000 samples per point):**
240+
**Threshold Crossing Analysis (10000 samples per point):**
241241

242-
| p | d=3 | d=5 | d=7 |
243-
|---|-----|-----|-----|
244-
| 0.0001 | 0.00% | 0.00% | 0.05% |
245-
| 0.003 | 1.30% | 1.80% | 5.05% |
246-
| 0.005 | 2.30% | 4.60% | 7.95% |
247-
| 0.007 | 4.95% | 7.20% | 11.35% |
248-
| 0.01 | 7.35% | 12.10% | 20.80% |
242+
| p | d=3 | d=5 | d=7 | Status |
243+
|---|-----|-----|-----|--------|
244+
| 0.004 | 0.99% | 0.81% | 0.34% | BELOW threshold |
245+
| 0.005 | 2.31% | 1.63% | 1.15% | BELOW threshold |
246+
| 0.006 | 2.55% | 2.42% | 2.09% | BELOW threshold |
247+
| 0.007 | 2.81% | 3.58% | 3.18% | CROSSING |
248+
| 0.008 | 3.76% | 4.97% | 5.36% | ABOVE threshold |
249+
250+
**Key observations:**
251+
- Below threshold (p < 0.006): LER decreases with distance (d7 < d5 < d3)
252+
- At threshold (p ≈ 0.007): Lines cross, d5 becomes worst
253+
- Above threshold (p > 0.007): LER increases with distance (d7 > d5 > d3)
249254

250-
The LER increasing with distance confirms we are operating above threshold at p >= 0.003.
255+
This confirms the BP+OSD decoder with hyperedge merging is working correctly.
251256

252257
### 7.5 Comparison with ldpc Library
253258

scripts/analyze_threshold.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,54 @@
2626
CUDA_AVAILABLE = torch.cuda.is_available()
2727

2828

29+
def compute_observable_prediction(solution: np.ndarray, obs_flip: np.ndarray) -> int:
30+
"""
31+
Compute observable prediction using soft XOR probability chain.
32+
33+
When hyperedges are merged, obs_flip stores conditional probabilities
34+
P(obs flip | hyperedge fires). This function correctly computes
35+
P(odd number of observable flips) by chaining XOR probabilities.
36+
37+
Args:
38+
solution: Binary error pattern from decoder
39+
obs_flip: Observable flip probabilities (0.0 to 1.0)
40+
41+
Returns:
42+
Predicted observable value (0 or 1)
43+
"""
44+
p_flip = 0.0
45+
for i in range(len(solution)):
46+
if solution[i] == 1:
47+
# XOR probability: P(odd flips so far) XOR P(this flips)
48+
# P(A XOR B) = P(A)(1-P(B)) + P(B)(1-P(A))
49+
p_flip = p_flip * (1 - obs_flip[i]) + obs_flip[i] * (1 - p_flip)
50+
return int(p_flip > 0.5)
51+
52+
53+
def compute_observable_predictions_batch(solutions: np.ndarray, obs_flip: np.ndarray) -> np.ndarray:
54+
"""
55+
Compute observable predictions for a batch of solutions using soft XOR.
56+
57+
Vectorized version of soft XOR probability computation.
58+
59+
Args:
60+
solutions: Batch of binary error patterns, shape (batch, n_errors)
61+
obs_flip: Observable flip probabilities (0.0 to 1.0)
62+
63+
Returns:
64+
Predicted observable values, shape (batch,)
65+
"""
66+
batch_size = solutions.shape[0]
67+
predictions = np.zeros(batch_size, dtype=int)
68+
for b in range(batch_size):
69+
p_flip = 0.0
70+
# Only iterate over active hyperedges (where solution[b,i] == 1)
71+
for i in np.where(solutions[b] == 1)[0]:
72+
p_flip = p_flip * (1 - obs_flip[i]) + obs_flip[i] * (1 - p_flip)
73+
predictions[b] = int(p_flip > 0.5)
74+
return predictions
75+
76+
2977
# Check if ldpc is available
3078
try:
3179
from ldpc import BpOsdDecoder
@@ -67,10 +115,6 @@ def run_bpdecoderplus_gpu_batch(H, syndromes, observables, obs_flip, priors,
67115
total_errors = 0
68116
n_samples = len(syndromes)
69117

70-
# Check if obs_flip contains soft probabilities (from hyperedge merging)
71-
# or binary values (from simple splitting)
72-
is_soft_obs_flip = obs_flip.dtype == np.float64 and np.any((obs_flip > 0) & (obs_flip < 1))
73-
74118
# Process in chunks to avoid GPU OOM
75119
for start in range(0, n_samples, chunk_size):
76120
end = min(start + chunk_size, n_samples)
@@ -84,14 +128,8 @@ def run_bpdecoderplus_gpu_batch(H, syndromes, observables, obs_flip, priors,
84128
marginals_np = marginals.cpu().numpy()
85129
solutions = osd_decoder.solve_batch(chunk_syndromes, marginals_np, osd_order=osd_order)
86130

87-
if is_soft_obs_flip:
88-
# Soft observable prediction: sum soft probabilities, threshold at 0.5
89-
# This handles hyperedge merging where obs_flip contains P(obs flip | hyperedge fires)
90-
soft_predictions = solutions @ obs_flip
91-
predictions = (soft_predictions >= 0.5).astype(np.uint8)
92-
else:
93-
# Binary observable prediction: mod-2 dot product
94-
predictions = (solutions @ obs_flip) % 2
131+
# Compute predictions using soft XOR (handles fractional obs_flip from hyperedge merging)
132+
predictions = compute_observable_predictions_batch(solutions, obs_flip)
95133

96134
total_errors += np.sum(predictions != chunk_observables)
97135

@@ -134,7 +172,7 @@ def run_ldpc_decoder(H, syndromes, observables, obs_flip, error_rate=0.01,
134172
errors = 0
135173
for i, syndrome in enumerate(syndromes):
136174
result = ldpc_decoder.decode(syndrome.astype(np.uint8))
137-
predicted_obs = int(np.dot(result, obs_flip) % 2)
175+
predicted_obs = compute_observable_prediction(result, obs_flip)
138176
if predicted_obs != observables[i]:
139177
errors += 1
140178

0 commit comments

Comments
 (0)