Skip to content

Commit bb5b91d

Browse files
authored
Merge pull request #77 from bstellato/fix-clarabel-psd-permutation
Fix Clarabel PSD cone permutation for multiple cones
2 parents 44e4ffe + 09fe09e commit bb5b91d

2 files changed

Lines changed: 349 additions & 1 deletion

File tree

src/diffcp/cone_program.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,38 @@ def permute_psd_rows(A: sparse.csc_matrix, b: np.ndarray, n: int, row_offset: in
5050

5151
return new_A, new_b
5252

53+
54+
def inverse_permute_psd_solution(y: np.ndarray, s: np.ndarray, n: int, row_offset: int):
55+
"""
56+
Inverse permutes y and s vectors from Clarabel (upper triangular)
57+
back to SCS (lower triangular) convention for a PSD cone.
58+
59+
Args:
60+
y (ndarray): Dual variable vector (Clarabel convention).
61+
s (ndarray): Slack variable vector (Clarabel convention).
62+
n (int): Size of the PSD constraint matrix (n x n).
63+
row_offset (int): Row index where the PSD block starts.
64+
65+
Returns:
66+
tuple: (new_y, new_s) permuted back to SCS convention.
67+
"""
68+
triu_rows, triu_cols = np.triu_indices(n)
69+
70+
# Compute the inverse permutation (Clarabel → SCS)
71+
triu_multi_index = np.ravel_multi_index((triu_cols, triu_rows), (n, n))
72+
preshuffle_from_postshuffle_perm = np.argsort(triu_multi_index)
73+
n_rows = len(preshuffle_from_postshuffle_perm)
74+
75+
new_y = np.copy(y)
76+
new_s = np.copy(s)
77+
78+
# Apply inverse permutation to the PSD block
79+
new_y[row_offset:row_offset+n_rows] = y[row_offset + preshuffle_from_postshuffle_perm]
80+
new_s[row_offset:row_offset+n_rows] = s[row_offset + preshuffle_from_postshuffle_perm]
81+
82+
return new_y, new_s
83+
84+
5385
def pi(z, cones):
5486
"""Projection onto R^n x K^* x R_+
5587
@@ -539,7 +571,7 @@ def solve_internal(A, b, c, cone_dict, solve_method=None,
539571
for v in cone_dict["s"]:
540572
cones.append(clarabel.PSDTriangleConeT(v))
541573
A, b = permute_psd_rows(A, b, v, start_row)
542-
start_row += v
574+
start_row += v * (v + 1) // 2 # triangular number for vectorized PSD cone
543575
if "ep" in cone_dict:
544576
v = cone_dict["ep"]
545577
cones += [clarabel.ExponentialConeT()] * v
@@ -558,6 +590,22 @@ def solve_internal(A, b, c, cone_dict, solve_method=None,
558590
result["y"] = np.array(solution.z)
559591
result["s"] = np.array(solution.s)
560592

593+
# Permute y and s back from Clarabel (upper triangular) to SCS (lower triangular) convention
594+
if "s" in cone_dict:
595+
start_row = 0
596+
if "z" in cone_dict and cone_dict["z"] > 0:
597+
start_row += cone_dict["z"]
598+
if "f" in cone_dict and cone_dict["f"] > 0:
599+
start_row += cone_dict["f"]
600+
if "l" in cone_dict and cone_dict["l"] > 0:
601+
start_row += cone_dict["l"]
602+
if "q" in cone_dict:
603+
start_row += sum(cone_dict["q"])
604+
for v in cone_dict["s"]:
605+
result["y"], result["s"] = inverse_permute_psd_solution(
606+
result["y"], result["s"], v, start_row)
607+
start_row += v * (v + 1) // 2 # triangular number
608+
561609
CLARABEL2SCS_STATUS_MAP = {
562610
"Solved": "Solved",
563611
"PrimalInfeasible": "Infeasible",

tests/test_clarabel_psd.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
"""
2+
Tests for Clarabel PSD cone support in diffcp.
3+
4+
Verifies that solutions and derivatives from Clarabel match SCS
5+
for problems with PSD cones.
6+
7+
Note: When testing dual variables (y), we only check that the solution is correct
8+
(feasible and optimal) rather than exact match, since SCS and Clarabel may
9+
converge to different optimal dual solutions in degenerate cases.
10+
"""
11+
import numpy as np
12+
import scipy.sparse as sparse
13+
import pytest
14+
15+
16+
def scs_data_from_cvxpy_problem(problem):
17+
"""Extract SCS-format data from a CVXPy problem."""
18+
import cvxpy as cp
19+
data = problem.get_problem_data(cp.SCS)[0]
20+
cone_dims = cp.reductions.solvers.conic_solvers.scs_conif.dims_to_solver_dict(
21+
data["dims"]
22+
)
23+
return data["A"], data["b"], data["c"], cone_dims
24+
25+
26+
class TestClarabelPSDPermutation:
27+
"""Tests for Clarabel PSD cone permutation fixes."""
28+
29+
def test_multiple_psd_cones_objective_match(self):
30+
"""Test that SCS and Clarabel give same objective for multiple PSD cones."""
31+
import cvxpy as cp
32+
import diffcp
33+
34+
# Create a problem with two PSD cones
35+
A = np.array([
36+
[1, 2, 3],
37+
[2, 4, 5],
38+
[3, 5, 6],
39+
])
40+
B = np.array([
41+
[7, 8, 9],
42+
[8, 10, 11],
43+
[9, 11, 12],
44+
])
45+
46+
X = cp.Variable((3, 3), symmetric=True)
47+
y = cp.Variable(2)
48+
49+
constraints = [y[0] * A + y[1] * B >> 0, X >> 0]
50+
constraints += [
51+
cp.trace(A @ X) == 1,
52+
y >= 0,
53+
]
54+
55+
obj = cp.Minimize(cp.trace(X) + np.ones(2) @ y)
56+
prob = cp.Problem(obj, constraints)
57+
58+
# Get CVXPy solution as reference
59+
cvxpy_obj = prob.solve(solver=cp.CLARABEL)
60+
61+
# Get SCS-format data
62+
scs_A, scs_b, scs_c, scs_cones = scs_data_from_cvxpy_problem(prob)
63+
64+
# Solve with SCS through diffcp
65+
x_scs, y_scs, s_scs, D_scs, DT_scs = diffcp.solve_and_derivative(
66+
sparse.csc_matrix(scs_A), scs_b, scs_c,
67+
scs_cones,
68+
solve_method='SCS',
69+
verbose=False,
70+
)
71+
72+
# Solve with Clarabel through diffcp
73+
x_cla, y_cla, s_cla, D_cla, DT_cla = diffcp.solve_and_derivative(
74+
sparse.csc_matrix(scs_A), scs_b, scs_c,
75+
scs_cones,
76+
solve_method='CLARABEL',
77+
verbose=False,
78+
)
79+
80+
obj_scs = scs_c @ x_scs
81+
obj_cla = scs_c @ x_cla
82+
83+
# Objectives should match CVXPy
84+
assert np.isclose(obj_scs, cvxpy_obj, atol=1e-4), \
85+
f"SCS obj {obj_scs} doesn't match CVXPy obj {cvxpy_obj}"
86+
assert np.isclose(obj_cla, cvxpy_obj, atol=1e-4), \
87+
f"Clarabel obj {obj_cla} doesn't match CVXPy obj {cvxpy_obj}"
88+
89+
# Primal solution x should match between solvers
90+
assert np.allclose(x_scs, x_cla, atol=1e-4), \
91+
f"x mismatch: SCS={x_scs}, Clarabel={x_cla}"
92+
93+
# Slack variable s should match (since s = b - Ax and x matches)
94+
assert np.allclose(s_scs, s_cla, atol=1e-4), \
95+
f"s mismatch between SCS and Clarabel"
96+
97+
# Note: We do NOT check y here because dual degeneracy can cause
98+
# SCS and Clarabel to find different optimal dual solutions.
99+
100+
def test_single_psd_cone(self):
101+
"""Test that SCS and Clarabel match for a single PSD cone."""
102+
import cvxpy as cp
103+
import diffcp
104+
105+
n = 3
106+
C = np.eye(n)
107+
108+
X = cp.Variable((n, n), symmetric=True)
109+
constraints = [X >> 0, cp.trace(X) == 1]
110+
obj = cp.Minimize(cp.trace(C @ X))
111+
prob = cp.Problem(obj, constraints)
112+
113+
cvxpy_obj = prob.solve(solver=cp.CLARABEL)
114+
115+
scs_A, scs_b, scs_c, scs_cones = scs_data_from_cvxpy_problem(prob)
116+
117+
x_scs, y_scs, s_scs, _, _ = diffcp.solve_and_derivative(
118+
sparse.csc_matrix(scs_A), scs_b, scs_c,
119+
scs_cones,
120+
solve_method='SCS',
121+
verbose=False,
122+
)
123+
124+
x_cla, y_cla, s_cla, _, _ = diffcp.solve_and_derivative(
125+
sparse.csc_matrix(scs_A), scs_b, scs_c,
126+
scs_cones,
127+
solve_method='CLARABEL',
128+
verbose=False,
129+
)
130+
131+
obj_scs = scs_c @ x_scs
132+
obj_cla = scs_c @ x_cla
133+
134+
assert np.isclose(obj_scs, cvxpy_obj, atol=1e-4)
135+
assert np.isclose(obj_cla, cvxpy_obj, atol=1e-4)
136+
assert np.allclose(x_scs, x_cla, atol=1e-4)
137+
assert np.allclose(s_scs, s_cla, atol=1e-4)
138+
139+
def test_mixed_cones(self):
140+
"""Test problem with zero, nonneg, and PSD cones."""
141+
import cvxpy as cp
142+
import diffcp
143+
144+
n = 2
145+
X = cp.Variable((n, n), symmetric=True)
146+
t = cp.Variable()
147+
148+
A = np.array([[1, 0.5], [0.5, 2]])
149+
150+
constraints = [
151+
X >> 0,
152+
t >= 0,
153+
cp.trace(A @ X) == 1,
154+
t <= 5,
155+
]
156+
obj = cp.Minimize(cp.trace(X) + t)
157+
prob = cp.Problem(obj, constraints)
158+
159+
cvxpy_obj = prob.solve(solver=cp.CLARABEL)
160+
161+
scs_A, scs_b, scs_c, scs_cones = scs_data_from_cvxpy_problem(prob)
162+
163+
x_scs, y_scs, s_scs, _, _ = diffcp.solve_and_derivative(
164+
sparse.csc_matrix(scs_A), scs_b, scs_c,
165+
scs_cones,
166+
solve_method='SCS',
167+
verbose=False,
168+
)
169+
170+
x_cla, y_cla, s_cla, _, _ = diffcp.solve_and_derivative(
171+
sparse.csc_matrix(scs_A), scs_b, scs_c,
172+
scs_cones,
173+
solve_method='CLARABEL',
174+
verbose=False,
175+
)
176+
177+
obj_scs = scs_c @ x_scs
178+
obj_cla = scs_c @ x_cla
179+
180+
assert np.isclose(obj_scs, cvxpy_obj, atol=1e-4)
181+
assert np.isclose(obj_cla, cvxpy_obj, atol=1e-4)
182+
assert np.allclose(x_scs, x_cla, atol=1e-4)
183+
184+
def test_constraint_satisfaction(self):
185+
"""Test that Clarabel solution satisfies Ax + s = b, s in K."""
186+
import cvxpy as cp
187+
import diffcp
188+
189+
A_mat = np.array([
190+
[1, 2, 3],
191+
[2, 4, 5],
192+
[3, 5, 6],
193+
])
194+
B_mat = np.array([
195+
[7, 8, 9],
196+
[8, 10, 11],
197+
[9, 11, 12],
198+
])
199+
200+
X = cp.Variable((3, 3), symmetric=True)
201+
y_var = cp.Variable(2)
202+
203+
constraints = [y_var[0] * A_mat + y_var[1] * B_mat >> 0, X >> 0]
204+
constraints += [
205+
cp.trace(A_mat @ X) == 1,
206+
y_var >= 0,
207+
]
208+
209+
obj = cp.Minimize(cp.trace(X) + np.ones(2) @ y_var)
210+
prob = cp.Problem(obj, constraints)
211+
prob.solve(solver=cp.CLARABEL)
212+
213+
scs_A, scs_b, scs_c, scs_cones = scs_data_from_cvxpy_problem(prob)
214+
215+
x_cla, y_cla, s_cla, _, _ = diffcp.solve_and_derivative(
216+
sparse.csc_matrix(scs_A), scs_b, scs_c,
217+
scs_cones,
218+
solve_method='CLARABEL',
219+
verbose=False,
220+
)
221+
222+
# Check Ax + s = b
223+
residual = sparse.csc_matrix(scs_A) @ x_cla + s_cla - scs_b
224+
assert np.allclose(residual, 0, atol=1e-5), \
225+
f"Constraint residual too large: {np.linalg.norm(residual)}"
226+
227+
def test_derivative_lsqr_mode(self):
228+
"""Test that adjoint derivatives work with Clarabel in lsqr mode."""
229+
import cvxpy as cp
230+
import diffcp
231+
232+
# Simple problem with single PSD cone (non-degenerate)
233+
n = 2
234+
C = np.array([[1.0, 0.3], [0.3, 2.0]])
235+
236+
X = cp.Variable((n, n), symmetric=True)
237+
constraints = [X >> 0, cp.trace(X) == 1]
238+
obj = cp.Minimize(cp.trace(C @ X))
239+
prob = cp.Problem(obj, constraints)
240+
prob.solve(solver=cp.CLARABEL)
241+
242+
scs_A, scs_b, scs_c, scs_cones = scs_data_from_cvxpy_problem(prob)
243+
244+
x_cla, y_cla, s_cla, D_cla, DT_cla = diffcp.solve_and_derivative(
245+
sparse.csc_matrix(scs_A), scs_b, scs_c,
246+
scs_cones,
247+
solve_method='CLARABEL',
248+
verbose=False,
249+
mode='lsqr',
250+
)
251+
252+
# Test adjoint derivative with random perturbations
253+
np.random.seed(42)
254+
dx = np.random.randn(x_cla.size) * 0.01
255+
dy = np.random.randn(y_cla.size) * 0.01
256+
ds = np.random.randn(s_cla.size) * 0.01
257+
258+
# Just verify it runs without error and produces finite results
259+
dA_cla, db_cla, dc_cla = DT_cla(dx, dy, ds)
260+
261+
assert np.all(np.isfinite(dc_cla)), "dc contains non-finite values"
262+
assert np.all(np.isfinite(db_cla)), "db contains non-finite values"
263+
assert np.all(np.isfinite(dA_cla.data)), "dA contains non-finite values"
264+
265+
def test_psd_permutation_logic(self):
266+
"""Test the PSD permutation logic directly on a simple case."""
267+
import cvxpy as cp
268+
import diffcp
269+
270+
# Create a problem where we can verify the PSD block ordering
271+
n = 3
272+
C = np.random.randn(n, n)
273+
C = C @ C.T # Make positive definite for interesting solution
274+
275+
X = cp.Variable((n, n), symmetric=True)
276+
constraints = [X >> 0, cp.trace(X) == 1]
277+
obj = cp.Minimize(cp.trace(C @ X))
278+
prob = cp.Problem(obj, constraints)
279+
280+
cvxpy_obj = prob.solve(solver=cp.SCS)
281+
X_opt = X.value
282+
283+
scs_A, scs_b, scs_c, scs_cones = scs_data_from_cvxpy_problem(prob)
284+
285+
# The primal variable x should encode X in lower-triangular column-major order
286+
x_cla, _, s_cla, _, _ = diffcp.solve_and_derivative(
287+
sparse.csc_matrix(scs_A), scs_b, scs_c,
288+
scs_cones,
289+
solve_method='CLARABEL',
290+
verbose=False,
291+
)
292+
293+
# Objectives should match
294+
obj_cla = scs_c @ x_cla
295+
assert np.isclose(obj_cla, cvxpy_obj, atol=1e-4), \
296+
f"Clarabel obj {obj_cla} doesn't match CVXPy obj {cvxpy_obj}"
297+
298+
299+
if __name__ == "__main__":
300+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)