7.10. Interactive visualization#

Hide code cell content
Hide code cell source
model_choice = 0
model_file = "../../data/model-definitions.yaml"
PARTICLES = load_particles("../../data/particle-definitions.yaml")
BUILDER = load_model_builder(model_file, PARTICLES, model_id=0)
imported_parameters = load_model_parameters(
    model_file,
    BUILDER.decay,
    model_id=0,
    particle_definitions=PARTICLES,
)
MODELS = {}
for ref in (1, 2, 3):
    MODELS[ref] = BUILDER.formulate(ref)
    MODELS[ref].parameter_defaults.update(imported_parameters)
DECAY = MODELS[1].decay
RESONANCES = sorted(
    {c.resonance for c in DECAY.chains},
    key=lambda p: (p.name[0], p.mass),
)
del model_choice, model_file, imported_parameters
Hide code cell source
def to_polar_coordinates(coupling: sp.Indexed) -> tuple[sp.Symbol, sp.Symbol]:
    superscript = sp.latex(coupling.indices[0])
    subscript = ", ".join(map(sp.latex, coupling.indices[1:]))
    suffix = f"^{{{superscript}}}_{{{subscript}}}"
    norm = sp.Symbol("C" + suffix)
    phi = sp.Symbol(R"\phi" + suffix)
    return norm, phi


PARAMETERS = {}
POLAR_SUBSTITUTIONS = {}
for model in MODELS.values():
    PARAMETERS.update(model.parameter_defaults)
    for symbol, value in model.parameter_defaults.items():
        if not symbol.name.startswith(R"\mathcal{H}"):
            continue
        if "production" not in symbol.name:
            continue
        del PARAMETERS[symbol]
        norm, phi = to_polar_coordinates(symbol)
        PARAMETERS[norm] = np.abs(value)
        PARAMETERS[phi] = np.angle(value)
        POLAR_SUBSTITUTIONS[symbol] = norm * sp.exp(phi * sp.I)
    del model

FREE_PARAMETERS = {
    s: value
    for s, value in PARAMETERS.items()
    if s.name.startswith("C")
    or s.name.startswith(R"\phi")
    or (s.name.startswith(R"\Gamma_") and "Sigma" not in s.name)
    or (s.name.startswith("m_") and "(" in s.name)
}
FIXED_PARAMETERS = {s: v for s, v in PARAMETERS.items() if s not in FREE_PARAMETERS}


@lru_cache(maxsize=None)
def unfold_and_substitute(expr: sp.Expr, reference_subsystem: int = 1) -> sp.Expr:
    expr = perform_cached_doit(expr)
    expr = perform_cached_doit(expr.xreplace(MODELS[reference_subsystem].amplitudes))
    expr = expr.xreplace(POLAR_SUBSTITUTIONS)
    return expr.xreplace(FIXED_PARAMETERS)
Hide code cell content
def create_function(
    expr: sp.Expr, reference_subsystem: int = 1
) -> ParametrizedFunction:
    global progress_bar  # noqa: PLW0602
    expr = unfold_and_substitute(expr, reference_subsystem)
    func = perform_cached_lambdify(expr, parameters=FREE_PARAMETERS)
    progress_bar.update()
    return func


progress_bar = tqdm(total=12, disable=NO_TQDM)
INTENSITY_FUNC = {
    reference_subsystem: create_function(MODELS[reference_subsystem].intensity)
    for reference_subsystem, model in MODELS.items()
}
POLARIMETRY_FUNCS = {
    reference_subsystem: tuple(
        create_function(expr, reference_subsystem)
        for expr in formulate_polarimetry(BUILDER, reference_subsystem)
    )
    for reference_subsystem in MODELS
}
progress_bar.close()
del progress_bar
Hide code cell source
def create_grid(resolution: int) -> DataSample:
    sample = generate_meshgrid_sample(DECAY, resolution)
    for model in MODELS.values():
        transformer = create_data_transformer(model)
        sample.update(transformer(sample))
    return sample


MESH_GRID = generate_meshgrid_sample(DECAY, resolution=200)
QUIVER_GRID = generate_meshgrid_sample(DECAY, resolution=35)
for model in tqdm(MODELS.values(), disable=NO_TQDM, leave=False):
    transformer = create_data_transformer(model)
    MESH_GRID.update(transformer(MESH_GRID))
    QUIVER_GRID.update(transformer(QUIVER_GRID))

# pre-compile
for ref in tqdm(MODELS, disable=NO_TQDM, leave=False):
    INTENSITY_FUNC[ref](MESH_GRID)
    for func in POLARIMETRY_FUNCS[ref]:
        func(QUIVER_GRID)
Hide code cell source
def create_ui() -> HBox:
    @temporarily_deactivate_continuous_update
    def reset_sliders(click_event: Button | None = None) -> None:
        for symbol, value in FREE_PARAMETERS.items():
            set_slider(SLIDERS[symbol.name], value)

    reset_button = Button(description="Reset sliders", button_style="danger")
    reset_button.on_click(reset_sliders)
    reset_sliders()

    @temporarily_deactivate_continuous_update
    def set_reference_subsystem(value: Bunch) -> None:
        global REFERENCE_SUBSYSTEM  # noqa: PLW0603
        subsystems = {1: "K", 2: "L", 3: "D"}
        REFERENCE_SUBSYSTEM = value.new
        for name, slider in SLIDERS.items():
            if not name.startswith(R"\phi"):
                continue
            if subsystems[value.old] in name or subsystems[value.new] in name:
                phi = slider.value
                set_slider(slider, -np.sign(phi) * (np.pi - abs(phi)))

    reference_selector = RadioButtons(
        description="Reference sub-system",
        options=[
            ("1: K** → π⁺K⁻", 1),
            ("2: Λ** → pK⁻", 2),
            ("3: Δ** → pπ⁺", 3),
        ],
        layout=Layout(width="auto"),
    )
    reference_selector.observe(set_reference_subsystem, names="value")

    @temporarily_deactivate_continuous_update
    def set_coupling_to_zero(filter_pattern: Button) -> None:
        if isinstance(filter_pattern, Button):
            filter_pattern = from_unicode(filter_pattern.description)
        for name, slider in SLIDERS.items():
            if not name.startswith("C"):
                continue
            if filter_pattern not in name:
                continue
            set_slider(slider, 0)

    def set_all_to_zero(action: Button | None = None) -> None:
        set_coupling_to_zero("D")
        set_coupling_to_zero("K")
        set_coupling_to_zero("L")

    all_to_zero = Button(
        description="Set all couplings to zero",
        layout=Layout(width="auto"),
        tooltip="Set all couplings to zero",
    )
    all_to_zero.on_click(set_all_to_zero)
    resonance_buttons = []
    for p in RESONANCES:
        button = Button(
            description=to_unicode(p.name),
            layout=Layout(width="auto"),
            tooltip=f"Set couplings for {to_unicode(p.name)} to 0",
        )
        button.style.button_color = to_html_color(p.name)
        button.on_click(set_coupling_to_zero)
        resonance_buttons.append(button)
    subsystem_buttons = []
    for subsystem_id in sorted(["D", "K", "L"]):
        button = Button(
            description=f"{to_unicode(subsystem_id)}**",
            tooltip=f"Set couplings for all {to_unicode(subsystem_id)}** to 0",
        )
        button.style.button_color = to_html_color(subsystem_id)
        button.on_click(set_coupling_to_zero)
        subsystem_buttons.append(button)
    zero_coupling_panel = GridBox(
        [
            all_to_zero,
            HBox(subsystem_buttons),
            GridBox(
                np.reshape(resonance_buttons, (4, 3)).T.flatten().tolist(),
                layout=Layout(grid_template_columns=4 * "auto "),
            ),
        ]
    )

    def get_subscript(p):
        return f"{p.name} \\to p K^-" if "1405" in p.name else p.name

    grouped_sliders = []
    for p in RESONANCES:
        row = (
            HTML("", layout=Layout(width="auto")),
            SLIDERS[f"m_{{{p.name}}}"],
            SLIDERS[Rf"\Gamma_{{{get_subscript(p)}}}"],
        )
        rows = [row]
        for slider_name in SLIDERS:
            if p.name not in slider_name:
                continue
            if not slider_name.startswith("C"):
                continue
            row = (
                HTMLMath(
                    f"${slider_name}$".replace("C", R"\mathcal{H}"),
                    layout=Layout(width="auto"),
                ),
                SLIDERS[slider_name],
                SLIDERS[slider_name.replace("C", R"\phi")],
            )
            rows.append(row)
        rows = np.array(rows)
        grouped_sliders.append(
            GridBox(
                rows.flatten().tolist(),
                layout=Layout(grid_template_columns=3 * "auto "),
            )
        )
    return HBox(
        [
            GridBox([reset_button, reference_selector]),
            Tab(grouped_sliders, titles=[to_unicode(p.name) for p in RESONANCES]),
            zero_coupling_panel,
        ]
    )


def create_slider(symbol: sp.Basic, value: float) -> FloatSlider:
    (
        (s1_min, s1_max),
        (s2_min, s2_max),
        (s3_min, s3_max),
    ) = compute_dalitz_boundaries(DECAY)
    slider = FloatSlider(
        description=Rf"\({sp.latex(symbol)})",
        continuous_update=True,
        readout_format=".3f",
        step=1e-3,
    )
    if symbol.name.startswith("m"):
        slider.description = "mass"
        slider.style.handle_color = "lightblue"
        if "K" in symbol.name:
            slider.min = np.sqrt(s1_min)
            slider.max = np.sqrt(s1_max)
        elif "L" in symbol.name:
            slider.min = np.sqrt(s2_min)
            slider.max = np.sqrt(s2_max)
        elif "D" in symbol.name:
            slider.min = np.sqrt(s3_min)
            slider.max = np.sqrt(s3_max)
    elif symbol.name.startswith(R"\Gamma"):
        slider.description = "width"
        slider.style.handle_color = "lightblue"
        slider.min = 0
        slider.max = max(0.5, 2 * slider.value)
    elif symbol.name.startswith("C"):
        slider.description = "r"
        slider.min = 0
        slider.max = 20
        slider.readout_format = ".1f"
        slider.step = 1e-1
    elif symbol.name.startswith(R"\phi"):
        slider.description = "φ"
        slider.min = -np.pi
        slider.max = +np.pi
        slider.readout_format = ".2f"
        slider.step = 1e-2
    return slider


def set_slider(slider: FloatSlider, value: float) -> None:
    n_decimals = -round(np.log10(slider.step))
    if slider.value != round(value, n_decimals):  # widget performance
        slider.value = value


def to_html_color(name: str) -> str:
    if "K" in name:
        return "#FFCCCB"  # light red
    if "L" in name:
        return "lightblue"
    if "D" in name:
        return "lightgreen"
    raise NotImplementedError


def to_unicode(text: str) -> str:
    text = text.replace("L", "Λ")
    return text.replace("D", "Δ")


def from_unicode(text: str) -> str:
    text = text.replace("Λ", "L")
    text = text.replace("Δ", "D")
    return text.replace("*", "")


def temporarily_deactivate_continuous_update(func):
    def new_func(*args, **kwargs):
        for slider in SLIDERS.values():
            slider.continuous_update = False
        output = func(*args, **kwargs)
        for slider in SLIDERS.values():
            slider.continuous_update = True
        return output

    return new_func


REFERENCE_SUBSYSTEM = 1
SLIDERS = {s.name: create_slider(s, value) for s, v in FREE_PARAMETERS.items()}
UI = create_ui()
Hide code cell source
def create_interactive_plot() -> None:
    plt.rcdefaults()
    use_mpl_latex_fonts()
    plt.rc("font", size=20)
    fig, axes = plt.subplots(
        figsize=(15, 7.5),
        ncols=2,
        sharey=True,
    )
    ax1, ax2 = axes
    ax1.set_title("Intensity distribution")
    ax2.set_title("Polarimeter vector field")
    ax1.set_xlabel(R"$m^2(K^- \pi^+)$")
    ax2.set_xlabel(R"$m^2(K^- \pi^+), \alpha_x$")
    ax1.set_ylabel(R"$m^2(p K^-), \alpha_x$")
    for ax in axes:
        ax.set_box_aspect(1)
    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False
    fig.canvas.footer_visible = False

    mesh = None
    quiver = None
    intensity_bar = None

    def plot3(**kwargs):
        nonlocal quiver, mesh, intensity_bar
        intensity_func = INTENSITY_FUNC[REFERENCE_SUBSYSTEM]
        polarimetry_funcs = POLARIMETRY_FUNCS[REFERENCE_SUBSYSTEM]
        for func in [intensity_func, *polarimetry_funcs]:
            func.update_parameters(kwargs)
        intensities = intensity_func(MESH_GRID)
        αx, αy, αz = tuple(func(QUIVER_GRID).real for func in polarimetry_funcs)
        abs_α = jnp.sqrt(αx**2 + αy**2 + αz**2)
        if mesh is None:
            mesh = ax1.pcolormesh(
                MESH_GRID["sigma1"],
                MESH_GRID["sigma2"],
                intensities,
                cmap=plt.cm.YlOrRd,
                norm=LogNorm(),
            )
            intensity_bar = fig.colorbar(mesh, ax=ax1, pad=0.01, fraction=0.0473)
            intensity_bar.ax.set_ylabel("normalized intensity (a.u.)")
        else:
            mesh.set_array(intensities)
            if jnp.isfinite(intensities).any():
                y_min = max(np.nanmin(intensities), 1e0)
                y_max = max(np.nanmax(intensities), 1e2)
                mesh.set_clim(y_min, y_max)
                intensity_bar.ax.set_ylim(y_min, y_max)
        if quiver is None:
            quiver = ax2.quiver(
                QUIVER_GRID["sigma1"],
                QUIVER_GRID["sigma2"],
                αz,
                αx,
                abs_α,
                cmap=plt.cm.viridis_r,
                clim=(0, 1),
            )
            c_bar = fig.colorbar(quiver, ax=ax2, pad=0.01, fraction=0.0473)
            c_bar.ax.set_ylabel(R"$\left|\vec\alpha\right|$")
        else:
            quiver.set_UVC(αz, αx, abs_α)
        fig.canvas.draw_idle()

    output = interactive_output(plot3, controls=SLIDERS)
    fig.tight_layout()
    if NO_TQDM:
        export_file = "../_static/images/interactive-plot.png"
        fig.savefig(export_file, dpi=200)
        src = f"""
        :::{{container}} full-width
        ![]({export_file})
        :::
        """
        src = dedent(src)
        display(Markdown(src))
    else:
        display(output, UI)


%matplotlib widget
create_interactive_plot()

Tip

Run this notebook locally in Jupyter or online on Binder to modify parameters interactively!