# Cross-check with LHCb data

```{autolink-concat}
```

In [None]:
from __future__ import annotations

import json
import logging
import os
from functools import lru_cache
from textwrap import dedent

import numpy as np
import sympy as sp
from IPython.display import Markdown, Math
from tqdm.auto import tqdm

from polarimetry.amplitude import AmplitudeModel, simplify_latex_rendering
from polarimetry.data import create_data_transformer
from polarimetry.io import (
 as_latex,
 display_latex,
 mute_jax_warnings,
 perform_cached_doit,
 perform_cached_lambdify,
)
from polarimetry.lhcb import (
 get_conversion_factor,
 get_conversion_factor_ls,
 load_model,
 load_model_builder,
 parameter_key_to_symbol,
)
from polarimetry.lhcb.particle import load_particles


@lru_cache(maxsize=None)
def load_model_cached(model_id: int | str) -> AmplitudeModel:
 return load_model(MODEL_FILE, PARTICLES, model_id)


mute_jax_warnings()
simplify_latex_rendering()
NO_TQDM = "EXECUTE_NB" in os.environ
if NO_TQDM:
 logging.getLogger().setLevel(logging.ERROR)

MODEL_FILE = "../data/model-definitions.yaml"
PARTICLES = load_particles("../data/particle-definitions.yaml")
DEFAULT_MODEL = load_model_cached(model_id=0)

In [None]:
with open("../data/crosscheck.json") as stream:
 crosscheck_data = json.load(stream)

## Lineshape comparison

We compute a few lineshapes for the following point in phase space and compare it with the values from {cite}`LHCb-PAPER-2022-002`:

In [None]:
σ1, σ2, σ3 = sp.symbols("sigma1:4", nonnegative=True)
lineshape_vars = crosscheck_data["mainvars"]
lineshape_subs = {
 σ1: lineshape_vars["m2kpi"],
 σ2: lineshape_vars["m2pk"],
 **DEFAULT_MODEL.parameter_defaults,
}
lineshape_vars

The lineshapes are computed for the following decay chains:

In [None]:
K892_chain = DEFAULT_MODEL.decay.find_chain("K(892)")
L1405_chain = DEFAULT_MODEL.decay.find_chain("L(1405)")
L1690_chain = DEFAULT_MODEL.decay.find_chain("L(1690)")
Math(as_latex([K892_chain, L1405_chain, L1690_chain]))

In [None]:
crosscheck_data["lineshapes"]

In [None]:
def build_dynamics(c):
 return builder.dynamics_choices.get_builder(c)(c)[0].doit()


builder = load_model_builder(MODEL_FILE, PARTICLES, model_id=0)
K892_bw_val = build_dynamics(K892_chain).xreplace(lineshape_subs).n()
L1405_bw_val = build_dynamics(L1405_chain).xreplace(lineshape_subs).n()
L1690_bw_val = build_dynamics(L1690_chain).xreplace(lineshape_subs).n()
display_latex([K892_bw_val, L1405_bw_val, L1690_bw_val])

In [None]:
lineshape_decimals = 13
np.testing.assert_array_almost_equal(
 np.array(list(map(complex, crosscheck_data["lineshapes"].values()))),
 np.array(list(map(complex, [K892_bw_val, L1405_bw_val, L1690_bw_val]))),
 decimal=lineshape_decimals,
)
src = f"""
:::{{tip}}
These values are **equal up to {lineshape_decimals} decimals**.
:::
"""
Markdown(src)

## Amplitude comparison

The amplitude for each decay chain and each outer state helicity combination are evaluated on the following point in phase space:

In [None]:
amplitude_vars = dict(crosscheck_data["chainvars"])
transformer = create_data_transformer(DEFAULT_MODEL)
input_data = {
 str(σ1): amplitude_vars["m2kpi"],
 str(σ2): amplitude_vars["m2pk"],
 str(σ3): amplitude_vars["m2ppi"],
}
input_data = {k: float(v) for k, v in transformer(input_data).items()}
display_latex({sp.Symbol(k): v for k, v in input_data.items()})

In [None]:
@lru_cache(maxsize=None)
def create_amplitude_functions(
 model_id: int | str,
) -> dict[tuple[sp.Rational, sp.Rational], sp.Expr]:
 model = load_model(MODEL_FILE, PARTICLES, model_id)
 production_couplings = get_production_couplings(model_id)
 fixed_parameters = {
 s: v
 for s, v in model.parameter_defaults.items()
 if s not in production_couplings
 }
 exprs = formulate_amplitude_expressions(model_id)
 return {
 k: perform_cached_lambdify(
 expr.xreplace(fixed_parameters),
 parameters=production_couplings,
 backend="numpy",
 )
 for k, expr in tqdm(exprs.items(), desc="Performing doit", disable=NO_TQDM)
 }


@lru_cache(maxsize=None)
def formulate_amplitude_expressions(
 model_id: int | str,
) -> dict[tuple[sp.Rational, sp.Rational], sp.Expr]:
 builder = load_model_builder(MODEL_FILE, PARTICLES, model_id)
 half = sp.Rational(1, 2)
 exprs = {
 (λ_Λc, λ_p): builder.formulate_aligned_amplitude(λ_Λc, λ_p, 0, 0)[0]
 for λ_Λc in [-half, +half]
 for λ_p in [-half, +half]
 }
 model = load_model(MODEL_FILE, PARTICLES, model_id)
 return {
 k: perform_cached_doit(expr.doit().xreplace(model.amplitudes))
 for k, expr in tqdm(exprs.items(), desc="Lambdifying", disable=NO_TQDM)
 }


@lru_cache(maxsize=None)
def get_production_couplings(model_id: int | str) -> dict[sp.Indexed, complex]:
 model = load_model(MODEL_FILE, PARTICLES, model_id)
 return {
 symbol: value
 for symbol, value in model.parameter_defaults.items()
 if isinstance(symbol, sp.Indexed)
 if "production" in str(symbol)
 }

In [None]:
def plusminus_to_helicity(plusminus: str) -> sp.Rational:
 half = sp.Rational(1, 2)
 if plusminus == "+":
 return +half
 if plusminus == "-":
 return -half
 raise NotImplementedError(plusminus)


def create_comparison_table(
 model_id: int | str, decimals: int | None = None
) -> Markdown:
 min_ls = not is_ls_model(model_id)
 amplitude_funcs = create_amplitude_functions(model_id)
 real_amp_crosscheck = {
 k: v
 for k, v in get_amplitude_crosscheck_data(model_id).items()
 if k.startswith("Ar")
 }
 production_couplings = get_production_couplings(model_id)
 couplings_to_zero = {str(symbol): 0 for symbol in production_couplings}

 src = ""
 if decimals is not None:
 src += dedent(f"""
 :::{{tip}}
 Computed amplitudes are equal to LHCb amplitudes up to **{decimals} decimals**.
 :::
 """)
 src += dedent("""
 | | Computed | Expected | Difference |
 | ---:| --------:| --------:| ----------:|
 """)
 for amp_identifier, entry in real_amp_crosscheck.items():
 coupling = parameter_key_to_symbol(
 amp_identifier.replace("Ar", "A"),
 min_ls,
 particle_definitions=PARTICLES,
 )
 src += f"| **`{amp_identifier}`** | ${sp.latex(coupling)}$ |\n"
 for matrix_key, expected in entry.items():
 matrix_suffix = matrix_key[1:] # ++, +-, -+, --
 λ_Λc, λ_p = map(plusminus_to_helicity, matrix_suffix)
 func = amplitude_funcs[(λ_Λc, -λ_p)]
 func.update_parameters(couplings_to_zero)
 func.update_parameters({str(coupling): 1})
 computed = complex(func(input_data))
 computed *= determine_conversion_factor(coupling, λ_p, min_ls)
 expected = complex(expected)
 if abs(expected) != 0.0:
 diff = abs(computed - expected) / abs(expected)
 if diff < 1e-6:
 diff = f"{diff:.2e}"
 else:
 diff = f'{diff:.2e}'
 else:
 diff = ""
 src += (
 f"| `{matrix_key}` | {computed:>.6f} | {expected:>.6f} | {diff} |\n"
 )
 if decimals is not None:
 np.testing.assert_array_almost_equal(
 computed,
 expected,
 decimal=decimals,
 err_msg=f" {amp_identifier} {matrix_key}",
 )
 return Markdown(src)


def determine_conversion_factor(
 coupling: sp.Indexed, λ_p: sp.Rational, min_ls: bool
) -> int:
 resonance_name = coupling.indices[0]
 resonance = PARTICLES[str(resonance_name)]
 if min_ls:
 factor = get_conversion_factor(resonance)
 else:
 _, L, S = coupling.indices
 factor = get_conversion_factor_ls(resonance, L, S)
 half = sp.Rational(1, 2)
 factor *= int((-1) ** (half + λ_p)) # # additional sign flip for amplitude
 return factor


def is_ls_model(model_id: int | str) -> bool:
 if isinstance(model_id, int):
 return model_id == 17
 return "LS couplings" in model_id


def get_amplitude_crosscheck_data(model_id: int | str) -> dict[str, complex]:
 if is_ls_model(model_id):
 return crosscheck_data["chains_LS"]
 return crosscheck_data["chains"]

### Default model

In [None]:
create_comparison_table(model_id=0, decimals=13)

### LS-model

In [None]:
create_comparison_table(
 "Alternative amplitude model obtained using LS couplings",
 decimals=13,
)