Source code for halox.halo.nfw

from jax import Array
import jax
from jax.typing import ArrayLike
import jax.numpy as jnp
import jax_cosmo as jc
from jaxopt import LBFGSB

from ..cosmology import G
from .. import cosmology


[docs] class NFWHalo: """ Properties of a dark matter halo following a Navarro-Frenk-White density profile. Parameters ---------- m_delta: float Mass at overdensity `delta` [h-1 Msun] c_delta: float Concentration at overdensity `delta` z: float Redshift cosmo: jc.Cosmology Underlying cosmology delta: float Density contrast in units of critical density at redshift z, defaults to 200. """ def __init__( self, m_delta: ArrayLike, c_delta: ArrayLike, z: ArrayLike, cosmo: jc.Cosmology, delta: float = 200.0, ): self.m_delta = jnp.asarray(m_delta) self.c_delta = jnp.asarray(c_delta) self.z = jnp.asarray(z) self.delta = delta self.cosmo = cosmo mean_rho = delta * cosmology.critical_density(self.z, cosmo) self.r_delta = (3 * self.m_delta / (4 * jnp.pi * mean_rho)) ** (1 / 3) self.Rs = self.r_delta / self.c_delta rho0_denum = 4 * jnp.pi * self.Rs**3 rho0_denum *= jnp.log(1 + self.c_delta) - self.c_delta / ( 1 + self.c_delta ) self.rho0 = self.m_delta / rho0_denum
[docs] def density(self, r: ArrayLike) -> Array: """NFW density profile :math:`\\rho(r)`. Parameters ---------- r : Array [h-1 Mpc] Radius Returns ------- Array [h2 Msun Mpc-3] Density at radius `r` """ r = jnp.asarray(r) return self.rho0 / (r / self.Rs * (1 + r / self.Rs) ** 2)
[docs] def enclosed_mass(self, r: ArrayLike) -> Array: """Enclosed mass profile :math:`M(<r)`. Parameters ---------- r : Array [h-1 Mpc] Radius Returns ------- Array [h-1 Msun] Enclosed mass at radius `r` """ r = jnp.asarray(r) prefact = 4 * jnp.pi * self.rho0 * self.Rs**3 return prefact * (jnp.log(1 + r / self.Rs) - r / (r + self.Rs))
[docs] def potential(self, r: ArrayLike) -> Array: """Potential profile :math:`\\phi(r)`. Parameters ---------- r : Array [h-1 Mpc] Radius Returns ------- Array [km2 s-2] Potential at radius `r` """ r = jnp.asarray(r) # G = G.to("km2 Mpc Msun-1 s-2").value prefact = -4 * jnp.pi * G * self.rho0 * self.Rs**3 return prefact * jnp.log(1 + r / self.Rs) / r
[docs] def circular_velocity(self, r: ArrayLike) -> Array: """Circular velocity profile :math:`v_c(r)`. The circular velocity is related to the enclosed mass by: :math:`v_c^2(r) = GM(<r)/r` Parameters ---------- r : Array [h-1 Mpc] Radius Returns ------- Array [km s-1] Circular velocity at radius `r` """ r = jnp.asarray(r) m_enc = self.enclosed_mass(r) return jnp.sqrt(G * m_enc / r)
[docs] def velocity_dispersion(self, r: ArrayLike) -> Array: """Radial velocity dispersion profile :math:`\\sigma_r(r)`. Uses the Jeans equation assuming isotropic orbits: :math:`\\sigma_r^2(r) = \\frac{1}{\\rho(r)} \\int_r^{\\infty} \\rho(s) \\frac{GM(<s)}{s^2} ds` For NFW halos, this has an analytical solution. Parameters ---------- r : Array [h-1 Mpc] Radius Returns ------- Array [km s-1] Radial velocity dispersion at radius `r` """ r = jnp.asarray(r) x = r / self.Rs # Analytical solution for NFW velocity dispersion # From Lokas & Mamon 2001, Eq. 16 g_x = (jnp.log(1 + x) - x / (1 + x)) / x**2 # Factor involving concentration c = self.c_delta gc = jnp.log(1 + c) - c / (1 + c) # Velocity dispersion squared sigma_r2 = ( G * self.m_delta * gc * g_x / (self.r_delta * x * (1 + x) ** 2) ) return jnp.sqrt(sigma_r2)
[docs] def surface_density(self, r: ArrayLike) -> Array: """Projected surface density profile :math:`\\Sigma(r)`. The projected surface density is obtained by integrating the 3D density profile along the line of sight: :math:`\\Sigma(r) = 2 \\int_r^{\\infty} \\frac{\\rho(s) s ds} {\\sqrt{s^2 - r^2}}` For NFW halos, this has an analytical solution. Parameters ---------- r : Array [h-1 Mpc] Projected radius Returns ------- Array [h Msun Mpc-2] Surface density at projected radius `r` """ r = jnp.asarray(r) x = r / self.Rs # Analytical solution for NFW surface density # From Bartelmann 1996, Eq. 13 prefact = 2 * self.rho0 * self.Rs # Handle different regimes for numerical stability def f(x): return jnp.where( x < 1.0, # x < 1 case ( 1 - 2 * jnp.arctanh(jnp.sqrt((1 - x) / (1 + x))) / jnp.sqrt(1 - x**2) ) / (x**2 - 1), jnp.where( x > 1.0, # x > 1 case ( 1 - 2 * jnp.arctan(jnp.sqrt((x - 1) / (1 + x))) / jnp.sqrt(x**2 - 1) ) / (x**2 - 1), # x = 1 case 1.0 / 3.0, ), ) return prefact * f(x)
[docs] def to_delta(self, delta_new: float) -> tuple[Array, Array, Array]: """Convert halo properties to a different overdensity definition. Parameters ---------- delta_new : float New density contrast in units of critical density at redshift z Returns ------- Array [h-1 Msun] Mass at new overdensity Array [h-1 Mpc] Radius at new overdensity Array Concentration at new overdensity """ # Target density for the new overdensity definition rho_c = cosmology.critical_density(self.z, self.cosmo) target_density = delta_new * rho_c # Normalized objective function (critical for numerical stability) def lsq(r_new): m_enc = self.enclosed_mass(r_new[0]) mean_density = m_enc / (4.0 * jnp.pi * r_new[0] ** 3 / 3.0) # Normalize by target_density to get dimensionless objective return ((mean_density - target_density) / target_density) ** 2 # Initial guess based on scaling relation r0 = jnp.array([self.r_delta * (self.delta / delta_new) ** (1 / 3)]) # Bounds for the optimization lower = jnp.array([0.01 * self.r_delta]) upper = jnp.array([10.0 * self.r_delta]) bounds = (lower, upper) # Use jaxopt LBFGSB optimizer optimizer = LBFGSB(fun=lsq, tol=1e-12) result = optimizer.run(r0, bounds=bounds) r_new = result.params[0] # Calculate new mass and concentration m_new = self.enclosed_mass(r_new) c_new = r_new / self.Rs return m_new, r_new, c_new
[docs] def delta_delta( M: ArrayLike, # current solution for other halos, probably not the fastest # but we can optimize this later c: ArrayLike, z: ArrayLike, cosmo: jc.Cosmology, delta_old: float, delta_new: float, ) -> tuple[Array, Array, Array]: """Convert between overdensity definitions assuming an NFW profile. Parameters ---------- M : Array Halo mass at ``delta_old`` overdensity [h-1 Msun] c : Array Concentration at ``delta_old`` overdensity z : Array Redshift cosmo : jc.Cosmology Underlying cosmology delta_old : float Input overdensity in units of critical density at redshift z delta_new : float Output overdensity in units of critical density at redshift z Returns ------- Array [h-1 Msun] Halo mass at ``delta_new`` overdensity Array [h-1 Mpc] Halo radius at ``delta_new`` overdensity Array Concentration at ``delta_new`` overdensity """ M = jnp.atleast_1d(M) c = jnp.atleast_1d(c) z = jnp.atleast_1d(z) def single_halo(Mi, ci, zi): halo = NFWHalo(Mi, ci, zi, cosmo, delta_old) return halo.to_delta(delta_new) # Vectorize over halo index M_new, R_new, c_new = jax.vmap(single_halo)(M, c, z) return (jnp.squeeze(M_new), jnp.squeeze(R_new), jnp.squeeze(c_new))