{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Interactive visualization\n", "\n", "```{autolink-concat}\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "mystnb": { "code_prompt_show": "Import python libraries" }, "tags": [ "hide-cell", "scroll-input" ] }, "outputs": [], "source": [ "from __future__ import annotations\n", "\n", "import logging\n", "import os\n", "from functools import lru_cache\n", "from textwrap import dedent\n", "from warnings import filterwarnings\n", "\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import sympy as sp\n", "from IPython.display import Markdown, display\n", "from ipywidgets import (\n", " HTML,\n", " Button,\n", " FloatSlider,\n", " GridBox,\n", " HBox,\n", " HTMLMath,\n", " Layout,\n", " RadioButtons,\n", " Tab,\n", " interactive_output,\n", ")\n", "from matplotlib.colors import LogNorm\n", "from tensorwaves.interface import DataSample, ParametrizedFunction\n", "from tqdm.auto import tqdm\n", "from traitlets.utils.bunch import Bunch\n", "\n", "from polarimetry import formulate_polarimetry\n", "from polarimetry.amplitude import simplify_latex_rendering\n", "from polarimetry.data import (\n", " compute_dalitz_boundaries,\n", " create_data_transformer,\n", " generate_meshgrid_sample,\n", ")\n", "from polarimetry.io import (\n", " mute_jax_warnings,\n", " perform_cached_doit,\n", " perform_cached_lambdify,\n", ")\n", "from polarimetry.lhcb import load_model_builder, load_model_parameters\n", "from polarimetry.lhcb.particle import load_particles\n", "from polarimetry.plot import use_mpl_latex_fonts\n", "\n", "filterwarnings(\"ignore\")\n", "logging.getLogger(\"polarimetry.function\").setLevel(logging.INFO)\n", "mute_jax_warnings()\n", "simplify_latex_rendering()\n", "\n", "NO_TQDM = \"EXECUTE_NB\" in os.environ\n", "if NO_TQDM:\n", " logging.getLogger().setLevel(logging.ERROR)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "mystnb": { "code_prompt_show": "Formulate amplitude models" }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "model_choice = 0\n", "model_file = \"../../data/model-definitions.yaml\"\n", "PARTICLES = load_particles(\"../../data/particle-definitions.yaml\")\n", "BUILDER = load_model_builder(model_file, PARTICLES, model_id=0)\n", "imported_parameters = load_model_parameters(\n", " model_file,\n", " BUILDER.decay,\n", " model_id=0,\n", " particle_definitions=PARTICLES,\n", ")\n", "MODELS = {}\n", "for ref in (1, 2, 3):\n", " MODELS[ref] = BUILDER.formulate(ref)\n", " MODELS[ref].parameter_defaults.update(imported_parameters)\n", "DECAY = MODELS[1].decay\n", "RESONANCES = sorted(\n", " {c.resonance for c in DECAY.chains},\n", " key=lambda p: (p.name[0], p.mass),\n", ")\n", "del model_choice, model_file, imported_parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "mystnb": { "code_prompt_show": "Definition of free parameters" }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "def to_polar_coordinates(coupling: sp.Indexed) -> tuple[sp.Symbol, sp.Symbol]:\n", " superscript = sp.latex(coupling.indices[0])\n", " subscript = \", \".join(map(sp.latex, coupling.indices[1:]))\n", " suffix = f\"^{{{superscript}}}_{{{subscript}}}\"\n", " norm = sp.Symbol(\"C\" + suffix)\n", " phi = sp.Symbol(R\"\\phi\" + suffix)\n", " return norm, phi\n", "\n", "\n", "PARAMETERS = {}\n", "POLAR_SUBSTITUTIONS = {}\n", "for model in MODELS.values():\n", " PARAMETERS.update(model.parameter_defaults)\n", " for symbol, value in model.parameter_defaults.items():\n", " if not symbol.name.startswith(R\"\\mathcal{H}\"):\n", " continue\n", " if \"production\" not in symbol.name:\n", " continue\n", " del PARAMETERS[symbol]\n", " norm, phi = to_polar_coordinates(symbol)\n", " PARAMETERS[norm] = np.abs(value)\n", " PARAMETERS[phi] = np.angle(value)\n", " POLAR_SUBSTITUTIONS[symbol] = norm * sp.exp(phi * sp.I)\n", " del model\n", "\n", "FREE_PARAMETERS = {\n", " s: value\n", " for s, value in PARAMETERS.items()\n", " if s.name.startswith(\"C\")\n", " or s.name.startswith(R\"\\phi\")\n", " or (s.name.startswith(R\"\\Gamma_\") and \"Sigma\" not in s.name)\n", " or (s.name.startswith(\"m_\") and \"(\" in s.name)\n", "}\n", "FIXED_PARAMETERS = {s: v for s, v in PARAMETERS.items() if s not in FREE_PARAMETERS}\n", "\n", "\n", "@lru_cache(maxsize=None)\n", "def unfold_and_substitute(expr: sp.Expr, reference_subsystem: int = 1) -> sp.Expr:\n", " expr = perform_cached_doit(expr)\n", " expr = perform_cached_doit(expr.xreplace(MODELS[reference_subsystem].amplitudes))\n", " expr = expr.xreplace(POLAR_SUBSTITUTIONS)\n", " return expr.xreplace(FIXED_PARAMETERS)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "mystnb": { "code_prompt_show": "Formulate expressions and lambdify" }, "tags": [ "hide-cell" ] }, "outputs": [], "source": [ "def create_function(\n", " expr: sp.Expr, reference_subsystem: int = 1\n", ") -> ParametrizedFunction:\n", " global progress_bar # noqa: PLW0602\n", " expr = unfold_and_substitute(expr, reference_subsystem)\n", " func = perform_cached_lambdify(expr, parameters=FREE_PARAMETERS)\n", " progress_bar.update()\n", " return func\n", "\n", "\n", "progress_bar = tqdm(total=12, disable=NO_TQDM)\n", "INTENSITY_FUNC = {\n", " reference_subsystem: create_function(MODELS[reference_subsystem].intensity)\n", " for reference_subsystem, model in MODELS.items()\n", "}\n", "POLARIMETRY_FUNCS = {\n", " reference_subsystem: tuple(\n", " create_function(expr, reference_subsystem)\n", " for expr in formulate_polarimetry(BUILDER, reference_subsystem)\n", " )\n", " for reference_subsystem in MODELS\n", "}\n", "progress_bar.close()\n", "del progress_bar" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "mystnb": { "code_prompt_show": "Define phase space sample for plotting" }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "def create_grid(resolution: int) -> DataSample:\n", " sample = generate_meshgrid_sample(DECAY, resolution)\n", " for model in MODELS.values():\n", " transformer = create_data_transformer(model)\n", " sample.update(transformer(sample))\n", " return sample\n", "\n", "\n", "MESH_GRID = generate_meshgrid_sample(DECAY, resolution=200)\n", "QUIVER_GRID = generate_meshgrid_sample(DECAY, resolution=35)\n", "for model in tqdm(MODELS.values(), disable=NO_TQDM, leave=False):\n", " transformer = create_data_transformer(model)\n", " MESH_GRID.update(transformer(MESH_GRID))\n", " QUIVER_GRID.update(transformer(QUIVER_GRID))\n", "\n", "# pre-compile\n", "for ref in tqdm(MODELS, disable=NO_TQDM, leave=False):\n", " INTENSITY_FUNC[ref](MESH_GRID)\n", " for func in POLARIMETRY_FUNCS[ref]:\n", " func(QUIVER_GRID)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "mystnb": { "code_prompt_show": "Define sliders for the widget" }, "tags": [ "scroll-input", "hide-input" ] }, "outputs": [], "source": [ "def create_ui() -> HBox:\n", " @temporarily_deactivate_continuous_update\n", " def reset_sliders(click_event: Button | None = None) -> None:\n", " for symbol, value in FREE_PARAMETERS.items():\n", " set_slider(SLIDERS[symbol.name], value)\n", "\n", " reset_button = Button(description=\"Reset sliders\", button_style=\"danger\")\n", " reset_button.on_click(reset_sliders)\n", " reset_sliders()\n", "\n", " @temporarily_deactivate_continuous_update\n", " def set_reference_subsystem(value: Bunch) -> None:\n", " global REFERENCE_SUBSYSTEM # noqa: PLW0603\n", " subsystems = {1: \"K\", 2: \"L\", 3: \"D\"}\n", " REFERENCE_SUBSYSTEM = value.new\n", " for name, slider in SLIDERS.items():\n", " if not name.startswith(R\"\\phi\"):\n", " continue\n", " if subsystems[value.old] in name or subsystems[value.new] in name:\n", " phi = slider.value\n", " set_slider(slider, -np.sign(phi) * (np.pi - abs(phi)))\n", "\n", " reference_selector = RadioButtons(\n", " description=\"Reference sub-system\",\n", " options=[\n", " (\"1: K** → π⁺K⁻\", 1),\n", " (\"2: Λ** → pK⁻\", 2),\n", " (\"3: Δ** → pπ⁺\", 3),\n", " ],\n", " layout=Layout(width=\"auto\"),\n", " )\n", " reference_selector.observe(set_reference_subsystem, names=\"value\")\n", "\n", " @temporarily_deactivate_continuous_update\n", " def set_coupling_to_zero(filter_pattern: Button) -> None:\n", " if isinstance(filter_pattern, Button):\n", " filter_pattern = from_unicode(filter_pattern.description)\n", " for name, slider in SLIDERS.items():\n", " if not name.startswith(\"C\"):\n", " continue\n", " if filter_pattern not in name:\n", " continue\n", " set_slider(slider, 0)\n", "\n", " def set_all_to_zero(action: Button | None = None) -> None:\n", " set_coupling_to_zero(\"D\")\n", " set_coupling_to_zero(\"K\")\n", " set_coupling_to_zero(\"L\")\n", "\n", " all_to_zero = Button(\n", " description=\"Set all couplings to zero\",\n", " layout=Layout(width=\"auto\"),\n", " tooltip=\"Set all couplings to zero\",\n", " )\n", " all_to_zero.on_click(set_all_to_zero)\n", " resonance_buttons = []\n", " for p in RESONANCES:\n", " button = Button(\n", " description=to_unicode(p.name),\n", " layout=Layout(width=\"auto\"),\n", " tooltip=f\"Set couplings for {to_unicode(p.name)} to 0\",\n", " )\n", " button.style.button_color = to_html_color(p.name)\n", " button.on_click(set_coupling_to_zero)\n", " resonance_buttons.append(button)\n", " subsystem_buttons = []\n", " for subsystem_id in sorted([\"D\", \"K\", \"L\"]):\n", " button = Button(\n", " description=f\"{to_unicode(subsystem_id)}**\",\n", " tooltip=f\"Set couplings for all {to_unicode(subsystem_id)}** to 0\",\n", " )\n", " button.style.button_color = to_html_color(subsystem_id)\n", " button.on_click(set_coupling_to_zero)\n", " subsystem_buttons.append(button)\n", " zero_coupling_panel = GridBox(\n", " [\n", " all_to_zero,\n", " HBox(subsystem_buttons),\n", " GridBox(\n", " np.reshape(resonance_buttons, (4, 3)).T.flatten().tolist(),\n", " layout=Layout(grid_template_columns=4 * \"auto \"),\n", " ),\n", " ]\n", " )\n", "\n", " def get_subscript(p):\n", " return f\"{p.name} \\\\to p K^-\" if \"1405\" in p.name else p.name\n", "\n", " grouped_sliders = []\n", " for p in RESONANCES:\n", " row = (\n", " HTML(\"\", layout=Layout(width=\"auto\")),\n", " SLIDERS[f\"m_{{{p.name}}}\"],\n", " SLIDERS[Rf\"\\Gamma_{{{get_subscript(p)}}}\"],\n", " )\n", " rows = [row]\n", " for slider_name in SLIDERS:\n", " if p.name not in slider_name:\n", " continue\n", " if not slider_name.startswith(\"C\"):\n", " continue\n", " row = (\n", " HTMLMath(\n", " f\"${slider_name}$\".replace(\"C\", R\"\\mathcal{H}\"),\n", " layout=Layout(width=\"auto\"),\n", " ),\n", " SLIDERS[slider_name],\n", " SLIDERS[slider_name.replace(\"C\", R\"\\phi\")],\n", " )\n", " rows.append(row)\n", " rows = np.array(rows)\n", " grouped_sliders.append(\n", " GridBox(\n", " rows.flatten().tolist(),\n", " layout=Layout(grid_template_columns=3 * \"auto \"),\n", " )\n", " )\n", " return HBox(\n", " [\n", " GridBox([reset_button, reference_selector]),\n", " Tab(grouped_sliders, titles=[to_unicode(p.name) for p in RESONANCES]),\n", " zero_coupling_panel,\n", " ]\n", " )\n", "\n", "\n", "def create_slider(symbol: sp.Basic, value: float) -> FloatSlider:\n", " (\n", " (s1_min, s1_max),\n", " (s2_min, s2_max),\n", " (s3_min, s3_max),\n", " ) = compute_dalitz_boundaries(DECAY)\n", " slider = FloatSlider(\n", " description=Rf\"\\({sp.latex(symbol)})\",\n", " continuous_update=True,\n", " readout_format=\".3f\",\n", " step=1e-3,\n", " )\n", " if symbol.name.startswith(\"m\"):\n", " slider.description = \"mass\"\n", " slider.style.handle_color = \"lightblue\"\n", " if \"K\" in symbol.name:\n", " slider.min = np.sqrt(s1_min)\n", " slider.max = np.sqrt(s1_max)\n", " elif \"L\" in symbol.name:\n", " slider.min = np.sqrt(s2_min)\n", " slider.max = np.sqrt(s2_max)\n", " elif \"D\" in symbol.name:\n", " slider.min = np.sqrt(s3_min)\n", " slider.max = np.sqrt(s3_max)\n", " elif symbol.name.startswith(R\"\\Gamma\"):\n", " slider.description = \"width\"\n", " slider.style.handle_color = \"lightblue\"\n", " slider.min = 0\n", " slider.max = max(0.5, 2 * slider.value)\n", " elif symbol.name.startswith(\"C\"):\n", " slider.description = \"r\"\n", " slider.min = 0\n", " slider.max = 20\n", " slider.readout_format = \".1f\"\n", " slider.step = 1e-1\n", " elif symbol.name.startswith(R\"\\phi\"):\n", " slider.description = \"φ\"\n", " slider.min = -np.pi\n", " slider.max = +np.pi\n", " slider.readout_format = \".2f\"\n", " slider.step = 1e-2\n", " return slider\n", "\n", "\n", "def set_slider(slider: FloatSlider, value: float) -> None:\n", " n_decimals = -round(np.log10(slider.step))\n", " if slider.value != round(value, n_decimals): # widget performance\n", " slider.value = value\n", "\n", "\n", "def to_html_color(name: str) -> str:\n", " if \"K\" in name:\n", " return \"#FFCCCB\" # light red\n", " if \"L\" in name:\n", " return \"lightblue\"\n", " if \"D\" in name:\n", " return \"lightgreen\"\n", " raise NotImplementedError\n", "\n", "\n", "def to_unicode(text: str) -> str:\n", " text = text.replace(\"L\", \"Λ\")\n", " return text.replace(\"D\", \"Δ\")\n", "\n", "\n", "def from_unicode(text: str) -> str:\n", " text = text.replace(\"Λ\", \"L\")\n", " text = text.replace(\"Δ\", \"D\")\n", " return text.replace(\"*\", \"\")\n", "\n", "\n", "def temporarily_deactivate_continuous_update(func):\n", " def new_func(*args, **kwargs):\n", " for slider in SLIDERS.values():\n", " slider.continuous_update = False\n", " output = func(*args, **kwargs)\n", " for slider in SLIDERS.values():\n", " slider.continuous_update = True\n", " return output\n", "\n", " return new_func\n", "\n", "\n", "REFERENCE_SUBSYSTEM = 1\n", "SLIDERS = {s.name: create_slider(s, value) for s, v in FREE_PARAMETERS.items()}\n", "UI = create_ui()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "mystnb": { "code_prompt_show": "Create interactive plot" }, "tags": [ "hide-input", "scroll-input" ] }, "outputs": [], "source": [ "def create_interactive_plot() -> None:\n", " plt.rcdefaults()\n", " use_mpl_latex_fonts()\n", " plt.rc(\"font\", size=20)\n", " fig, axes = plt.subplots(\n", " figsize=(15, 7.5),\n", " ncols=2,\n", " sharey=True,\n", " )\n", " ax1, ax2 = axes\n", " ax1.set_title(\"Intensity distribution\")\n", " ax2.set_title(\"Polarimeter vector field\")\n", " ax1.set_xlabel(R\"$m^2(K^- \\pi^+)$\")\n", " ax2.set_xlabel(R\"$m^2(K^- \\pi^+), \\alpha_x$\")\n", " ax1.set_ylabel(R\"$m^2(p K^-), \\alpha_x$\")\n", " for ax in axes:\n", " ax.set_box_aspect(1)\n", " fig.canvas.toolbar_visible = False\n", " fig.canvas.header_visible = False\n", " fig.canvas.footer_visible = False\n", "\n", " mesh = None\n", " quiver = None\n", " intensity_bar = None\n", "\n", " def plot3(**kwargs):\n", " nonlocal quiver, mesh, intensity_bar\n", " intensity_func = INTENSITY_FUNC[REFERENCE_SUBSYSTEM]\n", " polarimetry_funcs = POLARIMETRY_FUNCS[REFERENCE_SUBSYSTEM]\n", " for func in [intensity_func, *polarimetry_funcs]:\n", " func.update_parameters(kwargs)\n", " intensities = intensity_func(MESH_GRID)\n", " αx, αy, αz = tuple(func(QUIVER_GRID).real for func in polarimetry_funcs)\n", " abs_α = jnp.sqrt(αx**2 + αy**2 + αz**2)\n", " if mesh is None:\n", " mesh = ax1.pcolormesh(\n", " MESH_GRID[\"sigma1\"],\n", " MESH_GRID[\"sigma2\"],\n", " intensities,\n", " cmap=plt.cm.YlOrRd,\n", " norm=LogNorm(),\n", " )\n", " intensity_bar = fig.colorbar(mesh, ax=ax1, pad=0.01, fraction=0.0473)\n", " intensity_bar.ax.set_ylabel(\"normalized intensity (a.u.)\")\n", " else:\n", " mesh.set_array(intensities)\n", " if jnp.isfinite(intensities).any():\n", " y_min = max(np.nanmin(intensities), 1e0)\n", " y_max = max(np.nanmax(intensities), 1e2)\n", " mesh.set_clim(y_min, y_max)\n", " intensity_bar.ax.set_ylim(y_min, y_max)\n", " if quiver is None:\n", " quiver = ax2.quiver(\n", " QUIVER_GRID[\"sigma1\"],\n", " QUIVER_GRID[\"sigma2\"],\n", " αz,\n", " αx,\n", " abs_α,\n", " cmap=plt.cm.viridis_r,\n", " clim=(0, 1),\n", " )\n", " c_bar = fig.colorbar(quiver, ax=ax2, pad=0.01, fraction=0.0473)\n", " c_bar.ax.set_ylabel(R\"$\\left|\\vec\\alpha\\right|$\")\n", " else:\n", " quiver.set_UVC(αz, αx, abs_α)\n", " fig.canvas.draw_idle()\n", "\n", " output = interactive_output(plot3, controls=SLIDERS)\n", " fig.tight_layout()\n", " if NO_TQDM:\n", " export_file = \"../_static/images/interactive-plot.png\"\n", " fig.savefig(export_file, dpi=200)\n", " src = f\"\"\"\n", " :::{{container}} full-width\n", " ![]({export_file})\n", " :::\n", " \"\"\"\n", " src = dedent(src)\n", " display(Markdown(src))\n", " else:\n", " display(output, UI)\n", "\n", "\n", "%matplotlib widget\n", "create_interactive_plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ":::{tip}\n", "Run this notebook locally in Jupyter or [online on Binder](https://mybinder.org/v2/gh/ComPWA/polarimetry/0.0.9?urlpath=lab/tree/docs/appendix/widget.ipynb) to modify parameters interactively!\n", ":::" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.15" }, "myst": { "all_links_external": true } }, "nbformat": 4, "nbformat_minor": 4 }