diff --git a/borrowing_strenght_hierarchical_model.ipynb b/borrowing_strenght_hierarchical_model.ipynb index 9c628db..2932893 100644 --- a/borrowing_strenght_hierarchical_model.ipynb +++ b/borrowing_strenght_hierarchical_model.ipynb @@ -50,49 +50,18 @@ "Fix $m$, $s^2$, $\\tau$ to sensible values, and use an MCMC algorithm of your\n", "choice to sample from the model and produce posterior distribution on\n", "$\\{\\lambda_j\\}^J_{j=1}, \\mu _0, \\sigma _0$ (a sensible choice might be $J = 10$\n", - "hospitals).\n", - "\n", - "____\n", - "\n", - "(b) Determine the marginal posterior distribution for the admission rate for\n", - "each hospital, and compare it with the posterior estimate in a model with no\n", - "pooling (i.e., where each hospital’s rate is inferred exclusively from its\n", - "observed counts, i.e. $\\ln \\lambda_j \\sim \\mathcal{N} (\\mu_0, \\sigma^2_0)$ and\n", - "the same priors as above, and the posterior for hospital j comes exclusively\n", - "from its own data). Check that the posterior means for hospitals with a smaller\n", - "number of records (i.e., $n_j = 12$) exhibit stronger shrinkage towards the\n", - "global mean, thus demonstrating borrowing of strength, and typically have 68%\n", - "HPD credible intervals that are shorter than in the pooling model. You may use a\n", - "violin plot to make this comparison.\n", - "\n", - "___\n", - "\n", - "(c) Consider the prior predictive distribution for possible data within the BHM:\n", - "\n", - "$$ Pr( y_{ij} \\mid priors) = \\int Pr ( y_{ij} \\mid \\lambda_j) Pr(\\lambda_j \\mid \\mu_0, \\sigma_0) Pr (\\mu_0,\\sigma_0) d \\mu_0 d \\sigma_0$$\n", - "\n", - "and simulate from it predictions for the possible counts $y_{ij}$ from the model\n", - "(before you see any data). Evaluate the degree of shrinkage by using as metric\n", - "the average shrinkage in the standard deviation of the posterior with respect to\n", - "the no-pooling scenario, i.e.\n", - "\n", - "$$ S = \\frac {1} {J} \\sum_{j=1}^{J}{1 - \\frac{\\text{std}_\\text{BHM}(\\ln \\lambda_j \\mid \\bold y)}{\\text{std}_\\text{no pool}(\\ln \\lambda_j \\mid \\bold y)}}$$\n", - "\n", - "where a smaller S corresponds to higher degree of shrinkage.\n", - "\n", - "Determine which hyperparameter among $(m, s, \\tau)$ has the most effect on the\n", - "degree of shrinkage, and tune each to avoid predictions that are too diffuse\n", - "(i.e., the predictive spread is unreasonably wide) or too narrow (i.e., the\n", - "predictive spread is over-constraining)." + "hospitals)." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "49307215", "metadata": {}, "outputs": [], "source": [ + "from typing import List, Callable, Tuple\n", + "\n", "import corner\n", "import emcee\n", "import matplotlib.pyplot as plt\n", @@ -145,13 +114,11 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 25, "id": "4a09b1d3", "metadata": {}, "outputs": [], "source": [ - "true_lambdas = []\n", - "\n", "def get_expected_admission_counts(mu: float, sigma: float, size: int):\n", " \"\"\"Get expected number of admissions to a hospital at 'size' time points.\n", " \n", @@ -165,28 +132,31 @@ " mu, sigma\n", " )\n", " admission_rate = np.exp(log_admission_rate)\n", - " true_lambdas.append(admission_rate)\n", " admission_counts = np.random.poisson(\n", " admission_rate, size=size\n", " )\n", - " return admission_counts" + " return admission_counts, admission_rate" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 26, "id": "2f6b2f3c", "metadata": {}, "outputs": [], "source": [ - "def get_admission_count_matrix(mu: float, sigma: float) -> pd.DataFrame:\n", + "def get_admission_count_matrix(\n", + " mu: float, sigma: float\n", + ") -> Tuple[pd.DataFrame, List[float]]:\n", " data = {}\n", + " true_lambdas = []\n", " total_array_length = max(reporting_frequency)\n", " for j in range(J): # For each hospital\n", " data_array_length = reporting_frequency[j]\n", - " admission_records = get_expected_admission_counts(\n", + " admission_records, admission_rate = get_expected_admission_counts(\n", " mu, sigma, size=data_array_length\n", " )\n", + " true_lambdas.append(admission_rate)\n", " padding = np.full(total_array_length - data_array_length, 0, dtype=int)\n", " admission_records = np.concatenate([\n", " admission_records,\n", @@ -195,7 +165,7 @@ " data[f\"Hospital {j+1}\"] = admission_records\n", " data = pd.DataFrame(data)\n", " data.index.name = \"Time point\"\n", - " return data.transpose()" + " return data.transpose(), true_lambdas" ] }, { @@ -222,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "4b5ee692", "metadata": {}, "outputs": [ @@ -262,7 +232,7 @@ "true_mu_0 = np.random.normal(np.log(m), s)\n", "true_sigma_0 = np.random.exponential(tau)\n", "\n", - "df = get_admission_count_matrix(mu=true_mu_0, sigma=true_sigma_0)\n", + "df, true_lambdas = get_admission_count_matrix(mu=true_mu_0, sigma=true_sigma_0)\n", "\n", "print(df)" ] @@ -289,33 +259,13 @@ { "cell_type": "code", "execution_count": 7, - "id": "35c25880", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "pandas.core.series.Series" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "j=6\n", - "type(df.iloc[j,:reporting_frequency[j]])" - ] - }, - { - "cell_type": "code", - "execution_count": null, "id": "10a1de3d", "metadata": {}, "outputs": [], "source": [ - "def log_posterior(parameters, m: float, s: float, tau: float, J: int):\n", + "def log_posterior(parameters,\n", + " m: float, s: float, tau: float,\n", + " J: int, df: pd.DataFrame):\n", " # lambda_j is an array of length J with the admission rate for each hospital\n", " mu_0, sigma_0, *lambdas = parameters\n", " # Check parameters that have to be strictly positive\n", @@ -359,7 +309,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "1e393c9f", "metadata": {}, "outputs": [ @@ -367,45 +317,58 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 5000/5000 [03:36<00:00, 23.12it/s]\n" + "100%|██████████| 5000/5000 [03:45<00:00, 22.17it/s]\n" ] } ], "source": [ - "# Data: observed_admissions (from previous synthetic data generation)\n", - "# Assuming observed_admissions is a pandas df (J * max(reporting_frequency))\n", + "def get_samples_from_full_bhm(true_mu_0: float, true_sigma_0: float,\n", + " true_lambdas: List[float],\n", + " log_posterior: Callable,\n", + " m: float, s: float, tau: float,\n", + " J: int, df: pd.DataFrame):\n", + " # Data: observed_admissions (from previous synthetic data generation)\n", + " # Assuming observed_admissions is a pandas df (J * max(reporting_frequency))\n", "\n", - "# Number of dimensions and walkers\n", - "ndim = 2 + J # [mu_0, sigma_0, lambda_1, ..., lambda_J]\n", - "num_walkers = 32\n", + " # Number of dimensions and walkers\n", + " ndim = 2 + J # [mu_0, sigma_0, lambda_1, ..., lambda_J]\n", + " num_walkers = 32\n", "\n", - "# Initial guess for parameters\n", - "initial_guess = [\n", - " true_mu_0,\n", - " true_sigma_0,\n", - " *true_lambdas\n", - "]\n", + " # Initial guess for parameters\n", + " initial_guess = [\n", + " true_mu_0,\n", + " true_sigma_0,\n", + " *true_lambdas\n", + " ]\n", "\n", - "# Initialize walkers in a small ball around the initial guess\n", - "pos = initial_guess + 1e-4 * np.random.randn(num_walkers, ndim)\n", + " # Initialize walkers in a small ball around the initial guess\n", + " pos = initial_guess + 1e-4 * np.random.randn(num_walkers, ndim)\n", "\n", - "# Create the sampler\n", - "sampler = emcee.EnsembleSampler(\n", - " num_walkers, ndim, log_posterior,\n", - " args=(m, s, tau, J)\n", - ")\n", + " # Create the sampler\n", + " sampler = emcee.EnsembleSampler(\n", + " num_walkers, ndim, log_posterior,\n", + " args=(m, s, tau, J, df)\n", + " )\n", "\n", - "# Run MCMC\n", - "nsteps = 5000\n", - "sampler.run_mcmc(pos, nsteps, progress=True)\n", + " # Run MCMC\n", + " nsteps = 5000\n", + " sampler.run_mcmc(pos, nsteps, progress=True)\n", "\n", - "# Get one sample every 'thin' steps, discarding the burn-in (first 'discard' steps)\n", - "samples = sampler.get_chain(discard=1000, thin=10, flat=True)" + " # Get one sample every 'thin' steps, discarding the burn-in (first 'discard' steps)\n", + " return sampler.get_chain(discard=1000, thin=10, flat=True)\n", + "\n", + "\n", + "samples = get_samples_from_full_bhm(\n", + " true_mu_0=true_mu_0, true_sigma_0=true_sigma_0,\n", + " true_lambdas=true_lambdas,\n", + " log_posterior=log_posterior,\n", + " m=m, s=s, tau=tau, J=J, df=df\n", + ")" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "bcedf091", "metadata": {}, "outputs": [ @@ -439,15 +402,28 @@ }, { "cell_type": "markdown", - "id": "3ce8571e", + "id": "66182918", "metadata": {}, "source": [ - "### (b) Here are the marginal posterior distributions of admission rates for every hospital:" + "____\n", + "\n", + "(b) Determine the marginal posterior distribution for the admission rate for\n", + "each hospital, and compare it with the posterior estimate in a model with no\n", + "pooling (i.e., where each hospital’s rate is inferred exclusively from its\n", + "observed counts, i.e. $\\ln \\lambda_j \\sim \\mathcal{N} (\\mu_0, \\sigma^2_0)$ and\n", + "the same priors as above, and the posterior for hospital j comes exclusively\n", + "from its own data). Check that the posterior means for hospitals with a smaller\n", + "number of records (i.e., $n_j = 12$) exhibit stronger shrinkage towards the\n", + "global mean, thus demonstrating borrowing of strength, and typically have 68%\n", + "HPD credible intervals that are shorter than in the pooling model. You may use a\n", + "violin plot to make this comparison.\n", + "\n", + "#### With pooling" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 10, "id": "03ee076c", "metadata": {}, "outputs": [ @@ -455,31 +431,31 @@ "name": "stdout", "output_type": "stream", "text": [ - " Reporting frequency 16th percentile 84th percentile \\\n", - "Hospital \n", - "1 12 155.772971 160.241098 \n", - "2 54 157.656116 160.502618 \n", - "3 12 156.272260 160.602153 \n", - "4 12 158.177209 162.292523 \n", - "5 54 157.382752 160.187262 \n", - "6 54 160.169699 163.533840 \n", - "7 12 158.731447 162.830338 \n", - "8 24 159.015641 162.708823 \n", - "9 12 154.559953 159.806722 \n", - "10 12 157.219515 161.435273 \n", + " Reporting frequency log lambda 16th percentile \\\n", + "Hospital \n", + "1 12 5.048400 \n", + "2 54 5.060416 \n", + "3 12 5.051600 \n", + "4 12 5.063716 \n", + "5 54 5.058681 \n", + "6 54 5.076234 \n", + "7 12 5.067214 \n", + "8 24 5.069003 \n", + "9 12 5.040582 \n", + "10 12 5.057643 \n", "\n", - " HPD width (68%) mu_0 sigma_0 \n", - "Hospital \n", - "1 4.468127 5.070992 0.016259 \n", - "2 2.846502 5.070992 0.016259 \n", - "3 4.329893 5.070992 0.016259 \n", - "4 4.115313 5.070992 0.016259 \n", - "5 2.804511 5.070992 0.016259 \n", - "6 3.364141 5.070992 0.016259 \n", - "7 4.098891 5.070992 0.016259 \n", - "8 3.693182 5.070992 0.016259 \n", - "9 5.246769 5.070992 0.016259 \n", - "10 4.215758 5.070992 0.016259 \n" + " log lambda 84th percentile HPD width (68%) mu_0 sigma_0 \n", + "Hospital \n", + "1 5.076680 0.028280 5.070992 0.016259 \n", + "2 5.078310 0.017894 5.070992 0.016259 \n", + "3 5.078930 0.027330 5.070992 0.016259 \n", + "4 5.089400 0.025684 5.070992 0.016259 \n", + "5 5.076344 0.017663 5.070992 0.016259 \n", + "6 5.097020 0.020786 5.070992 0.016259 \n", + "7 5.092709 0.025495 5.070992 0.016259 \n", + "8 5.091962 0.022960 5.070992 0.016259 \n", + "9 5.073965 0.033383 5.070992 0.016259 \n", + "10 5.084104 0.026461 5.070992 0.016259 \n" ] } ], @@ -489,13 +465,13 @@ "for j in range(J):\n", " mu_0_j = samples[:, 0]\n", " sigma_0_j = samples[:, 1]\n", - " lambda_samples_j = samples[:, 2 + j]\n", + " lambda_samples_j = np.log(samples[:, 2 + j])\n", " lower, upper = np.percentile(lambda_samples_j, [16, 84])\n", " rows.append({\n", " \"Hospital\": j + 1,\n", " \"Reporting frequency\": reporting_frequency[j],\n", - " \"16th percentile\": lower,\n", - " \"84th percentile\": upper,\n", + " \"log lambda 16th percentile\": lower,\n", + " \"log lambda 84th percentile\": upper,\n", " \"HPD width (68%)\": upper - lower,\n", " \"mu_0\": np.mean(mu_0_j),\n", " \"sigma_0\": np.mean(sigma_0_j)\n", @@ -508,7 +484,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "70b92966", "metadata": {}, "outputs": [ @@ -516,9 +492,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "Pearson r = -0.874\n", - "R^2 = 0.764\n", - "p-value = 9.495e-04\n" + "Pearson r = -0.865\n", + "R^2 = 0.749\n", + "p-value = 1.224e-03\n" ] } ], @@ -536,13 +512,13 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 12, "id": "74dbf9b4", "metadata": {}, "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -568,6 +544,357 @@ ")\n", "plt.show()\n" ] + }, + { + "cell_type": "markdown", + "id": "d586de09", + "metadata": {}, + "source": [ + "#### Without pooling" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "eb404c06", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 3000/3000 [00:21<00:00, 141.02it/s]\n", + "100%|██████████| 3000/3000 [00:19<00:00, 151.97it/s]\n", + "100%|██████████| 3000/3000 [00:17<00:00, 169.06it/s]\n", + "100%|██████████| 3000/3000 [00:18<00:00, 158.46it/s]\n", + "100%|██████████| 3000/3000 [00:19<00:00, 154.37it/s]\n", + "100%|██████████| 3000/3000 [00:17<00:00, 168.49it/s]\n", + "100%|██████████| 3000/3000 [00:17<00:00, 170.60it/s]\n", + "100%|██████████| 3000/3000 [00:17<00:00, 169.70it/s]\n", + "100%|██████████| 3000/3000 [00:17<00:00, 170.22it/s]\n", + "100%|██████████| 3000/3000 [00:17<00:00, 171.07it/s]\n" + ] + } + ], + "source": [ + "mu_fixed = np.log(m)\n", + "sigma_fixed = s\n", + "\n", + "\n", + "def log_posterior_without_pooling(parameters: np.ndarray,\n", + " mu_0: float, sigma_0: float,\n", + " j: int, df: pd.DataFrame):\n", + " lambda_j, *_ = parameters\n", + " if lambda_j <= 0:\n", + " return -np.inf\n", + " # Prior on 'lambda_j' (log-normally distributed, parameters 'mu_0' and 'sigma_0')\n", + " log_prior = (\n", + " -np.log(lambda_j * sigma_0 * np.sqrt(2 * np.pi))\n", + " - (np.log(lambda_j) - mu_0)**2 / (2 * sigma_0**2)\n", + " )\n", + " observed_admissions = df.iloc[j,:reporting_frequency[j]].to_numpy()\n", + " # Likelihood: Poisson\n", + " log_likelihood = np.sum(\n", + " observed_admissions * np.log(lambda_j)\n", + " - lambda_j\n", + " - gammaln(observed_admissions + 1)\n", + " )\n", + "\n", + " return log_prior + log_likelihood\n", + "\n", + "\n", + "def get_samples_without_pooling(true_lambdas: List[float],\n", + " log_posterior_without_pooling: Callable,\n", + " mu_fixed: float, sigma_fixed: float,\n", + " J: int, df: pd.DataFrame):\n", + " samples_without_pooling = {}\n", + " for j in range(J):\n", + " # Number of dimensions and walkers\n", + " ndim = 1 # Only lambda_j\n", + " num_walkers = 32\n", + "\n", + " # Initial guess for parameters\n", + " initial_guess = [\n", + " true_lambdas[j]\n", + " ]\n", + "\n", + " # Initialize walkers in a small ball around the initial guess\n", + " pos = initial_guess + 1e-4 * np.random.randn(num_walkers, ndim)\n", + "\n", + " # Create the sampler\n", + " sampler = emcee.EnsembleSampler(\n", + " num_walkers, ndim, log_posterior_without_pooling,\n", + " args=(mu_fixed, sigma_fixed, j, df)\n", + " )\n", + "\n", + " # Run MCMC\n", + " nsteps = 3000\n", + " sampler.run_mcmc(pos, nsteps, progress=True)\n", + "\n", + " # Get one sample every 'thin' steps, discarding the burn-in (first 'discard' steps)\n", + " samples_without_pooling[j] = sampler.get_chain(discard=1000, thin=10, flat=True)\n", + " return samples_without_pooling\n", + "\n", + "\n", + "samples_without_pooling = get_samples_without_pooling(\n", + " true_lambdas=true_lambdas,\n", + " log_posterior_without_pooling=log_posterior_without_pooling,\n", + " mu_fixed=mu_fixed, sigma_fixed=sigma_fixed,\n", + " J=J, df=df\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "aa76a65f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Reporting frequency log lambda 16th percentile \\\n", + "Hospital \n", + "1 12 5.017023 \n", + "2 54 5.057103 \n", + "3 12 5.027284 \n", + "4 12 5.057264 \n", + "5 54 5.054388 \n", + "6 54 5.085134 \n", + "7 12 5.070524 \n", + "8 24 5.072232 \n", + "9 12 5.000043 \n", + "10 12 5.046290 \n", + "\n", + " log lambda 84th percentile HPD width (68%) mu_0 sigma_0 \n", + "Hospital \n", + "1 5.062971 0.045948 4.442651 0.4 \n", + "2 5.078878 0.021775 4.442651 0.4 \n", + "3 5.072929 0.045645 4.442651 0.4 \n", + "4 5.103612 0.046349 4.442651 0.4 \n", + "5 5.076058 0.021670 4.442651 0.4 \n", + "6 5.106324 0.021190 4.442651 0.4 \n", + "7 5.115854 0.045329 4.442651 0.4 \n", + "8 5.104386 0.032154 4.442651 0.4 \n", + "9 5.046979 0.046936 4.442651 0.4 \n", + "10 5.091491 0.045202 4.442651 0.4 \n" + ] + } + ], + "source": [ + "rows = []\n", + "\n", + "for j in range(J):\n", + " swp = samples_without_pooling[j]\n", + " lambda_samples_j = np.log(swp[:, 0])\n", + " lower, upper = np.percentile(lambda_samples_j, [16, 84])\n", + " rows.append({\n", + " \"Hospital\": j + 1,\n", + " \"Reporting frequency\": reporting_frequency[j],\n", + " \"log lambda 16th percentile\": lower,\n", + " \"log lambda 84th percentile\": upper,\n", + " \"HPD width (68%)\": upper - lower,\n", + " \"mu_0\": np.mean(mu_fixed),\n", + " \"sigma_0\": np.mean(sigma_fixed)\n", + " })\n", + "\n", + "temp = pd.DataFrame(rows)\n", + "temp = pd.DataFrame(rows).set_index(\"Hospital\")\n", + "print(temp)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "0117285c", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "data = []\n", + "positions = []\n", + "groups = []\n", + "\n", + "pos, offset, step = 1.0, 0.35, 1.2\n", + "\n", + "for j in range(J):\n", + " # With pooling\n", + " data.append(np.log(samples[:, 2 + j]))\n", + " positions.append(pos)\n", + " groups.append(\"bhm\")\n", + " # Without pooling\n", + " data.append(np.log(samples_without_pooling[j][:, 0]))\n", + " positions.append(pos + offset)\n", + " groups.append(\"nopool\")\n", + " pos += step\n", + "\n", + "plt.figure(figsize=(14, 6))\n", + "\n", + "vp = plt.violinplot(\n", + " data,\n", + " positions=positions,\n", + " widths=0.3,\n", + " #showmeans=False,\n", + " showmedians=True,\n", + " #showextrema=False\n", + ")\n", + "\n", + "# Color each violin plot according to group (with or without pooling)\n", + "for body, group in zip(vp[\"bodies\"], groups):\n", + " body.set_facecolor(\n", + " {\n", + " \"bhm\": \"orange\",\n", + " \"nopool\": \"darkblue\"\n", + " }[group]\n", + " )\n", + " body.set_alpha(0.7)\n", + "\n", + "vp[\"cmedians\"].set_color(\"white\")\n", + "vp[\"cmedians\"].set_alpha(1.0)\n", + "vp[\"cmedians\"].set_linewidth(2.0)\n", + "\n", + "# Set axes labels and plot title\n", + "plt.xticks(\n", + " [1 + step * j + offset/2 for j in range(J)],\n", + " [f\"Hospital {j+1}\\n(n={reporting_frequency[j]})\" for j in range(J)]\n", + ")\n", + "plt.ylabel(r\"$\\ln \\lambda_j$\")\n", + "plt.title(r\"Distribution of $\\ln \\lambda_y$\"\"\\n\"r\"Full BHM vs w/out pooling\")\n", + "\n", + "from matplotlib.patches import Patch\n", + "plt.legend(\n", + " handles=[\n", + " Patch(color=\"orange\", label=\"With pooling\"),\n", + " Patch(color=\"darkblue\", label=\"No pooling\")\n", + " ],\n", + " loc=\"upper right\"\n", + ")\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "id": "dd561fbc", + "metadata": {}, + "source": [ + "___\n", + "\n", + "(c) Consider the prior predictive distribution for possible data within the BHM:\n", + "\n", + "$$ Pr( y_{ij} \\mid priors) = \\int Pr ( y_{ij} \\mid \\lambda_j) Pr(\\lambda_j \\mid \\mu_0, \\sigma_0) Pr (\\mu_0,\\sigma_0) d \\mu_0 d \\sigma_0$$\n", + "\n", + "and simulate from it predictions for the possible counts $y_{ij}$ from the model\n", + "(before you see any data). Evaluate the degree of shrinkage by using as metric\n", + "the average shrinkage in the standard deviation of the posterior with respect to\n", + "the no-pooling scenario, i.e.\n", + "\n", + "$$ S = \\frac {1} {J} \\sum_{j=1}^{J}{1 - \\frac{\\text{std}_\\text{BHM}(\\ln \\lambda_j \\mid \\bold y)}{\\text{std}_\\text{no pool}(\\ln \\lambda_j \\mid \\bold y)}}$$\n", + "\n", + "where a smaller S corresponds to higher degree of shrinkage.\n", + "\n", + "Determine which hyperparameter among $(m, s, \\tau)$ has the most effect on the\n", + "degree of shrinkage, and tune each to avoid predictions that are too diffuse\n", + "(i.e., the predictive spread is unreasonably wide) or too narrow (i.e., the\n", + "predictive spread is over-constraining)." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "3366ea7e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The degree of shrinkage S for m = 85, s = 0.4, tau = 0.1 is: 0.281\n" + ] + } + ], + "source": [ + "S = np.mean([\n", + " 1 - (\n", + " np.std(np.log(samples[:, 2 + j]))\n", + " / np.std(np.log(samples_without_pooling[j][:, 0]))\n", + " )\n", + " for j in range(J)\n", + "])\n", + "\n", + "print(\n", + " f\"The degree of shrinkage S for m = {m}, s = {s}, tau = {tau} is: {S:.3f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ddb9e81", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|███████▉ | 3977/5000 [03:49<00:56, 17.96it/s]" + ] + } + ], + "source": [ + "m = 85\n", + "s = 0.4\n", + "tau = 0.1\n", + "\n", + "def evaluate_degree_of_shrinkage(m, s, tau):\n", + " true_mu_0 = np.random.normal(np.log(m), s)\n", + " true_sigma_0 = np.random.exponential(tau)\n", + " df, true_lambdas = get_admission_count_matrix(\n", + " mu=true_mu_0, sigma=true_sigma_0\n", + " )\n", + " # Samples from joint posterior (full BHM)\n", + " samples_bhm = get_samples_from_full_bhm(\n", + " true_mu_0=true_mu_0, true_sigma_0=true_sigma_0,\n", + " true_lambdas=true_lambdas,\n", + " log_posterior=log_posterior,\n", + " m=m, s=s, tau=tau,\n", + " J=J, df=df\n", + " )\n", + " # Samples from non-pooling model\n", + " samples_without_pooling = get_samples_without_pooling(\n", + " true_lambdas=true_lambdas,\n", + " log_posterior_without_pooling=log_posterior_without_pooling,\n", + " mu_fixed=true_mu_0, sigma_fixed=true_sigma_0,\n", + " J=J, df=df\n", + " )\n", + " S = np.mean([\n", + " 1 - (\n", + " np.std(np.log(samples_bhm[:, 2 + j]))\n", + " / np.std(np.log(samples_without_pooling[j][:, 0]))\n", + " )\n", + " for j in range(J)\n", + " ])\n", + " print(\n", + " f\"The degree of shrinkage S for m = {m}, s = {s}, tau = {tau} is: {S:.3f}\"\n", + " )\n", + "\n", + "\n", + "evaluate_degree_of_shrinkage(m=(2 * m), s=s, tau=tau)\n", + "evaluate_degree_of_shrinkage(m=m, s=(2 * s), tau=tau)\n", + "evaluate_degree_of_shrinkage(m=m, s=s, tau=(2 * tau))" + ] } ], "metadata": {