Other tutorials#
This notebook gathers small, self-contained tutorials demonstrating useful features of halox that do not warrant their own dedicated notebook.
from halox import hmf, cosmology, lss, nfw
import jax
import jax.numpy as jnp
import jax_cosmo as jc
/Users/fkeruzore/Software/halox/.venv/lib/python3.12/site-packages/jax_cosmo/__init__.py:2: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
from pkg_resources import DistributionNotFound
Vectorizing over cosmologies#
Most functions in halox support dependence on cosmological parameters by taking as input a jax_cosmo.Cosmology object.
In certain workflows, it may be useful to use the vectorization capabilities of JAX (e.g., jax.vmap or jax.lax.map) to run computations over different cosmologies in parallel and without having to rely on Python loops.
To do so, you can use the halox.cosmology.stack_cosmologies() function.
First, we define a list of individual cosmologies as jax_cosmo.Cosmology objects.
cosmo_p18 = cosmology.Planck18()
cosmo_p15 = jc.Planck15()
cosmo_p18_high_sigma8 = cosmology.Planck18(sigma8=0.9)
cosmo_0703 = jc.Cosmology(
Omega_b=0.05,
Omega_c=0.25,
h=0.7,
sigma8=0.8,
n_s=0.9665,
Omega_k=0,
w0=-1,
wa=0,
)
cosmo_list = [cosmo_p18, cosmo_p15, cosmo_p18_high_sigma8, cosmo_0703]
Unfortunately, JAX vectorization tools can’t iterate through simple Python lists; therefore we can’t use cosmo_list in vmap.
For example, computing the halo mass function on an array of masses:
m, z = jnp.logspace(13, 14, 8), 0.0
try:
hmfs = jax.vmap(hmf.tinker08_mass_function, in_axes=[None, None, 0])(
m, z, cosmo_list
)
print(hmfs.shape) # should be (4, 8) (four cosmologies, eight masses)
except ValueError as e:
print("Error:", e)
Error: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
Instead, we can stack them, creating a single jax_cosmo.Cosmology object containing arrays of cosmological parameters:
cosmo_stack = cosmology.stack_cosmologies(cosmo_list)
cosmo_stack
Cosmological parameters:
h: [0.6766 0.6774 0.6766 0.7 ]
Omega_b: [0.04897 0.0486 0.04897 0.05 ]
Omega_c: [0.26067 0.2589 0.26067 0.25 ]
Omega_k: [0. 0. 0. 0.]
w0: [-1. -1. -1. -1.]
wa: [0. 0. 0. 0.]
n: [0.9665 0.9667 0.9665 0.9665]
sigma8: [0.8102 0.8159 0.9 0.8 ]
try:
hmfs = jax.vmap(hmf.tinker08_mass_function, in_axes=[None, None, 0])(
m, z, cosmo_stack
)
print(hmfs.shape) # should be (4, 8) (four cosmologies, eight masses)
except ValueError as e:
print("Error:", e)
(4, 8)
This can be used to vectorize over pre-defined cosmologies for any cosmology-dependent function in halox.
Converting between mean and critical overdensities#
Halo masses are commonly defined with respect to a spherical overdensity threshold \(\Delta\), which can itself be expressed relative to either the critical density of the universe (\(\Delta_c\), e.g. \(M_{200c}\)) or the mean matter density (\(\Delta_m\), e.g. \(M_{200m}\)). The two are related through the matter density parameter at redshift \(z\):
The functions lss.overdensity_c_to_m and lss.overdensity_m_to_c perform these conversions for a given cosmology and redshift.
cosmo = cosmology.Planck18()
# 200c to m
delta_m = lss.overdensity_c_to_m(200.0, 0.0, cosmo)
print(f"Δ_c = 200 at z=0 corresponds to Δ_m = {delta_m:.2f}")
# 200m to c
delta_c = lss.overdensity_m_to_c(200.0, 0.0, cosmo)
print(f"Δ_m = 200 at z=0 corresponds to Δ_c = {delta_c:.2f}")
Δ_c = 200 at z=0 corresponds to Δ_m = 645.91
Δ_m = 200 at z=0 corresponds to Δ_c = 61.93
This is useful for switching mass definitions.
While halox only explicitly works with critical densities, you can convert mean to critical or critical to mean in that way.
For example, for halo properties:
m200c, c200c, z = 3e14, 5.0, 0.1
halo_200c = nfw.NFWHalo(m200c, c200c, z, delta=200.0, cosmo=cosmo)
delta_m = lss.overdensity_c_to_m(200.0, z, cosmo)
m200m, r200m, c200m = halo_200c.to_delta(delta_m)
print(f"At {z=}: {m200c=:.2e} -> {m200m=:.2e}")
At z=0.1: m200c=3.00e+14 -> m200m=2.11e+14
Or to instantiate a halo with properties with a \(200m\) overdensity:
m200m, c200m, z = 1e14, 4.0, 0.1
delta_c = lss.overdensity_m_to_c(200.0, z, cosmo)
halo_200m = nfw.NFWHalo(m200m, c200m, z, delta=delta_c, cosmo=cosmo)