Model Building with BoFire

This notebooks shows how to setup and analyze models trained with BoFire. It is still WIP.

Imports

import bofire.surrogates.api as surrogates
from bofire.data_models.domain.api import Inputs, Outputs
from bofire.data_models.enum import RegressionMetricsEnum
from bofire.data_models.features.api import ContinuousInput, ContinuousOutput
from bofire.data_models.surrogates.api import SingleTaskGPSurrogate
from bofire.plot.feature_importance import plot_feature_importance_by_feature_plotly
from bofire.surrogates.feature_importance import (
    combine_lengthscale_importances,
    combine_permutation_importances,
    lengthscale_importance_hook,
    permutation_importance_hook,
)

Problem Setup

For didactic purposes, we sample data from a Himmelblau benchmark function and use them to train a SingleTaskGP.

# TODO: replace this after JDs PR is ready.
input_features = Inputs(
    features=[ContinuousInput(key=f"x_{i+1}", bounds=(-4, 4)) for i in range(3)],
)
output_features = Outputs(features=[ContinuousOutput(key="y")])
experiments = input_features.sample(n=50)
experiments.eval("y=((x_1**2 + x_2 - 11)**2+(x_1 + x_2**2 -7)**2)", inplace=True)
experiments["valid_y"] = 1

Cross Validation

Run the cross validation

data_model = SingleTaskGPSurrogate(
    inputs=input_features,
    outputs=output_features,
)

model = surrogates.map(data_model=data_model)
train_cv, test_cv, pi = model.cross_validate(
    experiments,
    folds=5,
    hooks={
        "permutation_importance": permutation_importance_hook,
        "lengthscale_importance": lengthscale_importance_hook,
    },
)
combined_importances = {
    m.name: combine_permutation_importances(pi["permutation_importance"], m).describe()
    for m in RegressionMetricsEnum
}
combined_importances["lengthscale"] = combine_lengthscale_importances(
    pi["lengthscale_importance"],
).describe()
plot_feature_importance_by_feature_plotly(
    combined_importances,
    relative=False,
    caption="Permutation Feature Importances",
    show_std=True,
    importance_measure="Permutation Feature Importance",
)

Analyze the cross validation

Plots are added in a future PR.

# Performance on test sets
test_cv.get_metrics(combine_folds=True)
MAE MSD R2 MAPE PEARSON SPEARMAN FISHER
0 7.673787 164.666878 0.960158 0.147664 0.980352 0.977815 7.169177e-10
display(test_cv.get_metrics(combine_folds=False))
display(test_cv.get_metrics(combine_folds=False).describe())
MAE MSD R2 MAPE PEARSON SPEARMAN FISHER
0 5.298843 62.012141 0.987763 0.230372 0.995865 0.951515 0.003968
1 6.778411 113.683326 0.978833 0.181690 0.989596 0.975758 0.103175
2 11.769986 436.793806 0.919216 0.146917 0.972384 1.000000 0.003968
3 7.578324 81.631827 0.904415 0.084559 0.966079 0.842424 0.103175
4 6.943373 129.213291 0.966394 0.094783 0.988850 1.000000 0.003968
MAE MSD R2 MAPE PEARSON SPEARMAN FISHER
count 5.000000 5.000000 5.000000 5.000000 5.000000 5.000000 5.000000
mean 7.673787 164.666878 0.951324 0.147664 0.982555 0.953939 0.043651
std 2.437392 154.387628 0.037226 0.060781 0.012662 0.065499 0.054338
min 5.298843 62.012141 0.904415 0.084559 0.966079 0.842424 0.003968
25% 6.778411 81.631827 0.919216 0.094783 0.972384 0.951515 0.003968
50% 6.943373 113.683326 0.966394 0.146917 0.988850 0.975758 0.003968
75% 7.578324 129.213291 0.978833 0.181690 0.989596 1.000000 0.103175
max 11.769986 436.793806 0.987763 0.230372 0.995865 1.000000 0.103175