from __future__ import annotations
from jax import Array
from jax.typing import ArrayLike
import jax.numpy as jnp
import jax_cosmo as jc
from . import cosmology
from .emus import SigmaMEmulator
# jax-cosmo power spectra differ from colossus at the 0.3% level, which
# results in %-level discrepancies in HMF predictions. This fudge factor
# solves that.
_jax_cosmo_pk_corr = 1.0 / 1.0030
[docs]
def mass_to_lagrangian_radius(M: ArrayLike, cosmo: jc.Cosmology) -> Array:
"""Convert mass to Lagrangian radius.
Computes the radius of a sphere containing mass M at the mean matter
density of the universe at z=0.
Parameters
----------
M : Array
Mass [h-1 Msun]
cosmo : jc.Cosmology
Underlying cosmology
Returns
-------
Array
Lagrangian radius [h-1 Mpc]
"""
M = jnp.asarray(M)
rho_crit_0 = cosmology.critical_density(0.0, cosmo)
rho_m0 = cosmo.Omega_m * rho_crit_0
return (3.0 * M / (4.0 * jnp.pi * rho_m0)) ** (1.0 / 3.0)
[docs]
def overdensity_c_to_m(delta_c: float, z: float, cosmo: jc.Cosmology) -> float:
"""Convert critical overdensity to mean overdensity.
Parameters
----------
delta_c : float
Overdensity with respect to critical density
z : float
Redshift
cosmo : jc.Cosmology
Underlying cosmology
Returns
-------
float
Overdensity with respect to mean matter density
"""
rho_m = (
cosmo.Omega_m * cosmology.critical_density(0.0, cosmo) * (1 + z) ** 3
)
rho_c = cosmology.critical_density(z, cosmo)
return delta_c * rho_c / rho_m
[docs]
def overdensity_m_to_c(delta_m: float, z: float, cosmo: jc.Cosmology) -> float:
"""Convert mean overdensity to critical overdensity.
Parameters
----------
delta_m : float
Overdensity with respect to mean matter density
z : float
Redshift
cosmo : jc.Cosmology
Underlying cosmology
Returns
-------
float
Overdensity with respect to critical density
"""
rho_m = (
cosmo.Omega_m * cosmology.critical_density(0.0, cosmo) * (1 + z) ** 3
)
rho_c = cosmology.critical_density(z, cosmo)
return delta_m * rho_m / rho_c
[docs]
def sigma_R(
R: ArrayLike,
z: ArrayLike,
cosmo: jc.Cosmology,
k_min: float = 1e-5,
k_max: float = 1e2,
n_k_int: int = 5000,
) -> Array:
"""Compute RMS variance of density fluctuations in spheres
of radius R at redshift z.
Parameters
----------
R : Array
Radius [h-1 Mpc]
z : Array
Redshift
cosmo : jc.Cosmology
Underlying cosmology
k_min : float
Minimum k for integration [h Mpc-1], default 1e-5
k_max : float
Maximum k for integration [h Mpc-1], default 1e2
n_k_int : int
Number of k-space integration points for :math:`\\sigma(R,z)`,
default 5000
Returns
-------
Array
RMS variance :math:`\\sigma(R,z)`
"""
R = jnp.asarray(R)
z = jnp.asarray(z)
# Following is needed to be able to JIT this function for different values
# of n_k_int. We need to ensure n_k_int is a concrete Python int (required
# for static array shape) and to clear jax_cosmo's workspace cache to avoid
# tracer leaks across different JITs
n_k_int = int(n_k_int)
cosmo._workspace.clear()
# Create k array for integration (h/Mpc)
k = jnp.logspace(jnp.log10(k_min), jnp.log10(k_max), n_k_int)
# Power spectrum at redshift z
a = 1.0 / (1.0 + z)
pk = jc.power.linear_matter_power(cosmo, k, a=a)
pk *= _jax_cosmo_pk_corr # consistency with colossus
# Window function for spherical top-hat
# Handle broadcasting for both scalar and array R
kR = k * R[..., None] # Broadcasting works for both scalar and array R
W = jnp.where(
kR < 1e-3,
1.0 - kR**2 / 10.0, # Small kR approximation
3.0 * (jnp.sin(kR) - kR * jnp.cos(kR)) / kR**3,
)
# Integrate: sigma^2 = (1/2pi^2) * int k^2 P(k) W^2(kR) dk
integrand = k**2 * pk * W**2
sigma2 = jnp.trapezoid(integrand, k, axis=-1) / (2 * jnp.pi**2)
return jnp.squeeze(jnp.sqrt(sigma2))
[docs]
def sigma_M(
M: ArrayLike,
z: ArrayLike,
cosmo: jc.Cosmology,
k_min: float = 1e-5,
k_max: float = 1e2,
n_k_int: int = 5000,
emu: SigmaMEmulator | None = None,
) -> Array:
"""Compute RMS variance of density fluctuations within the
Lagrangian radius of a halo with mass M at redshift z.
When ``emu`` is provided, the emulator is used instead of the
analytical integral and the ``k_min``, ``k_max``, ``n_k_int``
parameters are ignored.
Parameters
----------
M : Array
Mass [h-1 Msun]
z : Array
Redshift
cosmo : jc.Cosmology
Underlying cosmology
k_min : float
Minimum k for integration [h Mpc-1], default 1e-5
k_max : float
Maximum k for integration [h Mpc-1], default 1e2
n_k_int : int
Number of k-space integration points for :math:`\\sigma(R,z)`,
default 5000
emu : SigmaMEmulator, optional
Trained emulator for :math:`\\sigma(M)`. If provided, the
emulator is used instead of the analytical integral.
Returns
-------
Array
RMS variance :math:`\\sigma(M,z)`
See Also
--------
halox.emus.SigmaMEmulator
Emulator for :math:`\\sigma(M,z)`.
"""
M = jnp.asarray(M)
z = jnp.asarray(z)
if emu is not None:
return emu(M, z, cosmo)
R = mass_to_lagrangian_radius(M, cosmo)
return sigma_R(R, z, cosmo, k_min=k_min, k_max=k_max, n_k_int=n_k_int)
[docs]
def peak_height(
M: ArrayLike,
z: ArrayLike,
cosmo: jc.Cosmology,
n_k_int: int = 5000,
k_min: float = 1e-5,
k_max: float = 1e2,
delta_sc: float = 1.68647,
emu: SigmaMEmulator | None = None,
) -> Array:
"""Peak height :math:`\\nu = \\delta_{sc} / \\sigma(M, z)`.
Parameters
----------
M : Array
Mass [h-1 Msun]
z : Array
Redshift
cosmo : jc.Cosmology
Underlying cosmology
n_k_int : int
Number of k-space integration points for :math:`\\sigma(R,z)`,
default 5000
k_min : float
Minimum k for integration [h Mpc-1], default 1e-5
k_max : float
Maximum k for integration [h Mpc-1], default 1e2
delta_sc : float
Spherical collapse overdensity, default 1.68647
emu : SigmaMEmulator, optional
Trained emulator for :math:`\\sigma(M)`.
Returns
-------
Array
Peak height :math:`\\nu`
See Also
--------
halox.emus.SigmaMEmulator
Emulator for :math:`\\sigma(M,z)`.
"""
sigma = sigma_M(
M,
z,
cosmo,
k_min=k_min,
k_max=k_max,
n_k_int=n_k_int,
emu=emu,
)
return delta_sc / sigma