Transfer Learning for Contextual Multi-Armed Bandits

This tutorial shows how to evolve a trained CMAB without starting from scratch:

  1. Train a CMAB with 3 context features and 2 actions

  2. Evolve it to use 4 features and add a new action — preserving everything learned

  3. Continue training the evolved model

The key function is edit_model_on_the_fly(current_mab, new_mab).
It takes new_mab as the template (defines actions and config) and transfers learned weights from current_mab for overlapping actions.
When the template has more features, it automatically expands the current model’s weight matrices and fills the new rows from the template’s cold-start weights.

Setup

[1]:
import numpy as np

from pybandits.cmab import CmabBernoulli
from pybandits.strategy import ClassicBandit
from pybandits.transfer import edit_model_on_the_fly

np.random.seed(42)
/home/runner/.cache/pypoetry/virtualenvs/pybandits-vYJB-miV-py3.10/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Step 1: Train a CMAB with 3 features and 2 actions

[2]:
N_FEATURES_V1 = 3
ACTIONS_V1 = {"action_A", "action_B"}

mab_v1 = CmabBernoulli.cold_start(
    action_ids=ACTIONS_V1,
    n_features=N_FEATURES_V1,
    activation="tanh",
    strategy=ClassicBandit(),
    update_kwargs={"num_steps": 200},
)

print(f"Actions : {sorted(mab_v1.actions)}")
print(f"Features: {N_FEATURES_V1}")
Actions : ['action_A', 'action_B']
Features: 3
[3]:
# Simulate an initial batch of interactions
N_TRAIN = 200
context_v1 = np.random.randn(N_TRAIN, N_FEATURES_V1)

# Predict
actions, probs, _ = mab_v1.predict(context=context_v1)

# Simulate rewards: action_A has higher reward probability
rewards = [int(np.random.rand() < (0.7 if a == "action_A" else 0.3)) for a in actions]

# Update the model
mab_v1.update(actions=actions, rewards=rewards, context=context_v1)

print("Training complete.")
for aid in sorted(mab_v1.actions):
    act = mab_v1.actions[aid]
    print(f"  {aid}: n_successes={act.n_successes}, n_failures={act.n_failures}")

Training complete.
  action_A: n_successes=72, n_failures=28
  action_B: n_successes=30, n_failures=74

Step 2: Evolve — add a new action and expand to 4 features

Create a cold-start template with the desired final configuration:

  • Same activation (structural — must match to allow weight transfer)

  • One extra feature (n_features=4)

  • The two original actions plus a new action_C

edit_model_on_the_fly will:

  1. Detect the dimension gap (3 → 4) and expand mab_v1’s weight matrices

  2. Merge with the template: action_A and action_B keep their learned weights; action_C starts cold

[4]:
N_FEATURES_V2 = 4
ACTIONS_V2 = {"action_A", "action_B", "action_C"}

template_v2 = CmabBernoulli.cold_start(
    action_ids=ACTIONS_V2,
    n_features=N_FEATURES_V2,
    activation="tanh",  # must match mab_v1
    strategy=ClassicBandit(),
    update_kwargs={"num_steps": 200},
)

mab_v2 = edit_model_on_the_fly(mab_v1, template_v2)

print(f"Actions : {sorted(mab_v2.actions)}")
print(f"Features: {mab_v2.input_dim}")
print()
print("Learned state preserved for existing actions:")
for aid in sorted(ACTIONS_V1):  # original actions
    orig = mab_v1.actions[aid]
    evolved = mab_v2.actions[aid]
    assert evolved.n_successes == orig.n_successes
    assert evolved.n_failures == orig.n_failures
    print(f"  {aid}: n_successes={evolved.n_successes}, n_failures={evolved.n_failures}  ✓")
print()
print("New action starts cold:")
new_act = mab_v2.actions["action_C"]
print(f"  action_C: n_successes={new_act.n_successes}, n_failures={new_act.n_failures}")
2026-03-29 18:50:32.676 | INFO     | pybandits.transfer:_expand_with_template_weights:543 - Expanding current CMAB from 3 to 4 features using template's weights for 1 new feature(s)
2026-03-29 18:50:32.680 | INFO     | pybandits.transfer:_merge_mabs:365 - Merged CmabBernoulli: used mab2 as template with 3 action(s), transferred learned state from mab1 for 2 overlapping action(s).
2026-03-29 18:50:32.683 | INFO     | pybandits.transfer:edit_model_on_the_fly:735 - Updated MAB using new_mab as template. Final MAB has 3 action(s) with new_mab's configuration and current_mab's learned state for overlapping actions.
Actions : ['action_A', 'action_B', 'action_C']
Features: 4

Learned state preserved for existing actions:
  action_A: n_successes=72, n_failures=28  ✓
  action_B: n_successes=30, n_failures=74  ✓

New action starts cold:
  action_C: n_successes=1, n_failures=1

Step 3: Continue training the evolved model

[5]:
N_TRAIN_V2 = 200
context_v2 = np.random.randn(N_TRAIN_V2, N_FEATURES_V2)  # 4 features now

actions_v2, probs_v2, _ = mab_v2.predict(context=context_v2)

rewards_v2 = [
    int(np.random.rand() < (0.7 if a == "action_A" else (0.5 if a == "action_C" else 0.3))) for a in actions_v2
]

mab_v2.update(actions=actions_v2, rewards=rewards_v2, context=context_v2)

print("Continued training complete.")
for aid in sorted(mab_v2.actions):
    act = mab_v2.actions[aid]
    print(f"  {aid}: n_successes={act.n_successes}, n_failures={act.n_failures}")

Continued training complete.
  action_A: n_successes=113, n_failures=51
  action_B: n_successes=45, n_failures=102
  action_C: n_successes=37, n_failures=58

Summary

Step

Actions

Features

How

v1 (initial)

A, B

3

cold_start

v1 (trained)

A, B

3

update

v2 (evolved)

A, B, C

4

edit_model_on_the_fly

v2 (trained)

A, B, C

4

update

What ``edit_model_on_the_fly`` did:

  • Expanded action_A and action_B weight matrices from shape (3, ...) to (4, ...)

  • The new 4th-feature row is initialised from the template’s cold-start weights

  • All existing learned weights (and n_successes / n_failures counts) were copied unchanged

  • action_C was added fresh from the template

Constraints to keep in mind:

  • activation and use_residual_connections are structural — they must be identical in both MABs

  • The template must have as many features as the current model (can expand, cannot shrink)

  • dist_type, hidden_dim_list, update_kwargs and update_method are all freely changeable