from __future__ import annotations
import jax
from jax import Array
from jax.typing import ArrayLike
import jax.numpy as jnp
import jax_cosmo as jc
from . import cosmology, lss
from .emus import SigmaMEmulator
def _tinker08_parameters(
z: ArrayLike,
cosmo: jc.Cosmology,
delta_c: float = 200.0,
) -> Array:
"""Get Tinker08 mass function parameters for given overdensity.
Parameters
----------
z : Array
Redshift
cosmo : jc.Cosmology
Underlying cosmology
delta_c : float
Overdensity threshold, default 200.0
Returns
-------
Array
Parameters [A, a, b, c] for Tinker08 mass function
"""
# Table 2 from Tinker et al. 2008 - exact values
delta_vals = jnp.array(
[200.0, 300.0, 400.0, 600.0, 800.0, 1200.0, 1600.0, 2400.0, 3200.0]
)
A_vals = jnp.array(
[0.186, 0.200, 0.212, 0.218, 0.248, 0.255, 0.260, 0.260, 0.260]
)
a_vals = jnp.array([1.47, 1.52, 1.56, 1.61, 1.87, 2.13, 2.30, 2.53, 2.66])
b_vals = jnp.array([2.57, 2.25, 2.05, 1.87, 1.59, 1.51, 1.46, 1.44, 1.41])
c_vals = jnp.array([1.19, 1.27, 1.34, 1.45, 1.58, 1.80, 1.97, 2.24, 2.44])
z = jnp.asarray(z)
# Critical to mean overdensity
delta_m = lss.overdensity_c_to_m(delta_c, z, cosmo)
# Use linear interpolation in log space
A_0 = jnp.interp(delta_m, delta_vals, A_vals)
a_0 = jnp.interp(delta_m, delta_vals, a_vals)
b_0 = jnp.interp(delta_m, delta_vals, b_vals)
c_0 = jnp.interp(delta_m, delta_vals, c_vals)
# Apply redshift evolution
A_z = A_0 * (1.0 + z) ** (-0.14)
a_z = a_0 * (1.0 + z) ** (-0.06)
alpha = 10 ** (-1 * (0.75 / jnp.log10(delta_m / 75)) ** 1.2)
b_z = b_0 * (1.0 + z) ** (-alpha)
return jnp.array([A_z, a_z, b_z, c_0])
[docs]
def tinker08_mass_function(
M: ArrayLike,
z: ArrayLike,
cosmo: jc.Cosmology,
delta_c: float = 200.0,
n_k_int: int = 5000,
emu: SigmaMEmulator | None = None,
) -> Array:
"""Tinker08 halo mass function :math:`dn/d\\ln M`.
Parameters
----------
M : Array
Halo mass [h-1 Msun]
z : Array
Redshift
cosmo : jc.Cosmology
Underlying cosmology
delta_c : float
Overdensity threshold, default 200.0
n_k_int : int
Number of k-space integration points for :math:`\\sigma(R,z)`,
default 5000
emu : SigmaMEmulator, optional
Trained emulator for :math:`\\sigma(M)`.
Returns
-------
Array
Mass function [h3 Mpc-3]
See Also
--------
halox.emus.SigmaMEmulator
Emulator for :math:`\\sigma(M,z)`.
"""
M = jnp.atleast_1d(M)
z = jnp.asarray(z)
# Background density
rho_m = cosmo.Omega_m * cosmology.critical_density(0.0, cosmo)
def sigma_fn(M_scalar):
return lss.sigma_M(M_scalar, z, cosmo, n_k_int=n_k_int, emu=emu)
# Compute sigma(M) and dsigma/dM in a single autodiff pass per element
sigma, dsigma_dM = jax.vmap(jax.value_and_grad(sigma_fn))(M)
# Multiplicity function
A, a, b, c = _tinker08_parameters(z, cosmo, delta_c)
f_sigma = A * ((b / sigma) ** a + 1.0) * jnp.exp(-c / sigma**2)
# d ln(1/sigma) / dM = -d ln(sigma)/dM = -(1/sigma) dsigma/dM
d_ln_sigma_inv = -dsigma_dM / sigma
dn_dm = f_sigma * (rho_m / M) * d_ln_sigma_inv
return jnp.squeeze(M * dn_dm)