{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Contextual Multi-Armed Bandit with BNN-Based Quantitative Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook demonstrates the usage of a **Bayesian Neural Network (BNN) based quantitative model** for continuous action parameters with the contextual multi-armed bandit (CMAB) implementation in `pybandits`.\n", "\n", "The BNN quantitative model maps context features and a continuous action parameter (e.g., price, dosage, bid) to a reward distribution. A Bayesian Neural Network is used to model the relationship between context, the quantitative parameter, and the expected reward, providing uncertainty estimates that support exploration–exploitation. Unlike discrete-action CMABs, this approach lets you learn and optimize over continuous or fine-grained quantitative choices while conditioning on context.\n", "\n", "**Key aspects:**\n", "- **Context + quantity**: The model takes both contextual features and a scalar (or vector) quantitative parameter as input.\n", "- **Uncertainty-aware**: The BNN yields a posterior over rewards, which the bandit can use for Thompson sampling or similar strategies.\n", "- **Flexible fitting**: Supports variational inference (VI) or other Bayesian update methods for the BNN weights." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from pybandits.cmab import CmabBernoulli\n", "from pybandits.quantitative_model import QuantitativeBayesianNeuralNetwork\n", "\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "\n", "First, we'll define actions with quantitative parameters. In this example, we'll use two actions, each with a one-dimensional quantitative parameter (e.g., price point or dosage level) ranging from 0 to 1. Unlike the SMAB model, here we also need to define contextual features." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# For reproducibility\n", "np.random.seed(42)\n", "\n", "# Define number of features for the context\n", "n_features = 3\n", "# Define number of segments for each action\n", "n_max_segments = 16 # Maximum number of segments for each action\n", "\n", "# Define cold start parameters for the base model\n", "update_method = \"VI\" # Variational Inference for Bayesian updates\n", "update_kwargs = {\"num_steps\": 20000, \"optimizer_type\": \"adam\", \"optimizer_kwargs\": {\"step_size\": 0.001}}\n", "dist_params_init = {\"mu\": 0, \"sigma\": 10, \"nu\": 5}\n", "\n", "# Define actions with zooming models\n", "actions = {\n", " \"action_1\": QuantitativeBayesianNeuralNetwork.cold_start(\n", " dimension=1,\n", " n_features=n_features,\n", " base_model_cold_start_kwargs=dict(\n", " hidden_dim_list=[10],\n", " update_method=\"VI\",\n", " update_kwargs=update_kwargs,\n", " dist_params_init=dist_params_init,\n", " activation=\"tanh\",\n", " ),\n", " ),\n", " \"action_2\": QuantitativeBayesianNeuralNetwork.cold_start(\n", " dimension=1,\n", " n_features=n_features,\n", " base_model_cold_start_kwargs=dict(\n", " hidden_dim_list=[10],\n", " update_method=\"VI\",\n", " update_kwargs=update_kwargs,\n", " dist_params_init=dist_params_init,\n", " activation=\"tanh\",\n", " ),\n", " ),\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "actions[\"action_1\"].bnn.update_kwargs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can initialize the CmabBernoulli bandit with our zooming models:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize the bandit\n", "cmab = CmabBernoulli(actions=actions, epsilon=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(cmab.actions[\"action_1\"].bnn.update_kwargs)\n", "print(f\"epsilon: {cmab.epsilon}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simulate Environment\n", "\n", "Let's create a reward function that depends on both the action, its quantitative parameter, and the context. For illustration purposes, we'll define that:\n", "\n", "- `action_1` performs better when the first context feature is high and when the quantitative parameter is around 0.25\n", "- `action_2` performs better when the second context feature is high and when the quantitative parameter is around 0.75\n", "\n", "The reward probability follows a bell curve for the quantitative parameter and is also influenced by the context features." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def reward_function(action, quantity, context):\n", " if action == \"action_1\":\n", " # Bell curve centered at 0.25 for the quantity\n", " # Influenced by first context feature\n", " quantity_component = np.exp(-((quantity - 0.25) ** 2) / 0.02)\n", " context_component = 0.5 + 0.5 * (context[0] / 2) # First feature has influence\n", " prob = quantity_component * context_component\n", " else: # action_2\n", " # Bell curve centered at 0.75 for the quantity\n", " # Influenced by second context feature\n", " quantity_component = np.exp(-((quantity - 0.75) ** 2) / 0.02)\n", " context_component = 0.5 + 0.5 * (context[1] / 2) # Second feature has influence\n", " prob = quantity_component * context_component\n", "\n", " # Ensure probability is between 0 and 1\n", " prob = max(0, min(1, prob))\n", "\n", " return np.random.binomial(1, prob), prob\n", "\n", "\n", "def get_optimal_reward(context):\n", " max_prob_action_1 = 0.5 + 0.5 * (context[0] / 2)\n", " max_prob_action_2 = 0.5 + 0.5 * (context[1] / 2)\n", " return max(max_prob_action_1, max_prob_action_2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's visualize our reward functions to understand what the bandit needs to learn. We'll show the reward surfaces for different values of context:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = np.linspace(0, 1, 100)\n", "\n", "# Plot for three different contexts\n", "contexts = [\n", " np.array([1.0, 0.0, 0.0]), # High first feature\n", " np.array([0.0, 1.0, 0.0]), # High second feature\n", " np.array([0.5, 0.5, 0.0]), # Mixed features\n", "]\n", "\n", "plt.figure(figsize=(16, 5))\n", "for i, context in enumerate(contexts, 1):\n", " plt.subplot(1, 3, i)\n", "\n", " y1 = [np.exp(-((xi - 0.25) ** 2) / 0.02) * (0.5 + 0.5 * (context[0] / 2)) for xi in x]\n", " y2 = [np.exp(-((xi - 0.75) ** 2) / 0.02) * (0.5 + 0.5 * (context[1] / 2)) for xi in x]\n", "\n", " plt.plot(x, y1, \"b-\", label=\"action_1\")\n", " plt.plot(x, y2, \"r-\", label=\"action_2\")\n", " plt.xlabel(\"Quantitative Parameter\")\n", " plt.ylabel(\"Reward Probability\")\n", "\n", " if i == 1:\n", " title = \"Context: High Feature 1\"\n", " elif i == 2:\n", " title = \"Context: High Feature 2\"\n", " else:\n", " title = \"Context: Mixed Features\"\n", "\n", " plt.title(title)\n", " plt.legend()\n", " plt.grid(True)\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate Synthetic Context Data\n", "\n", "Let's create synthetic context data for our experiment:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Generate random context data\n", "batch_sizes = [20000, 300, 300, 300, 300, 300]\n", "context_data_sample = np.random.uniform(0, 1, (5, n_features))\n", "\n", "# Preview the context data\n", "pd.DataFrame(context_data_sample[:5], columns=[f\"Feature {i + 1}\" for i in range(n_features)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Bandit Training Loop\n", "\n", "Now, let's train our bandit by simulating interactions for several rounds:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for iter, batch_size in enumerate(batch_sizes):\n", " if iter > 0:\n", " cmab = CmabBernoulli(actions=actions, epsilon=0) # no exploration\n", "\n", " # Get context for this round\n", " current_context = np.random.uniform(0, 1, (batch_size, n_features))\n", "\n", " # Predict best action\n", " pred_actions, model_probs, weighted_sums = cmab.predict(context=current_context)\n", " chosen_actions = [a[0] for a in pred_actions]\n", " chosen_quantities = [a[1][0] for a in pred_actions]\n", "\n", " # Observe reward\n", " rewards_and_probs = [\n", " reward_function(chosen_action, chosen_quantity, _current_context)\n", " for chosen_action, chosen_quantity, _current_context in zip(chosen_actions, chosen_quantities, current_context)\n", " ]\n", " rewards = [reward_and_prob[0] for reward_and_prob in rewards_and_probs]\n", " probs = [reward_and_prob[1] for reward_and_prob in rewards_and_probs]\n", "\n", " optimal_probs = [get_optimal_reward(context) for context in current_context]\n", "\n", " regret = np.mean(np.array(optimal_probs) - np.array(probs))\n", "\n", " # Update bandit\n", " cmab.update(actions=chosen_actions, rewards=rewards, context=current_context, quantities=chosen_quantities)\n", "\n", " # Print progress\n", " print(f\"Completed {iter} batches. Avg regret: {regret}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "### Plot Reward surface - actual vs. predicted" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "grid = np.mgrid[0:1:0.01, 0:1:0.01].astype(float)\n", "grid_2d = grid.reshape(2, -1).T\n", "plt.figure(figsize=(10, 6))\n", "ax = plt.subplot(1, 2, 1)\n", "y_true = np.zeros((100, 100))\n", "reward_prob = np.zeros((100, 100))\n", "for i, quantity in enumerate(np.linspace(0, 1, 100)):\n", " for j, first_feature in enumerate(np.linspace(0, 1, 100)):\n", " y_true[i, j], reward_prob[i, j] = reward_function(\n", " action=\"action_1\", quantity=quantity, context=[first_feature, 0, 0]\n", " )\n", "\n", "cmap = plt.get_cmap(\"coolwarm\")\n", "# Create the contour plot\n", "contour = ax.contourf(*grid, reward_prob.reshape(100, 100), cmap=cmap)\n", "cbar = plt.colorbar(contour, ax=ax)\n", "ax.set(title=\"action_1\", ylabel=\"First feature\", xlabel=\"Quantitative Parameter\")\n", "\n", "ax = plt.subplot(1, 2, 2)\n", "y_true = np.zeros((100, 100))\n", "reward_prob = np.zeros((100, 100))\n", "for i, quantity in enumerate(np.linspace(0, 1, 100)):\n", " for j, second_feature in enumerate(np.linspace(0, 1, 100)):\n", " y_true[i, j], reward_prob[i, j] = reward_function(\n", " action=\"action_2\", quantity=quantity, context=[0, second_feature, 0]\n", " )\n", "\n", "cmap = plt.get_cmap(\"coolwarm\")\n", "# Create the contour plot\n", "contour = ax.contourf(*grid, reward_prob.reshape(100, 100), cmap=cmap)\n", "cbar = plt.colorbar(contour, ax=ax)\n", "ax.set(title=\"action_2\", ylabel=\"Second feature\", xlabel=\"Quantitative Parameter\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "grid_2d_action_1 = np.append(grid_2d, np.zeros((grid_2d.shape[0], 2)), axis=1)\n", "\n", "batch_predictions_action_1 = [\n", " cmab.actions[\"action_1\"].bnn.sample_proba(grid_2d_action_1) for _ in range(500)\n", "] # predictions are list of tuples of probabilities and corresponding weighted sums\n", "batch_proba_action_1 = np.array(\n", " [\n", " [proba_and_weighted_sum[0] for proba_and_weighted_sum in predictions]\n", " for predictions in batch_predictions_action_1\n", " ]\n", ")\n", "\n", "grid_2d_action_2 = np.concatenate(\n", " [grid_2d[:, 0:1], np.zeros((grid_2d.shape[0], 1)), grid_2d[:, 1:2], np.zeros((grid_2d.shape[0], 1))], axis=1\n", ")\n", "\n", "batch_predictions_action_2 = [\n", " cmab.actions[\"action_2\"].bnn.sample_proba(grid_2d_action_2) for _ in range(500)\n", "] # predictions are list of tuples of probabilities and corresponding weighted sums\n", "batch_proba_action_2 = np.array(\n", " [\n", " [proba_and_weighted_sum[0] for proba_and_weighted_sum in predictions]\n", " for predictions in batch_predictions_action_2\n", " ]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(10, 6))\n", "ax = plt.subplot(1, 2, 1)\n", "cmap = plt.get_cmap(\"coolwarm\")\n", "\n", "# Create the contour plot\n", "pred_proba_mean = batch_proba_action_1.mean(axis=0)\n", "contour = ax.contourf(*grid, pred_proba_mean.reshape(100, 100), cmap=cmap)\n", "cbar = plt.colorbar(contour, ax=ax)\n", "ax.set(ylabel=\"First feature\", xlabel=\"Quantitative Parameter\")\n", "\n", "ax = plt.subplot(1, 2, 2)\n", "pred_proba_mean = batch_proba_action_2.mean(axis=0)\n", "contour = ax.contourf(*grid, pred_proba_mean.reshape(100, 100), cmap=cmap)\n", "cbar = plt.colorbar(contour, ax=ax)\n", "ax.set(ylabel=\"Second feature\", xlabel=\"Quantitative Parameter\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Testing with Specific Contexts\n", "\n", "Finally, let's test our trained bandit with specific contexts to see if it has learned the optimal policy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = np.linspace(0, 1, 100)\n", "n_samples = 100\n", "\n", "# Plot for three different contexts\n", "context = (np.array([1.0, 0.0, 0.0] * n_samples),) # High first feature # Mixed features\n", "\n", "plt.figure(figsize=(16, 5))\n", "for i, context in enumerate(contexts, 1):\n", " plt.subplot(1, 3, i)\n", "\n", " y1 = [np.exp(-((xi - 0.25) ** 2) / 0.02) * (0.5 + 0.5 * (context[0] / 2)) for xi in x]\n", " y2 = [np.exp(-((xi - 0.75) ** 2) / 0.02) * (0.5 + 0.5 * (context[1] / 2)) for xi in x]\n", "\n", " plt.plot(x, y1, \"b-\", label=\"action_1\")\n", " plt.plot(x, y2, \"r-\", label=\"action_2\")\n", " plt.xlabel(\"Quantitative Parameter\")\n", " plt.ylabel(\"Reward Probability\")\n", "\n", " if i == 1:\n", " title = \"Context: High Feature 1\"\n", " elif i == 2:\n", " title = \"Context: High Feature 2\"\n", " else:\n", " title = \"Context: Mixed Features\"\n", "\n", " plt.title(title)\n", " plt.legend()\n", " plt.grid(True)\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define test contexts\n", "test_contexts = np.array(\n", " [\n", " [1.0, 0.0, 0.0], # High feature 1, low feature 2\n", " [0.0, 1.0, 0.0], # Low feature 1, high feature 2\n", " [1.0, 1.0, 0.0], # High feature 1 and 2\n", " [0.0, 0.0, 0.0], # Low feature 1 and 2\n", " ]\n", ")\n", "\n", "# Test predictions\n", "results = []\n", "for i, context in enumerate(test_contexts):\n", " context_reshaped = context.reshape(1, -1)\n", " pred_actions, probs, weighted_sums = cmab.predict(context=context_reshaped)\n", " chosen_action_quantity = pred_actions[0]\n", " chosen_action = chosen_action_quantity[0]\n", " chosen_quantities = chosen_action_quantity[1][0]\n", " chosen_action_probs = probs[0][chosen_action](chosen_quantities)\n", "\n", " # Sample optimal quantity for the chosen action\n", " # In a real application, you would have a method to test different quantities\n", " # Here we'll use our knowledge of the true optimal values\n", " if chosen_action == \"action_1\":\n", " optimal_quantity = 0.25\n", " else:\n", " optimal_quantity = 0.75\n", "\n", " # Expected reward probability\n", " expected_reward = reward_function(chosen_action, optimal_quantity, context)\n", "\n", " results.append(\n", " {\n", " \"Context\": context,\n", " \"Chosen Action\": chosen_action,\n", " \"Chosen Quantity\": chosen_quantities,\n", " \"Action Probabilities\": chosen_action_probs,\n", " \"Optimal Quantity\": optimal_quantity,\n", " \"Expected Reward\": expected_reward,\n", " }\n", " )\n", "\n", "# Display results\n", "for i, result in enumerate(results):\n", " context_type = \"\"\n", " if i == 0:\n", " context_type = \"High feature 1, low feature 2\"\n", " elif i == 1:\n", " context_type = \"Low feature 1, high feature 2\"\n", " elif i == 2:\n", " context_type = \"High feature 1 and 2\"\n", " elif i == 3:\n", " context_type = \"Low feature 1 and 2\"\n", "\n", " print(f\"\\nTest {i + 1}: {context_type}\")\n", " print(f\"Context: {result['Context']}\")\n", " print(f\"Chosen Action: {result['Chosen Action']}\")\n", " print(f\"Chosen Quantity: {result['Chosen Quantity']}\")\n", " print(f\"Action Probabilities: {result['Action Probabilities']}\")\n", " print(f\"Optimal Quantity: {result['Optimal Quantity']:.2f}\")\n", " print(f\"Expected Reward: {result['Expected Reward']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "The CMAB BNN-based quantitative model uses a Bayesian Neural Network to map context and continuous action parameters to reward distributions. This approach enables efficient exploration and exploitation of continuous or fine-grained action parameters while conditioning on context. The BNN provides uncertainty estimates (e.g., via variational inference) that the bandit uses for Thompson sampling or similar strategies, balancing exploration of uncertain regions with exploitation of high predicted rewards.\n", "\n", "This approach is particularly useful when:\n", "1. Actions have continuous parameters (e.g., price, dosage, bid) that affect rewards\n", "2. The reward function depends on both context and action parameters\n", "3. The optimal parameter values may vary across different contexts\n", "4. You want a single differentiable model (BNN) over the full parameter space rather than adaptive discretization\n", "\n", "Real-world applications include:\n", "- **Personalized pricing**: Find optimal prices (continuous parameter) based on customer features (context)\n", "- **Content recommendation**: Optimize content parameters (e.g., length, complexity) based on user demographics\n", "- **Medical dosing**: Determine optimal medication dosages based on patient characteristics\n", "- **Ad campaign optimization**: Find best bid values based on ad placement and target audience" ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python" } }, "nbformat": 4, "nbformat_minor": 4 }