Model Building with BoFire

This notebooks shows how to setup and analyze models trained with BoFire. It is still WIP.

Imports

import bofire.surrogates.api as surrogates
from bofire.data_models.domain.api import Inputs, Outputs
from bofire.data_models.enum import RegressionMetricsEnum
from bofire.data_models.features.api import ContinuousInput, ContinuousOutput
from bofire.data_models.surrogates.api import SingleTaskGPSurrogate
from bofire.plot.feature_importance import plot_feature_importance_by_feature_plotly
from bofire.surrogates.feature_importance import (
    combine_lengthscale_importances,
    combine_permutation_importances,
    lengthscale_importance_hook,
    permutation_importance_hook,
)

Problem Setup

For didactic purposes, we sample data from a Himmelblau benchmark function and use them to train a SingleTaskGP.

# TODO: replace this after JDs PR is ready.
input_features = Inputs(
    features=[ContinuousInput(key=f"x_{i+1}", bounds=(-4, 4)) for i in range(3)],
)
output_features = Outputs(features=[ContinuousOutput(key="y")])
experiments = input_features.sample(n=50)
experiments.eval("y=((x_1**2 + x_2 - 11)**2+(x_1 + x_2**2 -7)**2)", inplace=True)
experiments["valid_y"] = 1

Cross Validation

Run the cross validation

data_model = SingleTaskGPSurrogate(
    inputs=input_features,
    outputs=output_features,
)

model = surrogates.map(data_model=data_model)
train_cv, test_cv, pi = model.cross_validate(
    experiments,
    folds=5,
    hooks={
        "permutation_importance": permutation_importance_hook,
        "lengthscale_importance": lengthscale_importance_hook,
    },
)
combined_importances = {
    m.name: combine_permutation_importances(pi["permutation_importance"], m).describe()
    for m in RegressionMetricsEnum
}
combined_importances["lengthscale"] = combine_lengthscale_importances(
    pi["lengthscale_importance"],
).describe()
plot_feature_importance_by_feature_plotly(
    combined_importances,
    relative=False,
    caption="Permutation Feature Importances",
    show_std=True,
    importance_measure="Permutation Feature Importance",
)

Analyze the cross validation

Plots are added in a future PR.

# Performance on test sets
test_cv.get_metrics(combine_folds=True)
MAE MSD R2 MAPE PEARSON SPEARMAN FISHER
0 6.026825 178.461177 0.944999 0.411475 0.973066 0.974454 4.952116e-12
display(test_cv.get_metrics(combine_folds=False))
display(test_cv.get_metrics(combine_folds=False).describe())
MAE MSD R2 MAPE PEARSON SPEARMAN FISHER
0 7.897970 409.486789 0.815130 0.109937 0.915839 0.830303 0.103175
1 10.155666 361.020360 0.866782 0.108530 0.947811 0.915152 0.103175
2 4.228275 52.937155 0.986410 0.113775 0.995287 1.000000 0.003968
3 5.031610 48.965685 0.973890 0.447791 0.993271 1.000000 0.003968
4 2.820603 19.895896 0.993873 1.277340 0.997456 1.000000 0.003968
MAE MSD R2 MAPE PEARSON SPEARMAN FISHER
count 5.000000 5.000000 5.000000 5.000000 5.000000 5.000000 5.000000
mean 6.026825 178.461177 0.927217 0.411475 0.969933 0.949091 0.043651
std 2.960304 189.979604 0.081150 0.505561 0.036608 0.075891 0.054338
min 2.820603 19.895896 0.815130 0.108530 0.915839 0.830303 0.003968
25% 4.228275 48.965685 0.866782 0.109937 0.947811 0.915152 0.003968
50% 5.031610 52.937155 0.973890 0.113775 0.993271 1.000000 0.003968
75% 7.897970 361.020360 0.986410 0.447791 0.995287 1.000000 0.103175
max 10.155666 409.486789 0.993873 1.277340 0.997456 1.000000 0.103175