Using JAX automatic differentiation#
Because all functions in halox are implemented in terms of JAX primitive expressions (or rely on jax-cosmo functions, which are also fully JAX-friendly), they can be automatically differentiated with respect to their inputs.
This notebook provides a few examples of how this can be used in practice, and will be updated with new examples as halox keeps receiving new features.
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import halox
import jax_cosmo as jc
from matplotlib_inline.backend_inline import set_matplotlib_formats
jax.config.update("jax_enable_x64", True)
plt.style.use(["seaborn-v0_8-darkgrid", "petroff10"])
plt.rcParams.update({"xtick.direction": "in", "ytick.direction": "in"})
set_matplotlib_formats("svg")
/Users/fkeruzore/Software/halox/.venv/lib/python3.12/site-packages/jax_cosmo/__init__.py:2: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
from pkg_resources import DistributionNotFound
Differentiating with respect to redshift#
cosmo = halox.cosmology.Planck18()
def H(z):
return halox.cosmology.hubble_parameter(z, cosmo)
H_and_grad = jax.value_and_grad(H)
zs = jnp.linspace(0, 2, 16)
Hs, dHdzs = jax.vmap(H_and_grad)(zs)
fig, axs = plt.subplots(2, 1)
axs[0].plot(zs, Hs, lw=2)
axs[0].set_ylabel(r"$H(z)$ [${\rm km} \, {\rm s}^{-1} \, {\rm Mpc}^{-1}$]")
axs[0].set_xticklabels([])
axs[1].plot(zs, dHdzs, lw=2)
axs[1].set_ylabel(
r"$\left( {\rm d}H / {\rm d}z \right) (z)$"
+ r"[${\rm km} \, {\rm s}^{-1} \, {\rm Mpc}^{-1}$]"
)
axs[1].set_xlabel("$z$")
for ax in axs:
ax.xaxis.set_ticks_position("both")
ax.yaxis.set_ticks_position("both")
ax.grid(True, alpha=1.0)
def rho_c(z):
return halox.cosmology.critical_density(z, cosmo)
rho_and_grad = jax.value_and_grad(rho_c)
zs = jnp.linspace(0, 2, 16)
rhos, drhodzs = jax.vmap(rho_and_grad)(zs)
fig, axs = plt.subplots(2, 1)
axs[0].plot(zs, rhos, lw=2)
axs[0].set_ylabel(r"$\rho_c(z)$ [$h^2 \, M_\odot \, {\rm Mpc}^{-3}$]")
axs[0].set_xticklabels([])
axs[1].plot(zs, drhodzs, lw=2)
axs[1].set_ylabel(
r"$\left( {\rm d}\rho_c / {\rm d}z \right) (z)$"
+ r"[$h^2 \, M_\odot \, {\rm Mpc}^{-3}$]"
)
axs[1].set_xlabel("$z$")
for ax in axs:
ax.xaxis.set_ticks_position("both")
ax.yaxis.set_ticks_position("both")
ax.grid(True, alpha=1.0)
Differentiating with respect to cosmology#
def H_1(w0):
cosmo = jc.Planck15(w0=w0)
return halox.cosmology.hubble_parameter(1.0, cosmo)
H_and_grad = jax.value_and_grad(H_1)
w0s = jnp.array([-1.2, -1.1, -1.0])
Hs, dHdws = jax.vmap(H_and_grad)(w0s)
print(f"w_0 = {w0s}")
print(f"H(z=1) = {Hs} [km s-1 Mpc-1]")
print(f"dH_1/dw_0 = {dHdws} [km s-1 Mpc-1]")
w_0 = [-1.2 -1.1 -1. ]
H(z=1) = [115.6922973 117.76798759 120.27427285] [km s-1 Mpc-1]
dH_1/dw_0 = [18.84102128 22.78718246 27.4697143 ] [km s-1 Mpc-1]
def hmf_1e14_0p5(sigma8):
cosmo = jc.Planck15(sigma8=sigma8)
return halox.hmf.tinker08_mass_function(1e14, 0.5, cosmo, delta_c=200.0)
hmf_and_grad = jax.value_and_grad(hmf_1e14_0p5)
sigma8s = jnp.array([0.8, 0.85, 0.9])
hmfs, dhmfdsigma8s = jax.vmap(hmf_and_grad)(sigma8s)
print(f"sigma_8 = {sigma8s}")
print(f"dN/dlnM(M200c=1e14 h-1 Msun, z=0.5) = {hmfs} [h3 Mpc-3]")
print(
"d(dN/dlnM(M200c=1e14 h-1 Msun, z=0.5))/dsigma_8 "
f"= {dhmfdsigma8s} [h3 Mpc-3]"
)
sigma_8 = [0.8 0.85 0.9 ]
dN/dlnM(M200c=1e14 h-1 Msun, z=0.5) = [1.62898095e-05 2.03306704e-05 2.43123380e-05] [h3 Mpc-3]
d(dN/dlnM(M200c=1e14 h-1 Msun, z=0.5))/dsigma_8 = [8.04095582e-05 8.06890205e-05 7.81862040e-05] [h3 Mpc-3]
Autodiff vs. finite differences#
The examples above use JAX’s automatic differentiation to compute exact derivatives. An alternative is to approximate derivatives with finite differences:
\[\frac{\partial f}{\partial x} \approx \frac{f(x + \epsilon) - f(x - \epsilon)}{2\epsilon}\]
This is simple but introduces truncation error (large \(\epsilon\)) or numerical noise (small \(\epsilon\)). Below we compare both approaches on the Tinker08 halo mass function \(dn/d\ln M\).
cosmo = halox.cosmology.Planck18()
# --- Derivative with respect to mass ---
def hmf_of_M(M):
return halox.hmf.tinker08_mass_function(M, 0.5, cosmo, delta_c=200.0)
dhmf_dM_auto = jax.vmap(jax.grad(hmf_of_M))
Ms = jnp.geomspace(1e14, 1e15, 16)
dhmf_dM_ad = dhmf_dM_auto(Ms).block_until_ready()
eps_M = 1e13
dhmf_dM_fd = jax.vmap(
lambda M: (hmf_of_M(M + eps_M) - hmf_of_M(M - eps_M)) / (2 * eps_M)
)(Ms).block_until_ready()
# --- Derivative with respect to redshift ---
def hmf_of_z(z):
return halox.hmf.tinker08_mass_function(1e14, z, cosmo, delta_c=200.0)
dhmf_dz_auto = jax.vmap(jax.grad(hmf_of_z))
zs = jnp.linspace(0.0, 2.0, 16)
dhmf_dz_ad = dhmf_dz_auto(zs).block_until_ready()
eps_z = 1e-2
dhmf_dz_fd = jax.vmap(
lambda z: (hmf_of_z(z + eps_z) - hmf_of_z(z - eps_z)) / (2 * eps_z)
)(zs).block_until_ready()
fig, axs = plt.subplots(2, 1)
axs[0].plot(Ms, dhmf_dM_ad, lw=2, label="Autodiff")
axs[0].plot(Ms, dhmf_dM_fd, ls="--", lw=2, label="Finite differences")
axs[0].set_xscale("log")
axs[0].set_xlabel(r"$M$ [$h^{-1} \, M_\odot$]")
axs[0].set_ylabel(
r"$\partial (dn/d\ln M) / \partial M$"
r" [$h^4 \, M_\odot^{-1} \, {\rm Mpc}^{-3}$]"
)
axs[0].set_title(r"$z = 0.5$")
axs[0].legend()
axs[1].plot(zs, dhmf_dz_ad, lw=2, label="Autodiff")
axs[1].plot(zs, dhmf_dz_fd, ls="--", lw=2, label="Finite differences")
axs[1].set_xlabel(r"$z$")
axs[1].set_ylabel(
r"$\partial (dn/d\ln M) / \partial z$"
r" [$h^3 \, {\rm Mpc}^{-3}$]"
)
axs[1].set_title(r"$M = 10^{14} \, h^{-1} \, M_\odot$")
axs[1].legend()
for ax in axs:
ax.xaxis.set_ticks_position("both")
ax.yaxis.set_ticks_position("both")
ax.grid(True, alpha=1.0)
fig.align_labels(axs)
fig.tight_layout()
# --- Derivative with respect to Omega_c ---
def hmf_of_Omega_c(Omega_c):
cosmo = jc.Planck15(Omega_c=Omega_c)
return halox.hmf.tinker08_mass_function(1e14, 0.5, cosmo, delta_c=200.0)
dhmf_dOc_auto = jax.vmap(jax.grad(hmf_of_Omega_c))
Omega_cs = jnp.linspace(0.20, 0.25, 16)
dhmf_dOc_ad = dhmf_dOc_auto(Omega_cs)
eps_Oc = 1e-2
dhmf_dOc_fd = jax.vmap(
lambda Oc: (hmf_of_Omega_c(Oc + eps_Oc) - hmf_of_Omega_c(Oc - eps_Oc))
/ (2 * eps_Oc)
)(Omega_cs)
# --- Derivative with respect to sigma_8 ---
def hmf_of_sigma8(sigma8):
cosmo = jc.Planck15(sigma8=sigma8)
return halox.hmf.tinker08_mass_function(1e14, 0.5, cosmo, delta_c=200.0)
dhmf_ds8_auto = jax.vmap(jax.grad(hmf_of_sigma8))
sigma8s = jnp.linspace(0.7, 1.0, 16)
dhmf_ds8_ad = dhmf_ds8_auto(sigma8s)
eps_s8 = 1e-2
dhmf_ds8_fd = jax.vmap(
lambda s8: (hmf_of_sigma8(s8 + eps_s8) - hmf_of_sigma8(s8 - eps_s8))
/ (2 * eps_s8)
)(sigma8s)
fig, axs = plt.subplots(2, 1)
axs[0].plot(Omega_cs, dhmf_dOc_ad, lw=2, label="Autodiff")
axs[0].plot(Omega_cs, dhmf_dOc_fd, ls="--", lw=2, label="Finite differences")
axs[0].set_xlabel(r"$\Omega_c$")
axs[0].set_ylabel(
r"$\partial (dn/d\ln M) / \partial \Omega_c$"
r" [$h^3 \, {\rm Mpc}^{-3}$]"
)
axs[0].set_title(r"$M = 10^{14} \, h^{-1} \, M_\odot, \; z = 0.5$")
axs[0].legend()
axs[1].plot(sigma8s, dhmf_ds8_ad, lw=2, label="Autodiff")
axs[1].plot(sigma8s, dhmf_ds8_fd, ls="--", lw=2, label="Finite differences")
axs[1].set_xlabel(r"$\sigma_8$")
axs[1].set_ylabel(
r"$\partial (dn/d\ln M) / \partial \sigma_8$"
r" [$h^3 \, {\rm Mpc}^{-3}$]"
)
axs[1].set_title(r"$M = 10^{14} \, h^{-1} \, M_\odot, \; z = 0.5$")
axs[1].legend()
for ax in axs:
ax.xaxis.set_ticks_position("both")
ax.yaxis.set_ticks_position("both")
ax.grid(True, alpha=1.0)
fig.align_labels(axs)
fig.tight_layout()