3. Intensity distribution#

Hide code cell content
from __future__ import annotations

import logging
import os
from io import BytesIO
from itertools import product
from urllib.request import urlopen
from zipfile import ZipFile

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.helicity.naming import natural_sorting
from IPython.display import Markdown
from matplotlib.image import imread
from matplotlib.patches import Rectangle
from tensorwaves.interface import DataSample
from tqdm.auto import tqdm

from polarimetry.data import (
    create_data_transformer,
    generate_meshgrid_sample,
    generate_phasespace_sample,
    generate_sub_meshgrid_sample,
)
from polarimetry.decay import Particle
from polarimetry.function import (
    compute_sub_function,
    integrate_intensity,
    interference_intensity,
    sub_intensity,
)
from polarimetry.io import (
    mute_jax_warnings,
    perform_cached_doit,
    perform_cached_lambdify,
)
from polarimetry.lhcb import load_model
from polarimetry.lhcb.particle import load_particles
from polarimetry.plot import (
    add_watermark,
    get_contour_line,
    stylize_contour,
    use_mpl_latex_fonts,
)

mute_jax_warnings()
particles = load_particles("../data/particle-definitions.yaml")
model = load_model("../data/model-definitions.yaml", particles, model_id=0)

NO_TQDM = "EXECUTE_NB" in os.environ
if NO_TQDM:
    logging.getLogger().setLevel(logging.ERROR)
unfolded_intensity_expr = perform_cached_doit(model.full_expression)

The complete intensity expression contains 43,198 mathematical operations.

3.1. Definition of free parameters#

free_parameters = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if isinstance(symbol, sp.Indexed)
    if "production" in str(symbol)
}
fixed_parameters = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if symbol not in free_parameters
}
subs_intensity_expr = unfolded_intensity_expr.xreplace(fixed_parameters)

After substituting the parameters that are not production couplings, the total intensity expression contains 9,516 operations.

3.2. Distribution#

intensity_func = perform_cached_lambdify(
    subs_intensity_expr,
    parameters=free_parameters,
    backend="jax",
)
transformer = create_data_transformer(model)
grid_sample = generate_meshgrid_sample(model.decay, resolution=1_000)
grid_sample = transformer(grid_sample)
X = grid_sample["sigma1"]
Y = grid_sample["sigma2"]
Hide code cell source
%config InlineBackend.figure_formats = ['png']
s1_label = R"$\sigma_1=m^2\left(K^-\pi^+\right)$ [GeV$^2$]"
s2_label = R"$\sigma_2=m^2\left(pK^-\right)$ [GeV$^2$]"
s3_label = R"$\sigma_3=m^2\left(p\pi^+\right)$ [GeV$^2$]"

plt.rcdefaults()
use_mpl_latex_fonts()
plt.rc("font", size=21)
fig, ax = plt.subplots(dpi=200, figsize=(8.22, 7), tight_layout=True)
ax.set_xlabel(s1_label)
ax.set_ylabel(s2_label)

INTENSITIES = intensity_func(grid_sample)
INTENSITY_INTEGRAL = jnp.nansum(INTENSITIES)
NORMALIZED_INTENSITIES = INTENSITIES / INTENSITY_INTEGRAL
np.testing.assert_almost_equal(jnp.nansum(NORMALIZED_INTENSITIES), 1.0)
mesh = ax.pcolormesh(X, Y, NORMALIZED_INTENSITIES)
c_bar = fig.colorbar(mesh, ax=ax, pad=0.02)
c_bar.ax.set_ylabel("Normalized intensity")
add_watermark(ax, 0.7, 0.82, fontsize=24)
fig.savefig("_static/images/intensity-distribution.png")
plt.show()
_images/c208fa5567ebb995fbdaec95e9a2366f9e824e4f908b16e98b70a61439fc8f1c.png

High-resolution image can be downloaded here: intensity-distribution.png

Comparison with Figure 2 from the original LHCb study [1]:

Hide code cell source
def download_lhcb_intensity() -> None:
    figure_path = "LHCb-PAPER-2022-002-figures/Fig2.png"
    if os.path.exists(figure_path):
        return
    url = "https://cds.cern.ch/record/2824328/files/LHCb-PAPER-2022-002-figures.zip?version=1"
    http_response = urlopen(url)  # noqa: S310
    zipfile = ZipFile(BytesIO(http_response.read()))
    zipfile.extract(figure_path)


def plot_horizontal_intensity(ax) -> None:
    ax.set_xlabel("$" + s2_label[10:])
    ax.set_ylabel("$" + s1_label[10:])
    ax.set_xlim(1.79, 4.95)
    ax.set_ylim(0.18, 2.05)
    add_watermark(ax, 0.7, 0.78, fontsize=18)
    mesh = ax.pcolormesh(
        grid_sample["sigma2"],
        grid_sample["sigma1"],
        NORMALIZED_INTENSITIES,
    )
    c_bar = fig.colorbar(mesh, ax=ax, pad=0.02)
    c_bar.ax.set_ylabel("Normalized intensity")


%config InlineBackend.figure_formats = ['png']
download_lhcb_intensity()
plt.rcdefaults()
use_mpl_latex_fonts()
plt.rc("font", size=16)
fig = plt.figure(dpi=200, figsize=(11.5, 4))
ax1 = fig.add_axes((0.1, 0.17, 0.40, 0.78))
ax2 = fig.add_axes((0.5, 0.0, 0.5, 1.0))
ax2.axis("off")
plot_horizontal_intensity(ax1)
image = imread("LHCb-PAPER-2022-002-figures/Fig2.png")
ax2.imshow(image)
fig.savefig("_static/images/intensity-distribution-comparison.png")
plt.show()
_images/439ca067398c87a18db90959126a157561bbc2d37a7c78a2082283177a07d9d7.png
Hide code cell content
plt.ioff()
fig, ax = plt.subplots(dpi=200, figsize=(5.8, 4))
plot_horizontal_intensity(ax)
fig.savefig(
    "_static/images/intensity-distribution-comparison-left.png",
    bbox_inches="tight",
)
plt.ion()
plt.close(fig)
Hide code cell source
def set_ylim_to_zero(ax):
    _, y_max = ax.get_ylim()
    ax.set_ylim(0, y_max)


%config InlineBackend.figure_formats = ['svg']
plt.rcdefaults()
use_mpl_latex_fonts()
plt.rc("font", size=18)
fig, (ax1, ax2) = plt.subplots(
    ncols=2,
    figsize=(12, 5),
    tight_layout=True,
    sharey=True,
)
ax1.set_xlabel(s1_label)
ax2.set_xlabel(s2_label)
ax1.set_ylabel("Normalized intensity")

subsystem_identifiers = ["K", "L", "D"]
subsystem_labels = ["K^{**}", R"\Lambda^{**}", R"\Delta^{**}"]
x, y = X[0], Y[:, 0]
ax1.fill(x, jnp.nansum(NORMALIZED_INTENSITIES, axis=0), alpha=0.3)
ax2.fill(y, jnp.nansum(NORMALIZED_INTENSITIES, axis=1), alpha=0.3)

original_parameters = dict(intensity_func.parameters)
for label, identifier in zip(subsystem_labels, subsystem_identifiers):
    label = f"${label}$"
    sub_intensities = compute_sub_function(intensity_func, grid_sample, [identifier])
    sub_intensities /= INTENSITY_INTEGRAL
    ax1.plot(x, jnp.nansum(sub_intensities, axis=0), label=label)
    ax2.plot(y, jnp.nansum(sub_intensities, axis=1), label=label)
    del sub_intensities
    intensity_func.update_parameters(original_parameters)
set_ylim_to_zero(ax1)
set_ylim_to_zero(ax2)
ax2.legend()
plt.savefig("_images/intensity-distributions-1D.svg")
plt.show()
_images/4073671efb1a86889324c4dbafe4811a7453befcf31299d525e1dd9debcddc54.svg

3.3. Decay rates#

integration_sample = generate_phasespace_sample(
    model.decay, n_events=100_000, seed=0
)
integration_sample = transformer(integration_sample)
I_tot = integrate_intensity(intensity_func(integration_sample))
I_K = sub_intensity(intensity_func, integration_sample, non_zero_couplings=["K"])
I_Λ = sub_intensity(intensity_func, integration_sample, non_zero_couplings=["L"])
I_Δ = sub_intensity(intensity_func, integration_sample, non_zero_couplings=["D"])
I_ΛΔ = interference_intensity(intensity_func, integration_sample, ["L"], ["D"])
I_KΔ = interference_intensity(intensity_func, integration_sample, ["K"], ["D"])
I_KΛ = interference_intensity(intensity_func, integration_sample, ["K"], ["L"])
np.testing.assert_allclose(I_tot, I_K + I_Λ + I_Δ + I_ΛΔ + I_KΔ + I_KΛ)
Hide code cell content
def compute_decay_rates(func, integration_sample: DataSample):
    decay_rates = np.zeros(shape=(n_resonances, n_resonances))
    combinations = list(product(enumerate(resonances), enumerate(resonances)))
    progress_bar = tqdm(
        desc="Calculating rate matrix",
        disable=NO_TQDM,
        total=(len(combinations) + n_resonances) // 2,
    )
    I_tot = integrate_intensity(intensity_func(integration_sample))
    for (i, resonance1), (j, resonance2) in combinations:
        if j < i:
            continue
        progress_bar.postfix = f"{resonance1.name} × {resonance2.name}"
        res1 = to_regex(resonance1.name)
        res2 = to_regex(resonance2.name)
        if res1 == res2:
            I_sub = sub_intensity(
                func, integration_sample, non_zero_couplings=[res1]
            )
        else:
            I_sub = interference_intensity(func, integration_sample, [res1], [res2])
        decay_rates[i, j] = I_sub / I_tot
        if i != j:
            decay_rates[j, i] = decay_rates[i, j]
        progress_bar.update()
    progress_bar.close()
    return decay_rates


def to_regex(text: str) -> str:
    text = text.replace("(", r"\(")
    return text.replace(")", r"\)")


def sort_resonances(resonance: Particle):
    KDL = {"L": 1, "D": 2, "K": 3}
    return KDL[resonance.name[0]], natural_sorting(resonance.name)


resonances = sorted(
    (chain.resonance for chain in model.decay.chains),
    key=sort_resonances,
    reverse=True,
)
n_resonances = len(resonances)
Hide code cell source
def visualize_decay_rates(decay_rates, title=R"Rate matrix for isobars (\%)"):
    vmax = jnp.max(jnp.abs(decay_rates))
    plt.rcdefaults()
    use_mpl_latex_fonts()
    plt.rc("font", size=14)
    plt.rc("axes", titlesize=24)
    plt.rc("xtick", labelsize=16)
    plt.rc("ytick", labelsize=16)
    fig, ax = plt.subplots(figsize=(9, 9))
    ax.set_title(title)
    ax.matshow(
        jnp.rot90(decay_rates).T, cmap=plt.cm.coolwarm, vmin=-vmax, vmax=+vmax
    )

    resonance_latex = [f"${p.latex}$" for p in resonances]
    ax.set_xticks(range(n_resonances))
    ax.set_xticklabels(reversed(resonance_latex), rotation=90)
    ax.set_yticks(range(n_resonances))
    ax.set_yticklabels(resonance_latex)
    for i in range(n_resonances):
        for j in range(n_resonances):
            if j < i:
                continue
            rate = decay_rates[i, j]
            ax.text(
                n_resonances - j - 1,
                i,
                f"{100 * rate:.2f}",
                va="center",
                ha="center",
            )
    fig.tight_layout()
    return fig


%config InlineBackend.figure_formats = ['svg']
decay_rates = compute_decay_rates(intensity_func, integration_sample)
fig = visualize_decay_rates(decay_rates)
fig.savefig("_images/rate-matrix.svg")
plt.show()
_images/2748e6f3e7aeb85e81bfb0e5bd71aa2e5496457fdc431d04308ba9ed3cf317a9.svg
Hide code cell content
def compute_sum_over_decay_rates(decay_rate_matrix) -> float:
    decay_rate_sum = 0.0
    for i in range(len(resonances)):
        for j in range(len(resonances)):
            if j < i:
                continue
            decay_rate_sum += decay_rate_matrix[i, j]
    return decay_rate_sum
np.testing.assert_almost_equal(compute_sum_over_decay_rates(decay_rates), 1.0)

3.4. Dominant decays#

Hide code cell source
%config InlineBackend.figure_formats = ['svg']
threshold = 0.5
percentage = int(100 * threshold)
I_tot = intensity_func(grid_sample)

plt.rcdefaults()
use_mpl_latex_fonts()
plt.rc("font", size=18)
fig, ax = plt.subplots(figsize=(9.1, 7), sharey=True, tight_layout=True)
ax.set_ylabel(s2_label)
ax.set_xlabel(s1_label)
fig.suptitle(
    Rf"Regions where the resonance has a decay ratio of $\geq {percentage}$\%",
    y=0.95,
)

phsp_region = jnp.select(
    [I_tot > 0, True],
    (1, 0),
)
contour_set = ax.contour(X, Y, phsp_region, colors="none")
stylize_contour(contour_set, edgecolor="black", linewidth=0.2)

resonances = [c.resonance for c in model.decay.chains]
contour_levels = [i for i, _ in enumerate(resonances, 1)]
colors = [plt.cm.rainbow(x) for x in np.linspace(0, 1, len(resonances))]
linestyles = {
    "K": "dotted",
    "L": "dashed",
    "D": "solid",
}
items = list(zip(contour_levels, resonances, colors))  # tqdm requires len
progress_bar = tqdm(
    desc="Computing dominant region contours",
    disable=NO_TQDM,
    total=len(items),
)
legend_elements = []
for res_id, resonance, color in items:
    progress_bar.postfix = resonance.name
    regex_filter = resonance.name.replace("(", r"\(").replace(")", r"\)")
    I_sub = compute_sub_function(intensity_func, grid_sample, [regex_filter])
    ratio = I_sub / I_tot
    selection = jnp.select(
        [jnp.isnan(ratio), ratio < threshold, True],
        [0, 0, res_id],
    )
    progress_bar.update()
    if jnp.all(selection == 0):
        continue
    contour_set = ax.contour(X, Y, selection, colors="none")
    contour_set.set_clim(vmin=1, vmax=len(model.decay.chains))
    stylize_contour(
        contour_set,
        label=f"${resonance.latex}$",
        edgecolor=color,
        linestyle=linestyles[resonance.name[0]],
    )
    line_collection = get_contour_line(contour_set)
    legend_elements.append(line_collection)
progress_bar.close()


sub_region_x_range = (0.68, 0.72)
sub_region_y_range = (2.52, 2.60)
region_indicator = Rectangle(
    xy=(
        sub_region_x_range[0],
        sub_region_y_range[0],
    ),
    width=sub_region_x_range[1] - sub_region_x_range[0],
    height=sub_region_y_range[1] - sub_region_y_range[0],
    edgecolor="black",
    facecolor="none",
    label="Sub-region",
    linewidth=0.5,
)
ax.add_patch(region_indicator)
legend_elements.append(region_indicator)

leg = plt.legend(
    handles=legend_elements,
    bbox_to_anchor=(1.38, 1),
    framealpha=1,
    loc="upper right",
)
fig.savefig("_images/sub-regions.svg")
plt.show()
_images/952fb8d47b9c89039f28540cd78660d9d575eee0e739620f1be44e4cd90cdc5c.svg
Hide code cell source
%config InlineBackend.figure_formats = ['svg']
sub_sample = generate_sub_meshgrid_sample(
    model.decay,
    resolution=50,
    x_range=sub_region_x_range,
    y_range=sub_region_y_range,
)
sub_sample = transformer(sub_sample)
sub_decay_rates = compute_decay_rates(intensity_func, sub_sample)
fig = visualize_decay_rates(sub_decay_rates, title="Rate matrix over sub-region")
fig.savefig("_images/rate-matrix-sub-region.svg")
plt.show()
_images/5277e452a0a1f6ee3669499b66b899c0cf16e1bb6f464e42dee72bb53cf2b67b.svg
np.testing.assert_almost_equal(compute_sum_over_decay_rates(sub_decay_rates), 1.0)