---
title: TNK Benchmark
jupyter: python3
---
## Imports
```{python}
#| papermill: {duration: 3.151705, end_time: '2024-10-10T20:35:50.148179', exception: false, start_time: '2024-10-10T20:35:46.996474', status: completed}
#| tags: []
import os
import warnings
import matplotlib.pyplot as plt
import pandas as pd
import bofire.strategies.api as strategies
from bofire.benchmarks.api import TNK
from bofire.data_models.api import Domain
from bofire.data_models.strategies.api import MoboStrategy, RandomStrategy
from bofire.runners.api import run
from bofire.utils.multiobjective import compute_hypervolume
warnings.simplefilter("once")
SMOKE_TEST = os.environ.get("SMOKE_TEST")
```
## Random Strategy
```{python}
#| papermill: {duration: 0.225999, end_time: '2024-10-10T20:35:50.376983', exception: true, start_time: '2024-10-10T20:35:50.150984', status: failed}
#| tags: []
def sample(domain):
datamodel = RandomStrategy(domain=domain)
sampler = strategies.map(data_model=datamodel)
sampled = sampler.ask(10)
return sampled
def hypervolume(domain: Domain, experiments: pd.DataFrame) -> float:
return compute_hypervolume(
domain,
experiments.loc[(experiments.c1 >= 0) & (experiments.c2 <= 0.5)],
ref_point={"f1": 4, "f2": 4},
)
random_results = run(
TNK(),
strategy_factory=lambda domain: strategies.map(RandomStrategy(domain=domain)),
n_iterations=100 if not SMOKE_TEST else 1,
metric=hypervolume,
initial_sampler=sample,
n_runs=1,
n_procs=1,
)
```
## MOBO Strategy
```{python}
#| papermill: {duration: null, end_time: null, exception: null, start_time: null, status: pending}
#| tags: []
def strategy_factory(domain: Domain):
data_model = MoboStrategy(domain=domain, ref_point={"f1": 4.0, "f2": 4.0})
return strategies.map(data_model)
results = run(
TNK(),
strategy_factory=strategy_factory,
n_iterations=100 if not SMOKE_TEST else 1,
metric=hypervolume,
initial_sampler=sample,
n_runs=1,
n_procs=1,
)
```
```{python}
#| papermill: {duration: null, end_time: null, exception: null, start_time: null, status: pending}
#| tags: []
if not SMOKE_TEST:
fig, ax = plt.subplots()
ax.plot(random_results[0][1], label="random")
ax.plot(results[0][1], label="MOBO")
ax.set_xlabel("iteration")
ax.set_ylabel("hypervolume")
ax.legend()
plt.show()
```