Skip to content

Commit f14e7fe

Browse files
committed
feat: add get_sampleset method to GaussianProcess and update bindings
1 parent 1d94292 commit f14e7fe

5 files changed

Lines changed: 78 additions & 14 deletions

File tree

examples/PythonTutorial.ipynb

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -210,24 +210,18 @@
210210
},
211211
{
212212
"cell_type": "code",
213-
"execution_count": null,
214-
"id": "0783f49b",
213+
"execution_count": 9,
214+
"id": "926ea540",
215215
"metadata": {},
216216
"outputs": [],
217-
"source": []
218-
},
219-
{
220-
"cell_type": "code",
221-
"execution_count": null,
222-
"id": "99e7a214",
223-
"metadata": {},
224-
"outputs": [],
225-
"source": []
217+
"source": [
218+
"model_json = gp.to_json()"
219+
]
226220
},
227221
{
228222
"cell_type": "code",
229223
"execution_count": null,
230-
"id": "650d9295",
224+
"id": "cdb01041",
231225
"metadata": {},
232226
"outputs": [],
233227
"source": []

include/gp.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ namespace libgp {
8686

8787
/** Clear sample set and free memory. */
8888
void clear_sampleset();
89+
90+
Eigen::MatrixXd get_sampleset();
8991

9092
/** Get reference on currently used covariance function. */
9193
CovarianceFunction & covf();

libgp/gaussian_process.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@ class GaussianProcess(libgp_cpp.GaussianProcess):
1010
This class provides methods for training and predicting with a Gaussian Process model.
1111
"""
1212

13+
def __init__(self, input_dim: int, covariance_function: str) -> None:
14+
"""Initialize the Gaussian Process model.
15+
16+
Parameters:
17+
- input_dim: Number of input dimensions.
18+
- covariance_function: Covariance function to use (e.g., 'RBF', 'Matern').
19+
"""
20+
super().__init__(input_dim, covariance_function)
21+
1322
def add_pattern(self, x: np.array, y: float) -> None:
1423
"""Add a single training pattern to the Gaussian Process model.
1524
@@ -87,6 +96,15 @@ def clear_sampleset(self) -> None:
8796
"""Clear the training set."""
8897
super().clear_sampleset()
8998

99+
def get_sampleset(self) -> tuple:
100+
"""Get the training set.
101+
102+
Returns:
103+
- A tuple containing the input features and target values.
104+
"""
105+
data = super().get_sampleset()
106+
return data[:, :-1], data[:, -1]
107+
90108
def get_log_likelihood(self) -> float:
91109
"""Get the log likelihood of the current model.
92110
@@ -134,3 +152,41 @@ def get_param_dim(self) -> int:
134152
- The number of hyperparameters as an integer.
135153
"""
136154
return super().get_param_dim()
155+
156+
def to_json(self) -> dict:
157+
"""Convert the Gaussian Process model to a JSON-compatible dictionary.
158+
159+
Returns:
160+
- A dictionary representation of the model.
161+
"""
162+
x, y = self.get_sampleset()
163+
return {
164+
"type": "GaussianProcess",
165+
"covariance_function": self.get_covariance_function(),
166+
"loghyper": self.get_loghyper().tolist(),
167+
"input_dim": self.get_input_dim(),
168+
"sampleset_size": self.get_sampleset_size(),
169+
"sampleset_x": x.tolist(),
170+
"sampleset_y": y.tolist()
171+
}
172+
173+
@classmethod
174+
def from_json(cls, json_data: dict) -> "GaussianProcess":
175+
"""Create a Gaussian Process model from a JSON-compatible dictionary.
176+
177+
Parameters:
178+
- json_data: A dictionary containing the model parameters.
179+
180+
Returns:
181+
- An instance of the GaussianProcess class.
182+
"""
183+
input_dim = json_data["input_dim"]
184+
covariance_function = json_data["covariance_function"]
185+
gp = cls(input_dim, covariance_function)
186+
gp.set_loghyper(np.array(json_data["loghyper"]))
187+
gp.add_patterns(np.array(json_data["sampleset_x"]), np.array(json_data["sampleset_y"]))
188+
return gp
189+
190+
def __repr__(self) -> str:
191+
"""Return a string representation of the Gaussian Process model."""
192+
return f"GaussianProcess(input_dim={self.get_input_dim()}, covariance_function='{self.get_covariance_function()}')"

src/bindings.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ PYBIND11_MODULE(libgp_cpp, m) {
1212
m.doc() = "Python bindings for libgp - Gaussian Process Regression Library";
1313

1414
py::class_<libgp::GaussianProcess>(m, "GaussianProcess")
15-
.def(py::init<>())
1615
.def(py::init<size_t, const std::string&>())
17-
.def(py::init<const char*>())
1816
.def("add_pattern", [](libgp::GaussianProcess& self, py::array_t<double> x, double y) {
1917
py::buffer_info buf = x.request();
2018
if (buf.ndim != 1)
@@ -46,6 +44,7 @@ PYBIND11_MODULE(libgp_cpp, m) {
4644
.def("set_y", &libgp::GaussianProcess::set_y)
4745
.def("get_sampleset_size", &libgp::GaussianProcess::get_sampleset_size)
4846
.def("clear_sampleset", &libgp::GaussianProcess::clear_sampleset)
47+
.def("get_sampleset", &libgp::GaussianProcess::get_sampleset)
4948
.def("get_log_likelihood", &libgp::GaussianProcess::log_likelihood)
5049
.def("get_log_likelihood_gradient", &libgp::GaussianProcess::log_likelihood_gradient)
5150
.def("get_input_dim", &libgp::GaussianProcess::get_input_dim)
@@ -65,6 +64,9 @@ PYBIND11_MODULE(libgp_cpp, m) {
6564
})
6665
.def("get_param_dim", [](libgp::GaussianProcess& self) {
6766
return self.covf().get_param_dim();
67+
})
68+
.def("get_covariance_function", [](libgp::GaussianProcess& self) {
69+
return self.covf().to_string();
6870
});
6971

7072
py::class_<libgp::CovFactory>(m, "CovFactory")

src/gp.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,16 @@ namespace libgp {
265265
sampleset->clear();
266266
}
267267

268+
Eigen::MatrixXd GaussianProcess::get_sampleset()
269+
{
270+
Eigen::MatrixXd samples(sampleset->size(), input_dim + 1);
271+
for (size_t i=0; i<sampleset->size(); ++i) {
272+
samples.row(i).head(input_dim) = sampleset->x(i);
273+
samples(i, input_dim) = sampleset->y(i);
274+
}
275+
return samples;
276+
}
277+
268278
void GaussianProcess::write(const char * filename)
269279
{
270280
// output

0 commit comments

Comments
 (0)