{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# Transfer Learning for Contextual Multi-Armed Bandits\n", "\n", "This tutorial shows how to evolve a trained CMAB without starting from scratch:\n", "\n", "1. **Train** a CMAB with 3 context features and 2 actions\n", "2. **Evolve** it to use 4 features and add a new action — preserving everything learned\n", "3. **Continue training** the evolved model\n", "\n", "The key function is `edit_model_on_the_fly(current_mab, new_mab)`. \n", "It takes `new_mab` as the template (defines actions and config) and transfers\n", "learned weights from `current_mab` for overlapping actions. \n", "When the template has more features, it automatically expands the current model's\n", "weight matrices and fills the new rows from the template's cold-start weights." ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from pybandits.cmab import CmabBernoulli\n", "from pybandits.strategy import ClassicBandit\n", "from pybandits.transfer import edit_model_on_the_fly\n", "\n", "np.random.seed(42)" ] }, { "cell_type": "markdown", "id": "3", "metadata": {}, "source": [ "## Step 1: Train a CMAB with 3 features and 2 actions" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": {}, "outputs": [], "source": [ "N_FEATURES_V1 = 3\n", "ACTIONS_V1 = {\"action_A\", \"action_B\"}\n", "\n", "mab_v1 = CmabBernoulli.cold_start(\n", " action_ids=ACTIONS_V1,\n", " n_features=N_FEATURES_V1,\n", " activation=\"tanh\",\n", " strategy=ClassicBandit(),\n", " update_kwargs={\"num_steps\": 200},\n", ")\n", "\n", "print(f\"Actions : {sorted(mab_v1.actions)}\")\n", "print(f\"Features: {N_FEATURES_V1}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [], "source": [ "# Simulate an initial batch of interactions\n", "N_TRAIN = 200\n", "context_v1 = np.random.randn(N_TRAIN, N_FEATURES_V1)\n", "\n", "# Predict\n", "actions, probs, _ = mab_v1.predict(context=context_v1)\n", "\n", "# Simulate rewards: action_A has higher reward probability\n", "rewards = [int(np.random.rand() < (0.7 if a == \"action_A\" else 0.3)) for a in actions]\n", "\n", "# Update the model\n", "mab_v1.update(actions=actions, rewards=rewards, context=context_v1)\n", "\n", "print(\"Training complete.\")\n", "for aid in sorted(mab_v1.actions):\n", " act = mab_v1.actions[aid]\n", " print(f\" {aid}: n_successes={act.n_successes}, n_failures={act.n_failures}\")" ] }, { "cell_type": "markdown", "id": "6", "metadata": {}, "source": [ "## Step 2: Evolve — add a new action and expand to 4 features\n", "\n", "Create a cold-start template with the desired final configuration:\n", "- Same `activation` (structural — must match to allow weight transfer)\n", "- One extra feature (`n_features=4`)\n", "- The two original actions **plus** a new `action_C`\n", "\n", "`edit_model_on_the_fly` will:\n", "1. Detect the dimension gap (3 → 4) and expand `mab_v1`'s weight matrices\n", "2. Merge with the template: `action_A` and `action_B` keep their learned weights; `action_C` starts cold" ] }, { "cell_type": "code", "execution_count": null, "id": "7", "metadata": {}, "outputs": [], "source": [ "N_FEATURES_V2 = 4\n", "ACTIONS_V2 = {\"action_A\", \"action_B\", \"action_C\"}\n", "\n", "template_v2 = CmabBernoulli.cold_start(\n", " action_ids=ACTIONS_V2,\n", " n_features=N_FEATURES_V2,\n", " activation=\"tanh\", # must match mab_v1\n", " strategy=ClassicBandit(),\n", " update_kwargs={\"num_steps\": 200},\n", ")\n", "\n", "mab_v2 = edit_model_on_the_fly(mab_v1, template_v2)\n", "\n", "print(f\"Actions : {sorted(mab_v2.actions)}\")\n", "print(f\"Features: {mab_v2.input_dim}\")\n", "print()\n", "print(\"Learned state preserved for existing actions:\")\n", "for aid in sorted(ACTIONS_V1): # original actions\n", " orig = mab_v1.actions[aid]\n", " evolved = mab_v2.actions[aid]\n", " assert evolved.n_successes == orig.n_successes\n", " assert evolved.n_failures == orig.n_failures\n", " print(f\" {aid}: n_successes={evolved.n_successes}, n_failures={evolved.n_failures} ✓\")\n", "print()\n", "print(\"New action starts cold:\")\n", "new_act = mab_v2.actions[\"action_C\"]\n", "print(f\" action_C: n_successes={new_act.n_successes}, n_failures={new_act.n_failures}\")" ] }, { "cell_type": "markdown", "id": "8", "metadata": {}, "source": [ "## Step 3: Continue training the evolved model" ] }, { "cell_type": "code", "execution_count": null, "id": "9", "metadata": {}, "outputs": [], "source": [ "N_TRAIN_V2 = 200\n", "context_v2 = np.random.randn(N_TRAIN_V2, N_FEATURES_V2) # 4 features now\n", "\n", "actions_v2, probs_v2, _ = mab_v2.predict(context=context_v2)\n", "\n", "rewards_v2 = [\n", " int(np.random.rand() < (0.7 if a == \"action_A\" else (0.5 if a == \"action_C\" else 0.3))) for a in actions_v2\n", "]\n", "\n", "mab_v2.update(actions=actions_v2, rewards=rewards_v2, context=context_v2)\n", "\n", "print(\"Continued training complete.\")\n", "for aid in sorted(mab_v2.actions):\n", " act = mab_v2.actions[aid]\n", " print(f\" {aid}: n_successes={act.n_successes}, n_failures={act.n_failures}\")" ] }, { "cell_type": "markdown", "id": "10", "metadata": {}, "source": [ "## Summary\n", "\n", "| Step | Actions | Features | How |\n", "|------|---------|----------|-----|\n", "| v1 (initial) | A, B | 3 | `cold_start` |\n", "| v1 (trained) | A, B | 3 | `update` |\n", "| v2 (evolved) | A, B, **C** | **4** | `edit_model_on_the_fly` |\n", "| v2 (trained) | A, B, C | 4 | `update` |\n", "\n", "**What `edit_model_on_the_fly` did:**\n", "- Expanded `action_A` and `action_B` weight matrices from shape `(3, ...)` to `(4, ...)`\n", "- The new 4th-feature row is initialised from the template's cold-start weights\n", "- All existing learned weights (and n_successes / n_failures counts) were copied unchanged\n", "- `action_C` was added fresh from the template\n", "\n", "**Constraints to keep in mind:**\n", "- `activation` and `use_residual_connections` are *structural* — they must be identical in both MABs\n", "- The template must have **≥** as many features as the current model (can expand, cannot shrink)\n", "- `dist_type`, `hidden_dim_list`, `update_kwargs` and `update_method` are all freely changeable" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }