Contextual Multi-Armed Bandit

For the contextual multi-armed bandit (sMAB) when user information is available (context), we implemented a generalisation of Thompson sampling algorithm (Agrawal and Goyal, 2014) based on PyMC3.

title

The following notebook contains an example of usage of the class Cmab, which implements the algorithm above.

[1]:
import numpy as np

from pybandits.cmab import CmabBernoulli
from pybandits.model import BayesianLogisticRegression, StudentT
/home/runner/.cache/pypoetry/virtualenvs/pybandits-vYJB-miV-py3.10/lib/python3.10/site-packages/pydantic/_migration.py:283: UserWarning: `pydantic.generics:GenericModel` has been moved to `pydantic.BaseModel`.
  warnings.warn(f'`{import_path}` has been moved to `{new_location}`.')
[2]:
n_samples = 1000
n_features = 5

First, we need to define the input context matrix \(X\) of size (\(n\_samples, n\_features\)) and the mapping of possible actions \(a_i \in A\) to their associated model.

[3]:
# context
X = 2 * np.random.random_sample((n_samples, n_features)) - 1  # random float in the interval (-1, 1)
print("X: context matrix of shape (n_samples, n_features)")
print(X[:10])
X: context matrix of shape (n_samples, n_features)
[[-0.1613841  -0.49794167 -0.33376438  0.94778933 -0.30614152]
 [-0.4496147   0.60884007  0.9585795   0.20389095 -0.48560358]
 [-0.42225609 -0.65629762  0.8063207   0.16104495 -0.33538967]
 [-0.30597128 -0.13308005  0.97772063  0.6193355   0.3049692 ]
 [-0.72070883 -0.05191547 -0.58054673 -0.91069608  0.63320628]
 [-0.93996399 -0.72082382  0.4136631   0.38261984  0.32162887]
 [-0.48978556 -0.77543961 -0.95728434 -0.68990959  0.72228253]
 [-0.35171887  0.06042878 -0.86312631  0.93927125  0.68654729]
 [ 0.69956398 -0.47079873  0.53818873  0.37612796  0.29317551]
 [-0.7198139  -0.86793186  0.67701877 -0.76653332  0.30504903]]
[4]:
# define action model
actions = {
    "a1": BayesianLogisticRegression(alpha=StudentT(mu=1, sigma=2), betas=n_features * [StudentT()]),
    "a2": BayesianLogisticRegression(alpha=StudentT(mu=1, sigma=2), betas=n_features * [StudentT()]),
}

We can now init the bandit given the mapping of actions \(a_i\) to their model.

[5]:
# init contextual Multi-Armed Bandit model
cmab = CmabBernoulli(actions=actions)

The predict function below returns the action selected by the bandit at time \(t\): \(a_t = argmax_k P(r=1|\beta_k, x_t)\). The bandit selects one action per each sample of the contect matrix \(X\).

[6]:
# predict action
pred_actions, _, _ = cmab.predict(X)
print("Recommended action: {}".format(pred_actions[:10]))
Recommended action: ['a2', 'a1', 'a2', 'a2', 'a1', 'a2', 'a2', 'a1', 'a2', 'a1']

Now, we observe the rewards and the context from the environment. In this example rewards and the context are randomly simulated.

[7]:
# simulate reward from environment
simulated_rewards = np.random.randint(2, size=n_samples)
# simulate context from environment
simulated_context = 2 * np.random.random_sample((n_samples, n_features)) - 1  # random float in the interval (-1, 1)
print("Simulated rewards: {}".format(simulated_rewards[:10]))
Simulated rewards: [0 1 0 1 0 0 0 0 1 0]

Finally, we update the model providing per each action sample: (i) its context \(x_t\) (ii) the action \(a_t\) selected by the bandit, (iii) the corresponding reward \(r_t\).

[8]:
# update model
cmab.update(context=X, actions=pred_actions, rewards=simulated_rewards)
/home/runner/.cache/pypoetry/virtualenvs/pybandits-vYJB-miV-py3.10/lib/python3.10/site-packages/pymc/data.py:384: FutureWarning: Data is now always mutable. Specifying the `mutable` kwarg will raise an error in a future release
  warnings.warn(
Initializing NUTS using adapt_diag...
/home/runner/.cache/pypoetry/virtualenvs/pybandits-vYJB-miV-py3.10/lib/python3.10/site-packages/pytensor/link/c/cmodule.py:2959: UserWarning: PyTensor could not link to a BLAS installation. Operations that might benefit from BLAS will be severely degraded.
This usually happens when PyTensor is installed via pip. We recommend it be installed via conda/mamba/pixi instead.
Alternatively, you can use an experimental backend such as Numba or JAX that perform their own BLAS optimizations, by setting `pytensor.config.mode == 'NUMBA'` or passing `mode='NUMBA'` when compiling a PyTensor function.
For more options and details see https://pytensor.readthedocs.io/en/latest/troubleshooting.html#how-do-i-configure-test-my-blas-library
  warnings.warn(
Sequential sampling (2 chains in 1 job)
NUTS: [alpha, betas]
Sampling 2 chains for 500 tune and 1_000 draw iterations (1_000 + 2_000 draws total) took 3 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
/home/runner/.cache/pypoetry/virtualenvs/pybandits-vYJB-miV-py3.10/lib/python3.10/site-packages/pymc/data.py:384: FutureWarning: Data is now always mutable. Specifying the `mutable` kwarg will raise an error in a future release
  warnings.warn(
Initializing NUTS using adapt_diag...
Sequential sampling (2 chains in 1 job)
NUTS: [alpha, betas]
Sampling 2 chains for 500 tune and 1_000 draw iterations (1_000 + 2_000 draws total) took 3 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
[ ]: