Large-Scale Structure quantities#

This notebook demonstrates the large-scale structure (LSS) calculations available in the halox library. We’ll explore RMS variance calculations, mass-to-radius conversions, and overdensity transformations that form the foundation of halo mass function calculations. For a full documentation of the module API, see halox.lss: Large-scale structure calculations.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats

from halox import cosmology, lss

jax.config.update("jax_enable_x64", True)

plt.style.use(["seaborn-v0_8-darkgrid", "petroff10"])
plt.rcParams.update({"xtick.direction": "in", "ytick.direction": "in"})
set_matplotlib_formats("svg")
/Users/fkeruzore/Software/halox/.venv/lib/python3.12/site-packages/jax_cosmo/__init__.py:2: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import DistributionNotFound

Setting up the Cosmology#

First, let’s create a cosmology object using the Planck 2018 parameters provided by halox:

# Create a Planck 2018 cosmology
cosmo = cosmology.Planck18()
print(f"Hubble parameter h = {cosmo.h}")
print(f"Matter density Ω_m = {cosmo.Omega_m}")
print(f"Baryon density Ω_b = {cosmo.Omega_b}")
print(f"Cold dark matter density Ω_c = {cosmo.Omega_c}")
print(f"Power spectrum normalization σ_8 = {cosmo.sigma8}")
Hubble parameter h = 0.6766
Matter density Ω_m = 0.30964
Baryon density Ω_b = 0.04897
Cold dark matter density Ω_c = 0.26067
Power spectrum normalization σ_8 = 0.8102

RMS Variance Theory#

The RMS variance of density fluctuations in spheres of radius R is given by:

\[\sigma^2(R,z) = \frac{1}{2\pi^2} \int_0^\infty k^2 P(k,z) W^2(kR) dk\]

where:

  • \(P(k,z)\) is the linear matter power spectrum at redshift z

  • \(W(kR)\) is the spherical top-hat window function:

\[W(x) = \frac{3(\sin x - x \cos x)}{x^3}\]

The variance σ(M,z) is obtained by converting mass M to its corresponding Lagrangian radius.

Variance as a Function of Radius#

Let’s also examine σ(R,z) directly as a function of radius:

# Radius range from 0.1 to 100 h^-1 Mpc
R = jnp.logspace(-1, 2, 100)
z = 0.0

# Compute sigma(R) at z=0
sigma_R_z0 = lss.sigma_R(R, z, cosmo)

print(f"σ(R) computed for {len(R)} radius bins at z={z}")
print(f"Radius range: {R.min():.2f} to {R.max():.2f} h^-1 Mpc")
print(f"σ(R) range: {sigma_R_z0.min():.3f} to {sigma_R_z0.max():.3f}")
σ(R) computed for 100 radius bins at z=0.0
Radius range: 0.10 to 100.00 h^-1 Mpc
σ(R) range: 0.066 to 5.139
# Plot sigma(R) at z=0
fig, ax = plt.subplots()
ax.loglog(
    R, sigma_R_z0, linewidth=2, color="C0", label=f"$\\sigma(R, z={z:.1f})$"
)

# Add horizontal line at σ_8 and vertical line at R_8
ax.axhline(
    cosmo.sigma8,
    color="C1",
    linestyle="--",
    alpha=0.7,
    label=f"$\\sigma_8 = {cosmo.sigma8:.3f}$",
)
ax.axvline(8.0, color="C1", linestyle="--", alpha=0.7)

# Annotate R_8
ax.annotate(
    "$\sigma(R=8 \\; {\\rm Mpc}/h) = \\sigma_8$",
    (8.0, cosmo.sigma8),
    xytext=(20, 20),
    textcoords="offset points",
    arrowprops=dict(arrowstyle="->", color="C1"),
    fontsize=10,
)

ax.set_xlabel(r"Radius $R$ [$h^{-1} \, {\rm Mpc}$]")
ax.set_ylabel(r"RMS Variance $\sigma(R,z)$")
ax.set_title("RMS Variance as a Function of Radius")
ax.set_xlim(0.05, 200)
ax.set_ylim(0.01, 20)
ax.legend()
ax.xaxis.set_ticks_position("both")
ax.yaxis.set_ticks_position("both")
ax.grid(True, alpha=1.0)
<>:19: SyntaxWarning: invalid escape sequence '\s'
<>:19: SyntaxWarning: invalid escape sequence '\s'
/var/folders/3t/_vtqsjsx3hq7ktlc27qg58c00000gr/T/ipykernel_25011/2675430585.py:19: SyntaxWarning: invalid escape sequence '\s'
  "$\sigma(R=8 \\; {\\rm Mpc}/h) = \\sigma_8$",
../_images/e200be2dc844494269a7ce2d9e1ed9be8830bcdc55fd3819209716ed809e86c4.svg

Variance as a Function of Mass#

# Compute sigma(M) at z=0
M = jnp.logspace(8, 16, 100)
z = 0.0
sigma_M_z0 = lss.sigma_M(M, z, cosmo)

print(f"σ(M) computed for {len(M)} mass bins at z={z}")
print(f"σ(M) range: {sigma_M_z0.min():.3f} to {sigma_M_z0.max():.3f}")
σ(M) computed for 100 mass bins at z=0.0
σ(M) range: 0.267 to 5.721
# Plot sigma(M) at z=0
fig, ax = plt.subplots()
ax.loglog(
    M, sigma_M_z0, linewidth=2, color="C0", label=f"$\\sigma(M, z={z:.1f})$"
)

ax.set_xlabel("Mass $M$ [$h^{-1} \\, M_\\odot$]")
ax.set_ylabel("RMS Variance $\\sigma(M,z)$")
ax.set_title("RMS Variance of Density Fluctuations")
ax.set_xlim(5e7, 2e16)
ax.set_ylim(0.01, 10)
ax.legend()
ax.xaxis.set_ticks_position("both")
ax.yaxis.set_ticks_position("both")
ax.grid(True, alpha=1.0)
../_images/b220cc655a39d3fe3ab72954d383f51bf61d2d84e5dabec4e129576936ede535.svg

Different Cosmologies#

Let’s compare σ(M) for different cosmological parameters. We’ll vary the matter density and power spectrum normalization:

# Create different cosmologies by modifying Planck18 parameters
cosmo_low_Om = cosmology.Planck18(Omega_c=0.20)
cosmo_high_Om = cosmology.Planck18(Omega_c=0.30)
cosmo_low_s8 = cosmology.Planck18(sigma8=0.7)
cosmo_high_s8 = cosmology.Planck18(sigma8=0.9)

cosmologies = [cosmo_low_Om, cosmo_high_Om, cosmo_low_s8, cosmo_high_s8]
labels = [
    r"$\Omega_c = 0.20$",
    r"$\Omega_c = 0.30$",
    r"$\sigma_8 = 0.7$",
    r"$\sigma_8 = 0.9$",
]
colors = ["C1", "C2", "C3", "C4"]

fig, ax = plt.subplots()

# Plot reference Planck18 cosmology
sigma_M_ref = lss.sigma_M(M, z, cosmo)
ax.loglog(
    M,
    sigma_M_ref,
    linewidth=2,
    color="C0",
    label="Planck18 (reference)",
    linestyle="-",
)

# Plot different cosmologies
for _, (cosmo_test, label, color) in enumerate(
    zip(cosmologies, labels, colors)
):
    sigma_M_test = lss.sigma_M(M, z, cosmo_test)
    ax.loglog(
        M, sigma_M_test, linewidth=2, color=color, label=label, linestyle="--"
    )

ax.set_xlabel(r"Mass $M$ [$h^{-1} \, M_\odot$]")
ax.set_ylabel(r"RMS Variance $\sigma(M,z)$")
ax.set_title(f"RMS Variance for Different Cosmologies ($z={z:.1f}$)")
ax.set_xlim(5e7, 2e16)
ax.set_ylim(0.01, 10)
ax.legend()
ax.xaxis.set_ticks_position("both")
ax.yaxis.set_ticks_position("both")
ax.grid(True, alpha=1.0)
../_images/5f5d91a3f7fe295bea6196d2ef29f16314f045aef1853f862bdde5824349d65a.svg

Redshift Evolution#

Now let’s examine how σ(M,z) evolves with redshift. The variance decreases with increasing redshift due to the suppression of structure growth:

# Different redshifts
redshifts = [0.0, 0.5, 1.0, 2.0]
colors = ["C0", "C1", "C2", "C3"]
linestyles = ["-", "--", "-.", ":"]

fig, ax = plt.subplots()

for z_val, color, ls in zip(redshifts, colors, linestyles):
    sigma_M_z = lss.sigma_M(M, z_val, cosmo)
    ax.loglog(
        M,
        sigma_M_z,
        linewidth=2.5,
        color=color,
        linestyle=ls,
        label=f"$z = {z_val:.1f}$",
    )

ax.set_xlabel(r"Mass $M$ [$h^{-1} \, M_\odot$]")
ax.set_ylabel(r"RMS Variance $\sigma(M,z)$")
ax.set_title("RMS Variance Evolution with Redshift")
ax.set_xlim(5e7, 2e16)
ax.set_ylim(0.01, 10)
ax.legend()
ax.xaxis.set_ticks_position("both")
ax.yaxis.set_ticks_position("both")
ax.grid(True, alpha=1.0)
../_images/deef584f84ef7ab57198260a764118e680180d96bb9eb4c594931e53a925cc89.svg

Overdensity Conversion#

In halo mass function calculations, we often need to convert between overdensities defined relative to the critical density (Δc) and those defined relative to the mean matter density (Δm):

\[\Delta_m = \Delta_c \frac{\rho_c(z)}{\rho_m(z)}\]

where ρc(z) is the critical density and ρm(z) is the mean matter density at redshift z.

# Common overdensity definitions
delta_c_values = [200.0, 500.0, 2500.0]
redshifts_conv = jnp.linspace(0, 3, 50)

colors = ["C0", "C1", "C2"]
linestyles = ["-", "--", "-."]

fig, ax = plt.subplots()

for delta_c, color, ls in zip(delta_c_values, colors, linestyles):
    delta_m = lss.overdensity_c_to_m(delta_c, redshifts_conv, cosmo)
    ax.plot(
        redshifts_conv,
        delta_m,
        linewidth=2.5,
        color=color,
        linestyle=ls,
        label=f"$\\Delta_c = {delta_c:.0f}$",
    )

ax.set_xlabel("Redshift $z$")
ax.set_ylabel(r"Mean Overdensity $\Delta_m$")
ax.set_title("Critical to Mean Overdensity Conversion")
ax.set_xlim(0, 3)
ax.legend()
ax.xaxis.set_ticks_position("both")
ax.yaxis.set_ticks_position("both")
ax.grid(True, alpha=1.0)

# Show values at z=0 and z=1
print("Overdensity conversions:")
for delta_c in delta_c_values:
    delta_m_z0 = lss.overdensity_c_to_m(delta_c, 0.0, cosmo)
    delta_m_z1 = lss.overdensity_c_to_m(delta_c, 1.0, cosmo)
    print(
        f"Δc = {delta_c:.0f}:",
        f"Δm = {delta_m_z0:.0f} (z=0), {delta_m_z1:.0f} (z=1)",
    )
Overdensity conversions:
Δc = 200: Δm = 646 (z=0), 256 (z=1)
Δc = 500: Δm = 1615 (z=0), 639 (z=1)
Δc = 2500: Δm = 8074 (z=0), 3197 (z=1)
../_images/9e5d4679c55758552bdb9ae6c7dbd8423f640c2174c46fe4612d7c44c68eb450.svg

Vectorization#

The LSS functions support vectorized operations, allowing efficient computation for multiple masses and redshifts:

# Vectorized computation for multiple redshifts
M_vec = jnp.logspace(10, 16, 50)  # Mass range
z_vec = jnp.array([0.0, 0.5, 1.0, 1.5, 2.0])

# Use vmap to vectorize over redshift
sigma_M_vec = jax.vmap(lss.sigma_M, in_axes=[None, 0, None])(
    M_vec, z_vec, cosmo
)

print(f"Computed σ(M,z) for {len(M_vec)} masses and {len(z_vec)} redshifts")
print(f"Result shape: {sigma_M_vec.shape}")

# Plot the results
fig, ax = plt.subplots()

for i, z_val in enumerate(z_vec):
    ax.loglog(
        M_vec,
        sigma_M_vec[i],
        linewidth=2,
        label=f"$z = {z_val:.1f}$",
        alpha=0.8,
    )

ax.set_xlabel(r"Mass $M$ [$h^{-1} \, M_\odot$]")
ax.set_ylabel(r"RMS Variance $\sigma(M,z)$")
ax.set_title("Vectorized σ(M,z) Calculation")
ax.set_xlim(5e9, 2e16)
ax.set_ylim(0.01, 10)
ax.legend()
ax.xaxis.set_ticks_position("both")
ax.yaxis.set_ticks_position("both")
ax.grid(True, alpha=1.0)
Computed σ(M,z) for 50 masses and 5 redshifts
Result shape: (5, 50)
../_images/568a2c4a58465fa9face6ef45f05b852358f2f37bbc5d8087114151079aee03a.svg

Performance: JIT Compilation#

Let’s compare the performance with and without JIT compilation:

# Setup for timing
M_timing = jnp.logspace(10, 16, 100)
z_timing = 0.0


def compute_sigma(M, z, cosmo=cosmo):
    return lss.sigma_M(M, z, cosmo)

Without JIT compilation:

%timeit _ = compute_sigma(M_timing, z_timing).block_until_ready()
605 ms ± 9.12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

With JIT:

compute_sigma_jit = jax.jit(compute_sigma)
_ = compute_sigma_jit(M_timing, z_timing).block_until_ready()
%timeit _ = compute_sigma_jit(M_timing, z_timing).block_until_ready()
2.77 ms ± 108 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)