from __future__ import annotations
import sys
from typing import TYPE_CHECKING
import jax.numpy as jnp
import sympy as sp
from ampform.kinematics.phasespace import compute_third_mandelstam, is_within_phasespace
from tensorwaves.data import IntensityDistributionGenerator, NumpyDomainGenerator
from tensorwaves.data.rng import NumpyUniformRNG
from tensorwaves.data.transform import SympyDataTransformer
from tensorwaves.function.sympy import create_function
if TYPE_CHECKING:
from tensorwaves.function import PositionalArgumentFunction
from tensorwaves.interface import DataSample
from polarimetry.amplitude import AmplitudeModel
from polarimetry.decay import ThreeBodyDecay
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
[docs]def create_phase_space_filter(
decay: ThreeBodyDecay,
x_mandelstam: Literal[1, 2, 3] = 1,
y_mandelstam: Literal[1, 2, 3] = 2,
outside_value=sp.nan,
) -> PositionalArgumentFunction:
m0, m1, m2, m3 = create_mass_symbol_mapping(decay).values()
sigma_x = sp.Symbol(f"sigma{x_mandelstam}", nonnegative=True)
sigma_y = sp.Symbol(f"sigma{y_mandelstam}", nonnegative=True)
in_phsp_expr = is_within_phasespace(sigma_x, sigma_y, m0, m1, m2, m3, outside_value)
return create_function(in_phsp_expr.doit(), backend="jax", use_cse=True)
[docs]def generate_meshgrid_sample(
decay: ThreeBodyDecay,
resolution: int,
x_mandelstam: Literal[1, 2, 3] = 1,
y_mandelstam: Literal[1, 2, 3] = 2,
) -> DataSample:
"""Generate a `numpy.meshgrid` sample for plotting with `matplotlib.pyplot`."""
boundaries = compute_dalitz_boundaries(decay)
return generate_sub_meshgrid_sample(
decay,
resolution,
x_range=boundaries[x_mandelstam - 1],
y_range=boundaries[y_mandelstam - 1],
x_mandelstam=x_mandelstam,
y_mandelstam=y_mandelstam,
)
[docs]def generate_sub_meshgrid_sample(
decay: ThreeBodyDecay,
resolution: int,
x_range: tuple[float, float],
y_range: tuple[float, float],
x_mandelstam: Literal[1, 2, 3] = 1,
y_mandelstam: Literal[1, 2, 3] = 2,
) -> DataSample:
sigma_x, sigma_y = jnp.meshgrid(
jnp.linspace(*x_range, num=resolution),
jnp.linspace(*y_range, num=resolution),
)
phsp = {
f"sigma{x_mandelstam}": sigma_x,
f"sigma{y_mandelstam}": sigma_y,
}
z_mandelstam = __get_third_mandelstam_index(x_mandelstam, y_mandelstam)
compute_sigma_z = __create_compute_compute_sigma_z(
decay, x_mandelstam, y_mandelstam
)
phsp[f"sigma{z_mandelstam}"] = compute_sigma_z(phsp)
return phsp
[docs]def generate_phasespace_sample(
decay: ThreeBodyDecay, n_events: int, seed: int | None = None
) -> DataSample:
r"""Generate a uniform distribution over Dalitz variables :math:`\sigma_{1,2,3}`."""
boundaries = compute_dalitz_boundaries(decay)
domain_generator = NumpyDomainGenerator(
boundaries={
"sigma1": boundaries[0],
"sigma2": boundaries[1],
}
)
phsp_filter = create_phase_space_filter(decay, outside_value=0)
phsp_generator = IntensityDistributionGenerator(domain_generator, phsp_filter)
rng = NumpyUniformRNG(seed)
phsp = phsp_generator.generate(n_events, rng)
compute_sigma_z = __create_compute_compute_sigma_z(decay)
phsp["sigma3"] = compute_sigma_z(phsp)
return phsp
[docs]def compute_dalitz_boundaries(
decay: ThreeBodyDecay,
) -> tuple[tuple[float, float], tuple[float, float], tuple[float, float]]:
m0, m1, m2, m3 = create_mass_symbol_mapping(decay).values()
return (
((m2 + m3) ** 2, (m0 - m1) ** 2),
((m3 + m1) ** 2, (m0 - m2) ** 2),
((m1 + m2) ** 2, (m0 - m3) ** 2),
)
def __create_compute_compute_sigma_z(
decay: ThreeBodyDecay,
x_mandelstam: Literal[1, 2, 3] = 1,
y_mandelstam: Literal[1, 2, 3] = 2,
) -> PositionalArgumentFunction:
m0, m1, m2, m3 = create_mass_symbol_mapping(decay).values()
sigma_x = sp.Symbol(f"sigma{x_mandelstam}", nonnegative=True)
sigma_y = sp.Symbol(f"sigma{y_mandelstam}", nonnegative=True)
sigma_k = compute_third_mandelstam(sigma_x, sigma_y, m0, m1, m2, m3)
return create_function(sigma_k, backend="jax", use_cse=True)
[docs]def create_mass_symbol_mapping(decay: ThreeBodyDecay) -> dict[sp.Symbol, float]:
return {
sp.Symbol(f"m{i}"): decay.states[i].mass
for i in sorted(decay.states) # ensure that dict keys are sorted by state ID
}
def __get_third_mandelstam_index(
x_mandelstam: Literal[1, 2, 3], y_mandelstam: Literal[1, 2, 3]
):
if x_mandelstam == y_mandelstam:
msg = "x_mandelstam and y_mandelstam must be different"
raise ValueError(msg)
return next(iter({1, 2, 3} - {x_mandelstam, y_mandelstam}))