Skip to content

Commit 68d2c99

Browse files
committed
test fanoval importance
1 parent fa9f750 commit 68d2c99

1 file changed

Lines changed: 60 additions & 0 deletions

File tree

tests/core/custom_test_function/test_custom_test_function.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,63 @@ def test_repr_contains_info(self, sphere_func):
350350
assert "CustomTestFunction" in repr_str
351351
assert "n_dim=2" in repr_str
352352
assert "n_evaluations=1" in repr_str
353+
354+
355+
@pytest.fixture
356+
def asymmetric_func_with_data():
357+
"""Function where x dominates: f = x^2 + 0.01*y. 100 evaluations."""
358+
func = CustomTestFunction(
359+
objective_fn=lambda p: p["x"] ** 2 + 0.01 * p["y"],
360+
search_space={"x": (-5, 5), "y": (-5, 5)},
361+
)
362+
rng = np.random.default_rng(42)
363+
for _ in range(100):
364+
func({"x": float(rng.uniform(-5, 5)), "y": float(rng.uniform(-5, 5))})
365+
return func
366+
367+
368+
class TestFanovaImportance:
369+
"""Test fANOVA-based parameter importance analysis."""
370+
371+
def test_returns_correct_structure(self, sphere_func_with_data):
372+
"""Returns dict with all param names, values >= 0, summing to 1."""
373+
importance = sphere_func_with_data.analysis.parameter_importance(method="fanova")
374+
assert set(importance.keys()) == {"x", "y"}
375+
assert all(v >= 0 for v in importance.values())
376+
assert sum(importance.values()) == pytest.approx(1.0, abs=1e-10)
377+
378+
def test_identifies_dominant_parameter(self, asymmetric_func_with_data):
379+
"""For f = x^2 + 0.01*y, fANOVA should assign x > 80% importance."""
380+
importance = asymmetric_func_with_data.analysis.parameter_importance(method="fanova")
381+
assert importance["x"] > 0.8
382+
383+
def test_symmetric_function_equal_importance(self):
384+
"""For f = x^2 + y^2, both parameters should get roughly equal weight."""
385+
func = CustomTestFunction(
386+
objective_fn=lambda p: p["x"] ** 2 + p["y"] ** 2,
387+
search_space={"x": (-5, 5), "y": (-5, 5)},
388+
)
389+
rng = np.random.default_rng(42)
390+
for _ in range(100):
391+
func({"x": float(rng.uniform(-5, 5)), "y": float(rng.uniform(-5, 5))})
392+
393+
importance = func.analysis.parameter_importance(method="fanova")
394+
assert abs(importance["x"] - importance["y"]) < 0.25
395+
396+
def test_requires_minimum_30_evaluations(self):
397+
"""fANOVA should raise ValueError with fewer than 30 data points."""
398+
func = CustomTestFunction(
399+
objective_fn=lambda p: p["x"] ** 2,
400+
search_space={"x": (-5, 5)},
401+
)
402+
for i in range(20):
403+
func({"x": float(i) - 10.0})
404+
405+
with pytest.raises(ValueError, match="requires at least 30"):
406+
func.analysis.parameter_importance(method="fanova")
407+
408+
def test_outperforms_variance_on_nonlinear(self, asymmetric_func_with_data):
409+
"""fANOVA should capture x^2 better than linear correlation does."""
410+
var_imp = asymmetric_func_with_data.analysis.parameter_importance(method="variance")
411+
fanova_imp = asymmetric_func_with_data.analysis.parameter_importance(method="fanova")
412+
assert fanova_imp["x"] > var_imp["x"]

0 commit comments

Comments
 (0)