Skip to content

Commit 2181358

Browse files
committed
fix bug with Ps=None
1 parent bb5b91d commit 2181358

4 files changed

Lines changed: 90 additions & 4 deletions

File tree

.github/workflows/build.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ jobs:
5151
uses: pypa/cibuildwheel@v2.22.0
5252
env:
5353
CIBW_SKIP: "pp* *-win32 *_i686"
54+
CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_28
5455

5556
- uses: actions/upload-artifact@v4
5657
with:

src/diffcp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.1.5"
1+
__version__ = "1.1.6"
22

33
from diffcp.cone_program import solve_and_derivative, \
44
solve_and_derivative_batch, \

src/diffcp/cone_program.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,11 @@ def DT_batch(dxs, dys, dss, return_dP=False, **kwargs):
197197
return dAs, dbs, dcs
198198
else:
199199

200-
def D_batch(dAs, dbs, dcs, dPs=None, **kwargs):
200+
def D_batch(dAs, dbs, dcs, **kwargs):
201201
pool = ThreadPool(processes=n_jobs_backward)
202202

203203
def Di(i):
204-
return Ds[i](dAs[i], dbs[i], dcs[i], dPs[i], **kwargs)
204+
return Ds[i](dAs[i], dbs[i], dcs[i], **kwargs)
205205
results = pool.map(Di, range(batch_size))
206206
pool.close()
207207
dxs = [r[0] for r in results]
@@ -256,6 +256,8 @@ def solve_only_batch(As, bs, cs, cone_dicts, n_jobs_forward=-1,
256256
batch_size = len(As)
257257
if warm_starts is None:
258258
warm_starts = [None] * batch_size
259+
if Ps is None:
260+
Ps = [None] * batch_size
259261
if n_jobs_forward == -1:
260262
n_jobs_forward = mp.cpu_count()
261263
n_jobs_forward = min(batch_size, n_jobs_forward)
@@ -265,7 +267,7 @@ def solve_only_batch(As, bs, cs, cone_dicts, n_jobs_forward=-1,
265267
xs, ys, ss = [], [], []
266268
for i in range(batch_size):
267269
x, y, s = solve_only(As[i], bs[i], cs[i], cone_dicts[i],
268-
warm_starts[i], Ps[i], **kwargs)
270+
warm_start=warm_starts[i], P=Ps[i], **kwargs)
269271
xs += [x]
270272
ys += [y]
271273
ss += [s]

tests/test_clarabel.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,86 @@ def test_psdcone():
138138

139139
assert np.abs(np.trace(sol) - 1.0) < 1e-6
140140
assert (np.linalg.eigvals(sol) >= -1e-6).all()
141+
142+
143+
def test_solve_only_batch():
144+
"""Test solve_only_batch with Ps=None (default)."""
145+
np.random.seed(0)
146+
m = 20
147+
n = 10
148+
batch_size = 5
149+
150+
As, bs, cs, cone_dicts = [], [], [], []
151+
for _ in range(batch_size):
152+
A, b, c, cone_dims = utils.least_squares_eq_scs_data(m, n)
153+
As.append(A)
154+
bs.append(b)
155+
cs.append(c)
156+
cone_dicts.append(cone_dims)
157+
158+
# Test serial path (n_jobs_forward=1) with Ps=None
159+
xs, ys, ss = cone_prog.solve_only_batch(
160+
As, bs, cs, cone_dicts, n_jobs_forward=1, solve_method='Clarabel')
161+
assert len(xs) == batch_size
162+
assert len(ys) == batch_size
163+
assert len(ss) == batch_size
164+
165+
# Verify solutions satisfy optimality conditions
166+
for i in range(batch_size):
167+
np.testing.assert_allclose(As[i] @ xs[i] + ss[i], bs[i], atol=1e-7)
168+
169+
# Test parallel path (n_jobs_forward=-1) with Ps=None
170+
xs_par, ys_par, ss_par = cone_prog.solve_only_batch(
171+
As, bs, cs, cone_dicts, n_jobs_forward=-1, solve_method='Clarabel')
172+
assert len(xs_par) == batch_size
173+
174+
# Verify parallel solutions also satisfy optimality conditions
175+
for i in range(batch_size):
176+
np.testing.assert_allclose(As[i] @ xs_par[i] + ss_par[i], bs[i], atol=1e-7)
177+
178+
179+
def test_derivative_batch_parallel():
180+
"""Test that parallel D_batch works correctly."""
181+
np.random.seed(0)
182+
m = 20
183+
n = 10
184+
batch_size = 5
185+
186+
As, bs, cs, cone_dicts = [], [], [], []
187+
for _ in range(batch_size):
188+
A, b, c, cone_dims = utils.least_squares_eq_scs_data(m, n)
189+
As.append(A)
190+
bs.append(b)
191+
cs.append(c)
192+
cone_dicts.append(cone_dims)
193+
194+
# Solve with serial backward pass
195+
xs_ser, ys_ser, ss_ser, D_ser, DT_ser = cone_prog.solve_and_derivative_batch(
196+
As, bs, cs, cone_dicts, n_jobs_forward=1, n_jobs_backward=1, solve_method='Clarabel')
197+
198+
# Solve with parallel backward pass
199+
xs_par, ys_par, ss_par, D_par, DT_par = cone_prog.solve_and_derivative_batch(
200+
As, bs, cs, cone_dicts, n_jobs_forward=-1, n_jobs_backward=-1, solve_method='Clarabel')
201+
202+
# Create perturbations
203+
dAs = [utils.get_random_like(A, lambda n: np.random.normal(0, 1e-6, size=n)) for A in As]
204+
dbs = [np.random.normal(0, 1e-6, size=b.size) for b in bs]
205+
dcs = [np.random.normal(0, 1e-6, size=c.size) for c in cs]
206+
207+
# Test D_batch (forward derivative)
208+
dxs_ser, dys_ser, dss_ser = D_ser(dAs, dbs, dcs)
209+
dxs_par, dys_par, dss_par = D_par(dAs, dbs, dcs)
210+
211+
for i in range(batch_size):
212+
np.testing.assert_allclose(dxs_ser[i], dxs_par[i], rtol=1e-5, atol=1e-10)
213+
np.testing.assert_allclose(dys_ser[i], dys_par[i], rtol=1e-5, atol=1e-10)
214+
np.testing.assert_allclose(dss_ser[i], dss_par[i], rtol=1e-5, atol=1e-10)
215+
216+
# Test DT_batch (adjoint derivative)
217+
dAs_ser, dbs_ser, dcs_ser = DT_ser(xs_ser, ys_ser, ss_ser)
218+
dAs_par, dbs_par, dcs_par = DT_par(xs_par, ys_par, ss_par)
219+
220+
for i in range(batch_size):
221+
np.testing.assert_allclose(dAs_ser[i].todense(), dAs_par[i].todense(), rtol=1e-5, atol=1e-10)
222+
np.testing.assert_allclose(dbs_ser[i], dbs_par[i], rtol=1e-5, atol=1e-10)
223+
np.testing.assert_allclose(dcs_ser[i], dcs_par[i], rtol=1e-5, atol=1e-10)

0 commit comments

Comments
 (0)