{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Stochastic Multi-Armed Bandit with Zooming Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook demonstrates the usage of the Zooming model for quantitative action spaces with the stochastic multi-armed bandit (sMAB) implementation in `pybandits`.\n", "\n", "The Zooming model adaptively partitions a continuous action space and fits a model (e.g., Beta) to each segment. This allows efficient exploration and exploitation in continuous or high-cardinality action spaces through an adaptive discretization approach.\n", "\n", "References:\n", "- [Multi-Armed Bandits in Metric Spaces (Kleinberg, Slivkins, and Upfal, 2008)](https://arxiv.org/pdf/0809.4882)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from pybandits.quantitative_model import SmabZoomingModel\n", "from pybandits.smab import SmabBernoulli" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "\n", "First, we'll define actions with quantitative parameters. In this example, we'll use two actions, each with a one-dimensional quantitative parameter (e.g., price point or dosage level) ranging from 0 to 1." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# For reproducibility\n", "np.random.seed(42)\n", "\n", "# Define actions with zooming models\n", "actions = {\"action_1\": SmabZoomingModel.cold_start(dimension=1), \"action_2\": SmabZoomingModel.cold_start(dimension=1)}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can initialize the SmabBernoulli bandit with our zooming models:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize the bandit\n", "smab = SmabBernoulli(actions=actions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simulate Environment\n", "\n", "Let's create a reward function that depends on both the action and its quantitative parameter. For illustration purposes, we'll define that:\n", "\n", "- `action_1` has optimal parameter values around 0.25\n", "- `action_2` has optimal parameter values around 0.75\n", "\n", "The reward probability follows a bell curve centered around these optimal values." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def reward_function(action, quantity):\n", " if action == \"action_1\":\n", " # Bell curve centered at 0.25\n", " prob = np.exp(-((quantity - 0.25) ** 2) / 0.02)\n", " else: # action_2\n", " # Bell curve centered at 0.75\n", " prob = np.exp(-((quantity - 0.75) ** 2) / 0.02)\n", " return np.random.binomial(1, prob)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's visualize our reward functions to understand what the bandit needs to learn:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = np.linspace(0, 1, 100)\n", "y1 = [np.exp(-((xi - 0.25) ** 2) / 0.02) for xi in x]\n", "y2 = [np.exp(-((xi - 0.75) ** 2) / 0.02) for xi in x]\n", "\n", "plt.figure(figsize=(10, 6))\n", "plt.plot(x, y1, \"b-\", label=\"action_1\")\n", "plt.plot(x, y2, \"r-\", label=\"action_2\")\n", "plt.xlabel(\"Quantitative Parameter\")\n", "plt.ylabel(\"Reward Probability\")\n", "plt.title(\"True Reward Functions\")\n", "plt.legend()\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Bandit Training Loop\n", "\n", "Now, let's train our bandit by simulating interactions for several rounds:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_rounds = 500\n", "history = []\n", "\n", "for t in range(n_rounds):\n", " # Sample random quantities to be evaluated for each action\n", " # In a real-world scenario, these might be potential price points to test\n", "\n", " # Predict best action (n_samples=1)\n", " pred_actions, probs = smab.predict(n_samples=1)\n", " chosen_action = pred_actions[0][0]\n", " chosen_quantity = pred_actions[0][1][0]\n", "\n", " # Observe reward\n", " reward = reward_function(chosen_action, chosen_quantity)\n", "\n", " # Update bandit\n", " smab.update(actions=[chosen_action], rewards=[reward], quantities=[chosen_quantity])\n", "\n", " # Store history\n", " history.append((t, chosen_action, chosen_quantity, reward))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Analyze Results\n", "\n", "Let's visualize the bandit's learning process and choices over time:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract data from history\n", "rounds = [h[0] for h in history]\n", "actions_hist = [h[1] for h in history]\n", "quantities_hist = [h[2] for h in history]\n", "rewards_hist = [h[3] for h in history]\n", "\n", "# Create a scatter plot of chosen quantities over time\n", "plt.figure(figsize=(12, 6))\n", "colors = {\"action_1\": \"blue\", \"action_2\": \"red\"}\n", "for action in np.unique(actions_hist):\n", " mask = [a == action for a in actions_hist]\n", " plt.scatter(\n", " [rounds[i] for i in range(len(mask)) if mask[i]],\n", " [quantities_hist[i] for i in range(len(mask)) if mask[i]],\n", " c=[rewards_hist[i] for i in range(len(mask)) if mask[i]],\n", " cmap=\"coolwarm\",\n", " alpha=0.7,\n", " label=action,\n", " )\n", "\n", "plt.xlabel(\"Round\")\n", "plt.ylabel(\"Quantitative Parameter\")\n", "plt.title(\"Bandit Choices Over Time\")\n", "plt.legend()\n", "plt.colorbar(label=\"Reward\")\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's look at the distribution of choices in the last 100 rounds, when the bandit has had time to learn:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Analyze the last 100 rounds\n", "last_100_actions = actions_hist[-100:]\n", "last_100_quantities = quantities_hist[-100:]\n", "\n", "plt.figure(figsize=(10, 6))\n", "\n", "# Plot histogram of quantitative parameters for each action\n", "for action in np.unique(actions_hist):\n", " action_quantities = [last_100_quantities[i] for i in range(100) if last_100_actions[i] == action]\n", " if action_quantities:\n", " plt.hist(action_quantities, bins=20, alpha=0.5, label=action)\n", "\n", "plt.xlabel(\"Quantitative Parameter\")\n", "plt.ylabel(\"Frequency\")\n", "plt.title(\"Distribution of Chosen Parameters (Last 100 Rounds)\")\n", "plt.legend()\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reward Performance Over Time\n", "\n", "Let's analyze how the reward performance improves as the bandit learns:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Calculate cumulative average reward\n", "cumulative_reward = np.cumsum(rewards_hist)\n", "cumulative_avg_reward = [cumulative_reward[i] / (i + 1) for i in range(len(cumulative_reward))]\n", "\n", "# Calculate moving average with window=20\n", "window_size = 20\n", "moving_avg = [np.mean(rewards_hist[max(0, i - window_size) : i + 1]) for i in range(len(rewards_hist))]\n", "\n", "plt.figure(figsize=(12, 6))\n", "plt.plot(cumulative_avg_reward, label=\"Cumulative Average Reward\")\n", "plt.plot(moving_avg, label=f\"Moving Average (window={window_size})\")\n", "plt.xlabel(\"Round\")\n", "plt.ylabel(\"Average Reward\")\n", "plt.title(\"Reward Performance Over Time\")\n", "plt.legend()\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "The Zooming model with sMAB allows efficient exploration and exploitation of continuous action parameters. It adaptively refines the segmentation of the parameter space, concentrating more segments in high-reward regions for finer discretization.\n", "\n", "This approach is particularly useful when:\n", "1. Actions have continuous parameters that affect rewards\n", "2. The reward function is unknown and needs to be learned\n", "3. The optimal parameter values may vary across different actions\n", "\n", "Real-world applications include pricing optimization, dose finding in medicine, or any scenario involving continuous parameter tuning." ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python" } }, "nbformat": 4, "nbformat_minor": 4 }