Model Building with BoFire

This notebooks shows how to setup and analyze models trained with BoFire. It is still WIP.

Imports

import bofire.surrogates.api as surrogates
from bofire.data_models.domain.api import Inputs, Outputs
from bofire.data_models.enum import RegressionMetricsEnum
from bofire.data_models.features.api import ContinuousInput, ContinuousOutput
from bofire.data_models.surrogates.api import SingleTaskGPSurrogate
from bofire.plot.feature_importance import plot_feature_importance_by_feature_plotly
from bofire.surrogates.feature_importance import (
    combine_lengthscale_importances,
    combine_permutation_importances,
    lengthscale_importance_hook,
    permutation_importance_hook,
)

Problem Setup

For didactic purposes, we sample data from a Himmelblau benchmark function and use them to train a SingleTaskGP.

# TODO: replace this after JDs PR is ready.
input_features = Inputs(
    features=[ContinuousInput(key=f"x_{i+1}", bounds=(-4, 4)) for i in range(3)],
)
output_features = Outputs(features=[ContinuousOutput(key="y")])
experiments = input_features.sample(n=50)
experiments.eval("y=((x_1**2 + x_2 - 11)**2+(x_1 + x_2**2 -7)**2)", inplace=True)
experiments["valid_y"] = 1

Cross Validation

Run the cross validation

data_model = SingleTaskGPSurrogate(
    inputs=input_features,
    outputs=output_features,
)

model = surrogates.map(data_model=data_model)
train_cv, test_cv, pi = model.cross_validate(
    experiments,
    folds=5,
    hooks={
        "permutation_importance": permutation_importance_hook,
        "lengthscale_importance": lengthscale_importance_hook,
    },
)
combined_importances = {
    m.name: combine_permutation_importances(pi["permutation_importance"], m).describe()
    for m in RegressionMetricsEnum
}
combined_importances["lengthscale"] = combine_lengthscale_importances(
    pi["lengthscale_importance"],
).describe()
plot_feature_importance_by_feature_plotly(
    combined_importances,
    relative=False,
    caption="Permutation Feature Importances",
    show_std=True,
    importance_measure="Permutation Feature Importance",
)

Analyze the cross validation

Plots are added in a future PR.

# Performance on test sets
test_cv.get_metrics(combine_folds=True)
MAE MSD R2 MAPE PEARSON SPEARMAN FISHER
0 5.393774 69.683719 0.985567 0.442587 0.992917 0.986843 7.910729e-15
display(test_cv.get_metrics(combine_folds=False))
display(test_cv.get_metrics(combine_folds=False).describe())
MAE MSD R2 MAPE PEARSON SPEARMAN FISHER
0 3.167514 19.542501 0.997057 0.060800 0.998756 0.987879 0.003968
1 4.853868 48.283872 0.989991 1.274604 0.996599 0.963636 0.003968
2 6.569830 131.349639 0.962255 0.137323 0.981612 0.963636 0.003968
3 6.816940 97.743277 0.979839 0.179474 0.992397 0.987879 0.003968
4 5.560719 51.499305 0.987374 0.560736 0.994116 0.987879 0.103175
MAE MSD R2 MAPE PEARSON SPEARMAN FISHER
count 5.000000 5.000000 5.000000 5.000000 5.000000 5.000000 5.000000
mean 5.393774 69.683719 0.983303 0.442587 0.992696 0.978182 0.023810
std 1.473441 44.420695 0.013281 0.503584 0.006651 0.013278 0.044366
min 3.167514 19.542501 0.962255 0.060800 0.981612 0.963636 0.003968
25% 4.853868 48.283872 0.979839 0.137323 0.992397 0.963636 0.003968
50% 5.560719 51.499305 0.987374 0.179474 0.994116 0.987879 0.003968
75% 6.569830 97.743277 0.989991 0.560736 0.996599 0.987879 0.003968
max 6.816940 131.349639 0.997057 1.274604 0.998756 0.987879 0.103175