20  Cosmological Simulations II: The PM Algorithm

20.1 Recap: The Equations of Motion

In the previous lecture, we started from the collisionless Boltzmann (Vlasov) equation coupled to the Poisson equation — the Vlasov–Poisson system — which governs the evolution of collisionless dark matter in comoving coordinates \(\mathbf{x}\) (with \(\mathbf{r} = a(t)\mathbf{x}\)). Since the 6D phase space is too large to simulate directly, we Monte-Carlo sample the distribution function with \(N\) particles in a periodic box of comoving side \(L\), each carrying a mass \[ m_p = \Omega_m \bar{\rho} \frac{L^3}{N} \]

We then derived the equations of motion for each particle. Starting from the full Lagrangian \(L = \frac{1}{2}m|\dot{a}\mathbf{x} + a\dot{\mathbf{x}}|^2 - m\varphi\), we decomposed the gravitational potential as \(\varphi = \varphi_{\text{bg}} + \Phi\) where \(\varphi_{\text{bg}} = \frac{2\pi}{3}G\bar{\rho}a^2|\mathbf{x}|^2\). The background and cross terms cancel exactly (using the Friedmann equation \(\ddot{a}/a = -4\pi G \bar{\rho}/3\)), leaving the peculiar Lagrangian: \[ \mathcal{L}_{\text{pec}} = \frac{1}{2} m a^2 \dot{\mathbf{x}}^2 - m \Phi(\mathbf{x}, t) \] where \(\Phi\) is the peculiar potential satisfying the Poisson equation: \[ \nabla^2 \Phi = 4\pi G \bar{\rho} a^2 \delta \]

The conjugate momentum is \(\mathbf{p} = ma^2\dot{\mathbf{x}}\), and the Euler–Lagrange equation gives \[ \ddot{\mathbf{x}} + 2\frac{\dot{a}}{a}\dot{\mathbf{x}} = -\frac{1}{a^2}\nabla_x \Phi \] with the familiar Hubble drag term \(2H\dot{\mathbf{x}}\).

We also saw the total derivative trick: adding \(d/dt(\frac{1}{2}ma\dot{a}|\mathbf{x}|^2)\) to \(\mathcal{L}_{\text{pec}}\) gives \(\mathcal{L}'\) with conjugate momentum \(\mathbf{p}' = ma\dot{\mathbf{r}}\) (proportional to the total velocity), and the simpler equations \(d\mathbf{x}/dt = \mathbf{p}'/(ma^2)\), \(d\mathbf{p}'/dt = -m\nabla_x\Phi\) — no Hubble drag, but the momentum now includes the Hubble flow.

The task is now to integrate these equations efficiently and stably over cosmological timescales — which brings us to symplectic methods.

20.2 Symplectic Integration

Why Symplectic?

The equations of motion derived from \(\mathcal{L}_{\text{pec}}\) are Hamiltonian, and Hamiltonian systems have a fundamental geometric property: they preserve the symplectic form on phase space. Loosely speaking, this means that the flow of trajectories preserves phase-space volume (Liouville’s theorem) and the area of any canonical \((x_i, p_i)\) pair.

A generic ODE integrator (e.g., Runge–Kutta) does not respect this structure. Over many timesteps, such integrators introduce systematic drift in conserved quantities — the energy “random walks” away from its true value. For cosmological simulations, where we may take \(10^3\)\(10^4\) timesteps, this secular drift is unacceptable.

Symplectic integrators are constructed to exactly preserve the symplectic structure at each step. The key consequence: while the energy at each step has an \(\mathcal{O}(\Delta t^n)\) error (where \(n\) is the order of the integrator), this error is bounded — it oscillates rather than drifts. There is no secular growth.

ImportantNo Conserved Energy in an Expanding Universe

In cosmology, the Hamiltonian is explicitly time-dependent (through \(a(t)\)), so there is no conserved energy. Symplecticity is still valuable, however: it preserves phase-space structure, prevents artificial dissipation or heating of particle orbits, and maintains the correct topology of trajectories. The composition of symplectomorphisms is itself a symplectomorphism — this is the mathematical property that makes the whole approach work.

For testing purposes, there exists a Layzer–Irvine equation that predicts \(dE/dt\) given the expansion history — one can integrate this alongside the simulation as a consistency check. We will not derive it here.

The Splitting Method

Consider a Hamiltonian of the form \[ H = A(\mathbf{p}) + B(\mathbf{x}) \] Each piece alone generates a trivially solvable flow:

  • \(H = A(\mathbf{p})\): momenta are constant, positions evolve as \(\dot{\mathbf{x}} = \partial A / \partial \mathbf{p}\)
  • \(H = B(\mathbf{x})\): positions are constant, momenta evolve as \(\dot{\mathbf{p}} = -\partial B / \partial \mathbf{x}\)

The idea of symplectic splitting is to approximate the full evolution by alternating these sub-steps. Each sub-step is an exact symplectomorphism (it solves a Hamiltonian system exactly), and the composition of symplectomorphisms is a symplectomorphism. So the overall integrator is automatically symplectic, regardless of the step size.

NoteSplitting and the Baker–Campbell–Hausdorff Formula

The splitting approximation can be understood through the BCH formula. For operators \(\hat{A}\) and \(\hat{B}\), \[ e^{\hat{A}\Delta t} \, e^{\hat{B}\Delta t} = e^{(\hat{A}+\hat{B})\Delta t + \frac{1}{2}[\hat{A},\hat{B}]\Delta t^2 + \cdots} \] The leading error is \(\mathcal{O}(\Delta t^2)\), giving a first-order integrator. By symmetrizing — taking a half-step, a full step, and another half-step — the odd-order error terms cancel, yielding second-order accuracy.

KDK Leapfrog

The most common choice for cosmological simulations is the Kick-Drift-Kick (KDK) leapfrog:

  1. Kick (half step): update momenta by a half timestep using the force
  2. Drift (full step): update positions by a full timestep using the (now updated) momenta
  3. Kick (half step): update momenta by another half timestep using the force at the new positions

This is a second-order symplectic integrator: the position and force errors are \(\mathcal{O}(\Delta t^2)\) per step.

NoteWhy KDK Instead of DKD?

The alternative Drift-Kick-Drift (DKD) ordering is equally valid and has the same formal accuracy. KDK is preferred in practice because after each full step, the positions and momenta are synchronized at the same time — making it easy to output snapshots and diagnostics. In DKD, positions lead momenta by a half step, requiring an extra synchronization step for output.

20.3 Equations of Motion in \(\ln a\)

The peculiar Lagrangian uses coordinate time \(t\), which leads to explicit time-dependence through \(a(t)\) and introduces Hubble drag terms in the equation of motion. We can eliminate this by choosing a better time variable.

Changing the Time Variable

Define \[ \eta \equiv \ln a \] so that \(d\eta = da/a = (da/dt)/a \, dt = H \, dt\), or equivalently, \(dt = d\eta / H\). This is a monotonic reparametrization of time (as long as \(H > 0\), i.e., the universe is expanding).

We now re-derive the Lagrangian, conjugate momentum, and Hamiltonian in this new time variable.

The \(\eta\)-Lagrangian

The action is invariant under reparametrization of time. Starting from \[ S = \int \left[ \frac{1}{2} m a^2 \dot{\mathbf{x}}^2 - m\Phi \right] dt \] we substitute \(dt = d\eta/H\) and \(\dot{\mathbf{x}} = (d\mathbf{x}/d\eta)(d\eta/dt) = H \mathbf{x}'\) where primes denote \(d/d\eta\): \[ S = \int \left[ \frac{1}{2} m a^2 H^2 \mathbf{x}'^2 - m\Phi \right] \frac{d\eta}{H} = \int \left[ \frac{1}{2} m a^2 H \, \mathbf{x}'^2 - \frac{m\Phi}{H} \right] d\eta \]

Reading off the \(\eta\)-Lagrangian: \[ \tilde{\mathcal{L}} = \frac{1}{2} m a^2 H \, \mathbf{x}'^2 - \frac{m\Phi}{H} \]

Conjugate Momentum

The conjugate momentum with respect to \(\eta\) is \[ \mathbf{p} = \frac{\partial \tilde{\mathcal{L}}}{\partial \mathbf{x}'} = m a^2 H \, \mathbf{x}' \] Since \(\mathbf{x}' = d\mathbf{x}/d\eta = \dot{\mathbf{x}}/H\), this gives \[ \mathbf{p} = m a^2 \dot{\mathbf{x}} \] This is the same conjugate momentum as in the original (\(t\)-time) Lagrangian — it is proportional to the peculiar velocity \(\mathbf{v}_{\text{pec}} = a \dot{\mathbf{x}}\): \[ \mathbf{p} = m a \, \mathbf{v}_{\text{pec}} \]

NoteConnection to Redshift-Space Distortions

This momentum is convenient for computing redshift-space positions. If \(\hat{\mathbf{n}}\) is the line-of-sight direction, the redshift-space coordinate is \[ \mathbf{s} = \mathbf{x} + \hat{\mathbf{n}} \frac{p_\parallel}{m a^2 H} = \mathbf{x} + \hat{\mathbf{n}} \frac{v_{\text{pec},\parallel}}{a H} \] Having \(\mathbf{p} = m a^2 \dot{\mathbf{x}}\) directly available from the integrator makes this trivial to evaluate.

The Hamiltonian

To construct the Hamiltonian, we perform the Legendre transform: \[ \tilde{H} = \mathbf{p} \cdot \mathbf{x}' - \tilde{\mathcal{L}} \] First, express \(\mathbf{x}'\) in terms of \(\mathbf{p}\): \[ \mathbf{x}' = \frac{\mathbf{p}}{m a^2 H} \] Then: \[ \mathbf{p} \cdot \mathbf{x}' = \frac{\mathbf{p}^2}{m a^2 H} \] \[ \tilde{\mathcal{L}} = \frac{1}{2} \frac{\mathbf{p}^2}{m a^2 H} - \frac{m\Phi}{H} \] Therefore: \[ \tilde{H} = \frac{\mathbf{p}^2}{m a^2 H} - \frac{1}{2}\frac{\mathbf{p}^2}{m a^2 H} + \frac{m\Phi}{H} \] \[ \boxed{\tilde{H} = \underbrace{\frac{\mathbf{p}^2}{2 m a^2 H}}_{\displaystyle A(\mathbf{p}, \eta)} + \underbrace{\frac{m\Phi(\mathbf{x})}{H}}_{\displaystyle B(\mathbf{x}, \eta)}} \]

This is the key result: the Hamiltonian is separable into a kinetic term \(A\) depending only on \(\mathbf{p}\) (and \(\eta\)) and a potential term \(B\) depending only on \(\mathbf{x}\) (and \(\eta\)). This separability is exactly what we need for the symplectic splitting to work.

ImportantSeparability is the Key

The explicit time-dependence through \(H(\eta)\) and \(a(\eta)\) does not spoil the splitting — what matters is that for a fixed value of \(\eta\), the Hamiltonian separates into a function of \(\mathbf{p}\) alone plus a function of \(\mathbf{x}\) alone. The time-dependence of the coefficients simply means that the drift and kick factors vary from step to step.

Hamilton’s Equations

From \(\tilde{H}\), Hamilton’s equations in the \(\eta\) variable are: \[ \frac{d\mathbf{x}}{d\eta} = \frac{\partial \tilde{H}}{\partial \mathbf{p}} = \frac{\mathbf{p}}{m a^2 H} \] \[ \frac{d\mathbf{p}}{d\eta} = -\frac{\partial \tilde{H}}{\partial \mathbf{x}} = -\frac{m \nabla\Phi}{H} \]

These equations have no Hubble drag term — the friction that appeared in the coordinate-time formulation has been absorbed into the choice of time variable and conjugate momentum.

Drift and Kick Factors

In the KDK scheme, the drift and kick sub-steps require integrating Hamilton’s equations with either \(B = 0\) (drift) or \(A = 0\) (kick) over an interval in \(\eta\). Since \(A\) and \(B\) have explicit \(\eta\)-dependence through \(a\) and \(H\), these sub-steps involve integrals.

Drift step (from \(a_1\) to \(a_2\), with \(\mathbf{p}\) fixed): \[ \Delta \mathbf{x} = \frac{\mathbf{p}}{m} \int_{a_1}^{a_2} \frac{d\eta}{a^2 H} = \frac{\mathbf{p}}{m} \int_{a_1}^{a_2} \frac{da}{a^3 H(a)} \] where we used \(d\eta = da/a\). We define the drift factor: \[ \boxed{D(a_1, a_2) \equiv \int_{a_1}^{a_2} \frac{da}{a^3 H(a)}} \]

Kick step (from \(a_1\) to \(a_2\), with \(\mathbf{x}\) fixed): \[ \Delta \mathbf{p} = -m \nabla\Phi \int_{a_1}^{a_2} \frac{d\eta}{H} = -m \nabla\Phi \int_{a_1}^{a_2} \frac{da}{a H(a)} \] We define the kick factor: \[ \boxed{K(a_1, a_2) \equiv \int_{a_1}^{a_2} \frac{da}{a H(a)}} \]

NoteWhy Integrals over \(a\) Rather Than \(\eta\)?

We converted from \(d\eta\) to \(da/a\) for practical reasons: the Hubble parameter \(H(a)\) is a known function of \(a\) (from the Friedmann equation), making the integrands explicit. In the \(\eta\) variable, one would need \(H\) as a function of \(\eta = \ln a\), which is the same information in a slightly less convenient form.

Einstein–de Sitter Case

For an Einstein–de Sitter (EdS) universe (\(\Omega_m = 1\), \(\Omega_\Lambda = 0\)), the Friedmann equation gives \[ H(a) = H_0 \, a^{-3/2} \]

The drift and kick factors then have closed-form expressions.

Drift factor: \[ D(a_1, a_2) = \int_{a_1}^{a_2} \frac{da}{a^3 \cdot H_0 a^{-3/2}} = \frac{1}{H_0} \int_{a_1}^{a_2} \frac{da}{a^{3/2}} = \frac{1}{H_0} \left[ -\frac{2}{\sqrt{a}} \right]_{a_1}^{a_2} \] \[ D_{\text{EdS}}(a_1, a_2) = \frac{2}{H_0} \left( \frac{1}{\sqrt{a_1}} - \frac{1}{\sqrt{a_2}} \right) \]

Kick factor: \[ K(a_1, a_2) = \int_{a_1}^{a_2} \frac{da}{a \cdot H_0 a^{-3/2}} = \frac{1}{H_0} \int_{a_1}^{a_2} a^{1/2} \, da = \frac{1}{H_0} \left[ \frac{2}{3} a^{3/2} \right]_{a_1}^{a_2} \] \[ K_{\text{EdS}}(a_1, a_2) = \frac{2}{3 H_0} \left( a_2^{3/2} - a_1^{3/2} \right) \]

General \(\Lambda\)CDM

For a general flat \(\Lambda\)CDM cosmology: \[ H(a) = H_0 \sqrt{\Omega_m \, a^{-3} + \Omega_\Lambda} \]

The drift and kick integrals become: \[ D(a_1, a_2) = \int_{a_1}^{a_2} \frac{da}{a^3 H_0 \sqrt{\Omega_m a^{-3} + \Omega_\Lambda}} \] \[ K(a_1, a_2) = \int_{a_1}^{a_2} \frac{da}{a H_0 \sqrt{\Omega_m a^{-3} + \Omega_\Lambda}} \]

These generally do not have closed-form solutions and must be evaluated numerically.

NoteProduction Codes

In practice, production \(N\)-body codes precompute \(D(a)\) and \(K(a)\) as lookup tables at startup (evaluated from some initial \(a_i\) to a grid of \(a\) values using numerical quadrature). During the simulation, the drift and kick factors for any interval \([a_1, a_2]\) are obtained by table lookup and subtraction: \(D(a_1, a_2) = D(a_2) - D(a_1)\).

20.4 Comparison of Formulations

The three natural Lagrangian formulations for cosmological \(N\)-body simulations differ in their choice of time variable and conjugate momentum:

\(\mathcal{L}_{\text{pec}}\) \(\mathcal{L}'\) (total deriv.) \(\tilde{\mathcal{L}}\) (\(\ln a\) time)
Time variable \(t\) \(t\) \(\eta = \ln a\)
Conjugate momentum \(ma^2\dot{\mathbf{x}}\) \(ma\dot{\mathbf{r}}\) \(ma^2\dot{\mathbf{x}}\)
Physical meaning of \(p\) \(\propto \mathbf{v}_{\text{pec}}\) \(\propto\) total velocity \(\propto \mathbf{v}_{\text{pec}}\)
Hubble drag in EOM Yes No No
RSD convenience Direct Requires subtracting \(H\mathbf{x}\) Direct
Hamiltonian separable? Yes Yes Yes
Common use Textbook derivations Analytical work Simulation codes

The \(\tilde{\mathcal{L}}\) formulation with \(\eta = \ln a\) is preferred for simulations: the absence of Hubble drag simplifies the integrator, the conjugate momentum directly gives the peculiar velocity, and \(\eta\) provides a natural logarithmic time-stepping (equal steps in \(\eta\) correspond to equal expansion factors, giving finer resolution at early times when structures are forming rapidly).

20.5 The KDK Step

Bringing it all together, a single KDK leapfrog step advancing from \(a_n\) to \(a_{n+1}\) (with midpoint \(a_{n+1/2}\)) proceeds as:

Step 1 — Half Kick: \[ \mathbf{p}_{n+1/2} = \mathbf{p}_n - m \nabla\Phi(\mathbf{x}_n) \, K(a_n, a_{n+1/2}) \]

Step 2 — Full Drift: \[ \mathbf{x}_{n+1} = \mathbf{x}_{n} + \frac{\mathbf{p}_{n+1/2}}{m} \, D(a_n, a_{n+1}) \]

Step 3 — Half Kick:

Compute \(\nabla\Phi(\mathbf{x}_{n+1})\) from the particle positions at \(\mathbf{x}_{n+1}\) (this is where the Poisson solver enters — to be discussed in later sections), then: \[ \mathbf{p}_{n+1} = \mathbf{p}_{n+1/2} - m \nabla\Phi(\mathbf{x}_{n+1}) \, K(a_{n+1/2}, a_{n+1}) \]

ImportantThe Expensive Step

The force evaluation (computing \(\nabla\Phi\) from the particle distribution) is by far the most expensive part of each timestep. The KDK ordering means we need exactly one force evaluation per step — the force at the end of one step provides the first half-kick of the next step. This is a crucial practical advantage. How we compute this force efficiently using the Particle-Mesh method (CIC interpolation + FFT-based Poisson solver) is the subject of the next sections.

20.6 Setting Up

Before diving into the Particle-Mesh algorithm, let us set up our computational environment. We will use JAX for the simulation code (to be developed in later sections), but for now we only need NumPy and Matplotlib.

Code
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.colors import Normalize

20.7 The Particle-Mesh Method: Mass Assignment

The KDK integrator needs the gravitational force \(\nabla\Phi\) at every particle position, which requires solving the Poisson equation \[ \nabla^2 \Phi = 4\pi G \bar{\rho} a^2 \delta \] The Particle-Mesh (PM) method solves this on a regular grid using FFTs, which gives \(\mathcal{O}(N_g \log N_g)\) scaling (where \(N_g\) is the number of grid cells) rather than the \(\mathcal{O}(N^2)\) cost of direct particle-particle summation.

The PM pipeline has three steps:

  1. Mass assignment: deposit particle masses onto the grid to get \(\delta(\mathbf{x}_\text{grid})\)
  2. Poisson solve: use FFTs to get \(\Phi(\mathbf{x}_\text{grid})\) from \(\delta(\mathbf{x}_\text{grid})\)
  3. Force interpolation: interpolate \(\nabla\Phi\) from the grid back to particle positions

Steps 1 and 3 require a scheme for transferring quantities between particles (at arbitrary positions) and grid points (at fixed locations). The choice of scheme controls the accuracy, smoothness, and isotropy of the forces.

Nearest Grid Point (NGP)

The simplest approach: assign each particle’s entire mass to the single nearest grid point.

For a 1D grid with spacing \(h\) and grid points at \(x_j = jh\), a particle at position \(x_p\) contributes to the grid as: \[ \rho_j \mathrel{+}= m_p \cdot \begin{cases} 1 & \text{if } |x_p - x_j| < h/2 \\ 0 & \text{otherwise} \end{cases} \]

This is just a zeroth-order (piecewise constant) interpolation. It is fast but produces a very noisy density field — the density jumps discontinuously every time a particle crosses a cell boundary, and these discontinuities propagate into discontinuous forces.

Cloud-in-Cell (CIC)

Cloud-in-Cell is a first-order (piecewise linear) scheme that spreads each particle’s mass across the \(2^d\) nearest grid points (where \(d\) is the number of dimensions), weighted by the overlap volume.

In 1D, a particle at \(x_p\) contributes to the two bracketing grid points \(x_j\) and \(x_{j+1}\) with weights: \[ W(x_p, x_j) = \begin{cases} 1 - |x_p - x_j|/h & \text{if } |x_p - x_j| < h \\ 0 & \text{otherwise} \end{cases} \]

The weight is just a triangle (tent) function — linearly interpolating between the two nearest grid points. In \(d\) dimensions, the weight factorizes: \[ W(\mathbf{x}_p, \mathbf{x}_j) = \prod_{i=1}^{d} W(x_{p,i}, x_{j,i}) \] so in 3D each particle contributes to \(2^3 = 8\) grid points.

ImportantCIC for Both Assignment and Interpolation

A crucial requirement: the same scheme must be used for mass assignment (particles → grid) and force interpolation (grid → particles). Using different schemes breaks Newton’s third law — a particle would not feel the same force it exerts — leading to self-forces and momentum non-conservation. This is sometimes called the “transpose” requirement: the interpolation operator must be the transpose of the assignment operator.

Visualizing CIC in 1D

To build intuition, let us visualize how CIC works. We place a few particles on a 1D grid and show how each particle’s mass is split between neighboring grid points.

Code
fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True,
                         gridspec_kw={'height_ratios': [2, 1]})

# Grid setup
ngrid = 8
h = 1.0  # cell spacing
grid_x = np.arange(ngrid) * h + h/2  # cell centers

# A few example particles at arbitrary positions
particles = np.array([1.2, 2.8, 3.1, 5.7])
colors_p = plt.cm.Set1(np.linspace(0, 0.4, len(particles)))

# --- Top panel: individual particle contributions ---
ax = axes[0]

# Draw grid lines
for gx in grid_x:
    ax.axvline(gx, color='0.85', lw=0.5, zorder=0)

# Cell boundaries
for i in range(ngrid + 1):
    ax.axvline(i * h, color='0.7', lw=0.5, ls='--', zorder=0)

# For each particle, show the CIC weights
bar_width = 0.15
total_density = np.zeros(ngrid)

for ip, xp in enumerate(particles):
    weights = np.zeros(ngrid)
    for j in range(ngrid):
        dx = abs(xp - grid_x[j]) / h
        if dx < 1.0:
            weights[j] = 1.0 - dx
    total_density += weights

    offset = (ip - len(particles)/2 + 0.5) * bar_width
    ax.bar(grid_x + offset, weights, width=bar_width, alpha=0.7,
           color=colors_p[ip], label=f'$x_p = {xp}$', zorder=2)

    # Draw the particle position
    ax.plot(xp, -0.06, marker='^', ms=10, color=colors_p[ip],
            clip_on=False, zorder=5)

# Draw grid points
ax.plot(grid_x, np.zeros(ngrid), 'ko', ms=6, zorder=3)

ax.set_ylabel('CIC weight $W(x_p, x_j)$')
ax.set_ylim(-0.02, 1.05)
ax.legend(loc='upper right', fontsize=9, ncol=2)
ax.set_title('Individual particle contributions')

# --- Bottom panel: total deposited density ---
ax = axes[1]

for gx in grid_x:
    ax.axvline(gx, color='0.85', lw=0.5, zorder=0)
for i in range(ngrid + 1):
    ax.axvline(i * h, color='0.7', lw=0.5, ls='--', zorder=0)

ax.bar(grid_x, total_density, width=0.6, color='0.4', alpha=0.7, zorder=2)
ax.plot(grid_x, np.zeros(ngrid), 'ko', ms=6, zorder=3)

# Show particles
for ip, xp in enumerate(particles):
    ax.plot(xp, -0.08, marker='^', ms=10, color=colors_p[ip],
            clip_on=False, zorder=5)

ax.set_xlabel('$x$')
ax.set_ylabel('Total density')
ax.set_title('Sum over all particles')

plt.tight_layout()
plt.show()
Figure 20.1: CIC mass assignment in 1D. Particles (red triangles) deposit mass onto grid points (black dots) using linear weights. The blue bars show the weight assigned to each grid point from each particle, and the gray bars show the total deposited density. The triangle/tent function ensures smooth, continuous density as particles move.

Visualizing CIC in 2D

In 2D, each particle deposits mass onto the four corners of its enclosing cell. The weight at each corner is the product of the 1D weights in each direction — geometrically, this is the area of the rectangle diagonally opposite to that corner.

Code
fig, ax = plt.subplots(1, 1, figsize=(6, 6))

# Grid
ngrid = 5
h = 1.0
for i in range(ngrid + 1):
    ax.axhline(i * h, color='0.8', lw=0.5)
    ax.axvline(i * h, color='0.8', lw=0.5)

# Grid points (cell centers)
for i in range(ngrid):
    for j in range(ngrid):
        ax.plot(i * h + h/2, j * h + h/2, 'ko', ms=5, zorder=3)

# Particle position
xp, yp = 2.3, 2.7

# Enclosing cell corners (grid point indices)
# Grid points at cell centers: x_j = j*h + h/2
# Find the two nearest in each direction
jx = int((xp - h/2) / h)  # left grid point index
jy = int((yp - h/2) / h)  # lower grid point index

corners = [(jx, jy), (jx+1, jy), (jx, jy+1), (jx+1, jy+1)]
corner_coords = [(c[0]*h + h/2, c[1]*h + h/2) for c in corners]

# CIC weights
dx = (xp - corner_coords[0][0]) / h
dy = (yp - corner_coords[0][1]) / h

weights = [
    (1 - dx) * (1 - dy),  # lower-left
    dx * (1 - dy),         # lower-right
    (1 - dx) * dy,         # upper-left
    dx * dy,               # upper-right
]

# Color patches showing the "opposite area" interpretation
# For each corner, the weight = area of rectangle on the opposite side
patch_specs = [
    # (lower-left corner, opposite area is upper-right)
    (xp, yp, corner_coords[3][0] - xp + h/2, corner_coords[3][1] - yp + h/2),
    # (lower-right corner, opposite area is upper-left)
    (corner_coords[0][0] - h/2, yp, xp - corner_coords[0][0] + h/2, corner_coords[3][1] - yp + h/2),
    # (upper-left corner, opposite area is lower-right)
    (xp, corner_coords[0][1] - h/2, corner_coords[3][0] - xp + h/2, yp - corner_coords[0][1] + h/2),
    # (upper-right corner, opposite area is lower-left)
    (corner_coords[0][0] - h/2, corner_coords[0][1] - h/2, xp - corner_coords[0][0] + h/2, yp - corner_coords[0][1] + h/2),
]

cmap = plt.cm.Blues
norm = Normalize(vmin=0, vmax=max(weights) * 1.5)
labels_pos = ['lower-left', 'lower-right', 'upper-left', 'upper-right']

for i, ((cx, cy), w, ps) in enumerate(zip(corner_coords, weights, patch_specs)):
    # Shaded rectangle (opposite area)
    rect = patches.Rectangle((ps[0], ps[1]), ps[2], ps[3],
                              facecolor=cmap(norm(w)), edgecolor='steelblue',
                              lw=1.5, alpha=0.5, zorder=1)
    ax.add_patch(rect)

    # Weight label at grid point
    ax.annotate(f'$W = {w:.2f}$', (cx, cy),
                textcoords='offset points', xytext=(12, 8),
                fontsize=11, fontweight='bold', color='steelblue',
                bbox=dict(boxstyle='round,pad=0.2', fc='white', ec='steelblue', alpha=0.8))

    # Highlight grid point
    ax.plot(cx, cy, 'o', ms=10, color='steelblue', zorder=4)

# Particle
ax.plot(xp, yp, 'r*', ms=18, zorder=5)
ax.annotate(f'particle\n$({xp}, {yp})$', (xp, yp),
            textcoords='offset points', xytext=(-50, -30),
            fontsize=10, color='red',
            arrowprops=dict(arrowstyle='->', color='red'))

# Dashed lines from particle to cell edges
ax.axhline(yp, color='red', ls=':', lw=0.8, alpha=0.5)
ax.axvline(xp, color='red', ls=':', lw=0.8, alpha=0.5)

# Dimension annotations
y_bottom = corner_coords[0][1]
ax.annotate('', xy=(xp, y_bottom - 0.35), xytext=(corner_coords[0][0], y_bottom - 0.35),
            arrowprops=dict(arrowstyle='<->', color='0.3', lw=1.5))
ax.text((xp + corner_coords[0][0])/2, y_bottom - 0.45, f'$dx={dx:.1f}$',
        ha='center', fontsize=10, color='0.3')

ax.annotate('', xy=(corner_coords[1][0], y_bottom - 0.35), xytext=(xp, y_bottom - 0.35),
            arrowprops=dict(arrowstyle='<->', color='0.3', lw=1.5))
ax.text((xp + corner_coords[1][0])/2, y_bottom - 0.45, f'$1-dx={1-dx:.1f}$',
        ha='center', fontsize=10, color='0.3')

ax.set_xlim(0.5, 4.5)
ax.set_ylim(1.0, 5.0)
ax.set_aspect('equal')
ax.set_xlabel('$x$')
ax.set_ylabel('$y$')

plt.tight_layout()
plt.show()
Figure 20.2: CIC mass assignment in 2D for a single particle (red star). The particle sits inside a grid cell and deposits mass to the four surrounding grid points. Each weight equals the area of the shaded rectangle diagonally opposite to that grid point — this is the product of the 1D tent-function weights in x and y.

Why CIC?

CIC is the standard choice for PM codes because it strikes a good balance:

Scheme Order Grid points per particle Density continuity Force continuity
NGP 0 \(1\) Discontinuous Discontinuous
CIC 1 \(2^d\) Continuous (\(C^0\)) Discontinuous
TSC 2 \(3^d\) \(C^1\) Continuous

Higher-order schemes (TSC = Triangular Shaped Cloud, and beyond) give smoother forces but spread each particle over more grid points, increasing the cost of the assignment and interpolation steps. CIC is almost universally used in production PM and TreePM codes — the force discontinuity is at the grid scale and is subdominant to the grid-scale force resolution limit inherent in the PM method anyway.

NoteCIC as a Convolution

There is a clean mathematical way to think about the full NGP → CIC → TSC hierarchy. NGP assigns mass using a top-hat window of width \(h\) (the “zeroth-order B-spline”). CIC convolves this top-hat with itself, producing the triangle/tent function — the first-order B-spline. TSC convolves once more, producing the second-order B-spline (a piecewise quadratic). Each successive convolution gains one order of smoothness. In Fourier space, convolution becomes multiplication, so the CIC window function is \(\tilde{W}(k) = \mathrm{sinc}^2(kh/2)\), compared to \(\mathrm{sinc}(kh/2)\) for NGP — the CIC window suppresses small-scale power more aggressively, which is the price of smoother forces.

Force Interpolation

Once we have solved for \(\Phi\) on the grid (next section), we need the force \(\mathbf{F} = -m\nabla\Phi\) at each particle position. As emphasized above, we must use the same CIC weights to interpolate the force back to the particles: \[ \mathbf{F}(\mathbf{x}_p) = -m \sum_j W(\mathbf{x}_p, \mathbf{x}_j) \, \nabla\Phi(\mathbf{x}_j) \] where the sum runs over the \(2^d\) grid points surrounding particle \(p\), and the gradient \(\nabla\Phi\) on the grid is computed by finite differences (or, equivalently, by multiplying by \(i\mathbf{k}\) in Fourier space before transforming back — more on this in the next section).

This completes the assignment/interpolation infrastructure. We now turn to the middle step of the PM pipeline: solving the Poisson equation on the grid using FFTs.

20.8 The Poisson Solver

The Idea

We need to solve \[ \nabla^2 \Phi = 4\pi G \bar{\rho} a^2 \delta \] on a periodic grid. Periodicity is the key simplification: the eigenfunctions of \(\nabla^2\) on a periodic domain are plane waves \(e^{i\mathbf{k}\cdot\mathbf{x}}\), so the Poisson equation becomes algebraic in Fourier space.

DFT Conventions

Before proceeding, we need to be precise about our Fourier transform conventions. On a grid of \(N_g\) points in each dimension with spacing \(h = L/N_g\), the discrete Fourier transform (DFT) and its inverse are:

\[ \hat{f}_\mathbf{k} = \sum_{\mathbf{n}} f_\mathbf{n} \, e^{-2\pi i \, \mathbf{k}\cdot\mathbf{n}/N_g} \] \[ f_\mathbf{n} = \frac{1}{N_g^d} \sum_{\mathbf{k}} \hat{f}_\mathbf{k} \, e^{+2\pi i \, \mathbf{k}\cdot\mathbf{n}/N_g} \]

where \(\mathbf{n}\) indexes grid points and \(\mathbf{k}\) indexes wavenumber bins. This is the convention used by NumPy’s np.fft.fftn and np.fft.ifftn (and JAX’s jnp.fft.fftn), with the \(1/N_g^d\) normalization on the inverse transform.

From Continuous to Discrete Fourier Transforms

In physics we work with the continuous Fourier transform pair (in 1D for clarity): \[ \hat{f}(k) = \int_{-\infty}^{\infty} f(x) \, e^{-ikx} \, dx \qquad f(x) = \int_{-\infty}^{\infty} \hat{f}(k) \, \frac{dk}{2\pi} \, e^{+ikx} \] On a periodic box of length \(L\) with \(N_g\) grid points at positions \(x_n = nh\) (where \(h = L/N_g\)), the integral becomes a Riemann sum: \[ \hat{f}(k) \approx \sum_{n=0}^{N_g-1} f(x_n) \, e^{-ik x_n} \, h \] The allowed wavenumbers are \(k = 2\pi m / L\) for integer \(m\). Substituting \(x_n = nh = nL/N_g\) and \(k = 2\pi m/L\): \[ \hat{f}(k_m) \approx h \sum_{n=0}^{N_g-1} f_n \, e^{-2\pi i \, mn/N_g} = h \, \hat{f}_m^{\text{DFT}} \] So the DFT output differs from the continuous transform by a factor of \(h\) (the grid spacing, which is the quadrature weight of the Riemann sum).

Similarly, for the inverse transform, the integral over \(dk/(2\pi)\) becomes a sum over discrete modes with spacing \(\Delta k = 2\pi/L\): \[ f(x_n) \approx \sum_{m} \hat{f}(k_m) \, e^{+ik_m x_n} \, \frac{\Delta k}{2\pi} = \frac{1}{L} \sum_m \hat{f}(k_m) \, e^{+2\pi i \, mn/N_g} \] Substituting \(\hat{f}(k_m) = h \, \hat{f}_m^{\text{DFT}}\): \[ f_n = \frac{h}{L} \sum_m \hat{f}_m^{\text{DFT}} \, e^{+2\pi i \, mn/N_g} = \frac{1}{N_g} \sum_m \hat{f}_m^{\text{DFT}} \, e^{+2\pi i \, mn/N_g} \] which is exactly the DFT inverse transform convention above.

In \(d\) dimensions, the volume element \(d^d x \to h^d\) gives:

\[ \boxed{\hat{f}(\mathbf{k}) = h^d \, \hat{f}_\mathbf{k}^{\text{DFT}}} \qquad \boxed{f(\mathbf{x}_\mathbf{n}) = h^{-d} \, f_\mathbf{n}^{\text{IDFT}}} \]

where \(h = L/N_g\) is the grid spacing. Here \(\hat{f}^{\text{DFT}}\) denotes the output of fftn, and \(f^{\text{IDFT}}\) denotes the output of ifftn when fed \(\hat{f}(\mathbf{k})\) (the continuous transform values). The forward relation says: multiply the fftn output by the volume element \(h^d\) to get the continuous transform. The inverse says: if you feed continuous Fourier coefficients into ifftn, multiply the result by \(h^{-d}\) to get the field values. The two factors are inverses of each other, as they must be.

ImportantWhen Do These Factors Matter?

If your equation is entirely in Fourier space (like the Poisson equation: divide \(\hat{\delta}\) by \(\hat{L}\) to get \(\hat{\Phi}\), then differentiate), the \(h^d\) factors cancel between the forward and inverse transforms and you can work directly with the DFT output — no extra factors needed.

The factors matter when you need to relate DFT output to a physical quantity with specific units. For example, the power spectrum: \[ P(k) = |\hat{\delta}(k)|^2 = h^{2d} \, |\hat{\delta}_\mathbf{k}^{\text{DFT}}|^2 \] or Parseval’s theorem relating a real-space integral to a Fourier-space sum. We will encounter this when setting up initial conditions from a power spectrum.

ImportantPhysical vs Grid Wavenumbers

The DFT indices \(\mathbf{k}\) run over integers. The corresponding physical wavenumbers are \[ \mathbf{k}_{\text{phys}} = \frac{2\pi \mathbf{k}}{L} \] The Nyquist frequency is \(k_{\text{Ny}} = \pi N_g / L = \pi / h\). Modes above this frequency are aliased. In NumPy, after np.fft.fftn, the wavenumber indices are ordered as \(k = 0, 1, \ldots, N_g/2, -N_g/2+1, \ldots, -1\); use np.fft.fftfreq(Ng, d=h) to get the physical frequencies directly.

Solving Poisson in Fourier Space

Taking the DFT of the Poisson equation: \[ \widehat{\nabla^2 \Phi}_\mathbf{k} = (4\pi G \bar{\rho} a^2) \, \hat{\delta}_\mathbf{k} \]

The left-hand side requires the Fourier-space representation of \(\nabla^2\). On a continuous domain, \(\nabla^2 e^{i\mathbf{k}\cdot\mathbf{x}} = -k^2 e^{i\mathbf{k}\cdot\mathbf{x}}\), so one might use \(-k^2\). However, on a discrete grid we should use the Laplacian that is consistent with the finite-difference stencil we use for the gradient. This matters for maintaining consistency between the potential and the force.

Continuous (spectral) Laplacian: \[ \hat{L}(\mathbf{k}) = -k_{\text{phys}}^2 = -\left(\frac{2\pi}{L}\right)^2 |\mathbf{k}|^2 \]

Discrete (finite-difference) Laplacian:

For a second-order centered finite difference, \(\nabla^2 f \approx (f_{j+1} - 2f_j + f_{j-1})/h^2\) in each dimension. The Fourier-space representation of this operator is: \[ \hat{L}(\mathbf{k}) = -\sum_{i=1}^{d} \frac{4}{h^2} \sin^2\!\left(\frac{\pi k_i}{N_g}\right) \]

At low \(k\) (long wavelengths), \(\sin(\pi k/N_g) \approx \pi k/N_g\) and the two agree. They differ near the Nyquist frequency, where the discrete version correctly captures the grid-scale behavior.

NoteWhich Laplacian to Use?

Either choice works. The discrete Laplacian is more self-consistent if you also compute gradients by finite differences in real space. The continuous (spectral) Laplacian is simpler and gives slightly better long-wavelength accuracy. Many production PM codes use the discrete version for strict consistency; for a pedagogical code the difference is minor. We will use the discrete Laplacian for consistency.

The solution is then: \[ \hat{\Phi}_\mathbf{k} = \frac{4\pi G \bar{\rho} a^2}{\hat{L}(\mathbf{k})} \, \hat{\delta}_\mathbf{k} \qquad (\mathbf{k} \neq 0) \]

The \(\mathbf{k} = 0\) mode is the mean potential, which is unphysical in a periodic box (it corresponds to the force from a uniform infinite sheet). We set \(\hat{\Phi}_{\mathbf{k}=0} = 0\).

Computing the Force

We need \(\nabla\Phi\) on the grid to kick the particles. There are two approaches:

Option A: Finite differences in real space. Transform \(\hat{\Phi}_\mathbf{k}\) back to real space to get \(\Phi(\mathbf{x}_j)\), then compute the gradient with centered differences: \[ (\nabla_i \Phi)_j = \frac{\Phi_{j+1} - \Phi_{j-1}}{2h} \] This requires one inverse FFT (to get \(\Phi\)) plus a simple stencil operation.

Option B: Differentiation in Fourier space. Multiply \(\hat{\Phi}_\mathbf{k}\) by the Fourier-space gradient operator, then inverse-transform each component separately: \[ \widehat{(\nabla_i \Phi)}_\mathbf{k} = \hat{D}_i(\mathbf{k}) \, \hat{\Phi}_\mathbf{k} \] where \(\hat{D}_i\) is the Fourier representation of the gradient in direction \(i\).

For consistency with the discrete Laplacian, we use the discrete gradient: \[ \hat{D}_i(k_i) = \frac{i}{h} \sin\!\left(\frac{2\pi k_i}{N_g}\right) \] which is the Fourier transform of the centered-difference operator \((f_{j+1} - f_{j-1})/(2h)\).

NoteCombining the Solve and Gradient

In practice, we can combine the Poisson solve and the gradient into a single Fourier-space operation. Define the Green’s function with gradient: \[ \hat{G}_i(\mathbf{k}) = \frac{\hat{D}_i(\mathbf{k})}{\hat{L}(\mathbf{k})} \] Then the force field in direction \(i\) is: \[ \widehat{F_i}(\mathbf{k}) = -m \cdot (4\pi G \bar{\rho} a^2) \, \hat{G}_i(\mathbf{k}) \, \hat{\delta}_\mathbf{k} \] This requires one forward FFT (to get \(\hat{\delta}\)) and \(d\) inverse FFTs (one per force component) — no intermediate real-space potential is needed.

The Full PM Force Pipeline

Putting it all together, the force calculation at each timestep proceeds as:

  1. CIC deposit: assign particle masses to grid → \(\rho(\mathbf{x}_j)\)
  2. Compute overdensity: \(\delta_j = \rho_j / \bar{\rho} - 1\)
  3. Forward FFT: \(\delta_j \to \hat{\delta}_\mathbf{k}\)
  4. Green’s function multiply: \(\hat{F}_{i,\mathbf{k}} = -m \cdot (4\pi G \bar{\rho} a^2) \, \hat{G}_i(\mathbf{k}) \, \hat{\delta}_\mathbf{k}\) for each direction \(i\)
  5. Inverse FFT: \(\hat{F}_{i,\mathbf{k}} \to F_i(\mathbf{x}_j)\) for each direction
  6. CIC interpolate: interpolate \(F_i(\mathbf{x}_j)\) to particle positions

Steps 3–5 cost \(\mathcal{O}(N_g^d \log N_g)\) via the FFT. Steps 1 and 6 cost \(\mathcal{O}(N)\) where \(N\) is the number of particles. The total cost per timestep is \(\mathcal{O}(N_g^d \log N_g + N)\).

ImportantPM Force Resolution

The PM method resolves forces only down to the grid scale \(\sim h = L/N_g\). Structures smaller than \(\sim 2h\) are smoothed out. This is adequate for following the large-scale evolution of the density field, but not for resolving internal halo structure. Production codes address this with hybrid methods:

  • P\(^3\)M: PM for long-range forces + direct particle-particle (PP) for short-range corrections
  • TreePM: PM for long-range + tree algorithm for short-range
  • AMR-PM: Adaptive mesh refinement in high-density regions

For our purposes, a pure PM code is sufficient — we will test it against the Zel’dovich approximation, which operates well above the grid scale.

20.9 Implementation

We now implement the PM \(N\)-body code in JAX. We will work in 3D throughout, in a periodic box of comoving side \(L\) with \(N_g^3\) grid cells and \(N_p^3\) particles.

Code Units

Before writing any code, let us set up a unit system that simplifies the equations. We make three choices:

Quantity Code value Physical value
Hubble constant \(H_0\) 1 \(H_0\)
Box size \(L\) 1 \(L_{\text{phys}}\) (comoving Mpc/\(h\))
Particle mass \(m\) 1 \(m_p = \Omega_m \bar{\rho} L^3/N\)

With \(H_0 = 1\), time is measured in units of \(1/H_0\) and the Friedmann equation becomes \(H(a) = a^{-3/2}\) (for EdS). With \(L = 1\), comoving positions run from 0 to 1 and wavenumbers are \(k = 2\pi n\) for integer \(n\). With \(m = 1\), the density deposited onto the grid is simply a particle count per cell.

The key derived relation: \(4\pi G \bar{\rho} = \frac{3}{2}\Omega_m H_0^2\), so the Poisson source \(4\pi G \bar{\rho} a^2 \delta\) becomes \(\frac{3}{2}\Omega_m \delta / a\) in code units.

Converting to physical units. A code velocity \(v_{\text{code}}\) (which has dimensions of \(L/t_{\text{code}} = L \cdot H_0\)) maps to a physical velocity as: \[ v_{\text{phys}} = v_{\text{code}} \times H_0 L_{\text{phys}} \] Similarly, code positions map as \(x_{\text{phys}} = x_{\text{code}} \times L_{\text{phys}}\), and code times as \(t_{\text{phys}} = t_{\text{code}} / H_0\). The conjugate momentum \(p = a^2 \dot{x}\) in code units becomes \(p_{\text{phys}} = p_{\text{code}} \times m_p H_0 L_{\text{phys}}\).

NoteWhy These Units?

This is the standard approach in \(N\)-body codes: factor out the overall scales (\(H_0\), \(L\), \(m_p\)) so the code only deals with dimensionless numbers of order unity. The physics is entirely in the cosmological functions (\(H(a)\), \(D(a_1,a_2)\), \(K(a_1,a_2)\)) and the Poisson prefactor \(\frac{3}{2}\Omega_m/a\). To run a “different” simulation (different box size, particle number, or \(H_0\)), you only change the conversion factors at output time — the code itself is identical.

Periodic Boundary Conditions

As discussed in the previous lecture, we simulate a cubic box of comoving side \(L\) with periodic boundary conditions: a particle leaving one face re-enters from the opposite face, and every particle sees an infinite periodic tiling of the box in all directions.

This is both a physical assumption and a computational convenience:

  • Physical: on scales much smaller than the horizon, the universe is (statistically) homogeneous. A periodic box is the natural way to enforce this — there are no boundaries, no edge effects, and the mean density is exactly \(\bar{\rho}\) by construction.
  • Computational: periodicity means the density field is a periodic function, so we can use the discrete Fourier transform (FFT) to solve the Poisson equation. This is what makes the PM method fast.

Periodicity also means we can only resolve modes with wavelengths \(\lambda \leq L\), i.e., wavenumbers \(k \geq 2\pi/L\). There is no information about fluctuations on scales larger than the box — these are implicitly assumed to vanish. This sets a minimum wavenumber (“fundamental mode”) \(k_f = 2\pi/L\).

In the code, periodicity enters in two places: (1) particle positions are wrapped via pos % L after every drift step, and (2) the CIC stencil uses modular arithmetic (idx + 1) % Ng so that particles near the boundary correctly deposit mass across the periodic boundary.

Imports and Setup

Code
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial

Cosmological Functions (Einstein–de Sitter)

For an EdS universe (\(\Omega_m = 1\)), the Hubble parameter and drift/kick factors have simple closed-form expressions (derived in Section 1). Using these avoids any numerical integration.

Code
def hubble(a):
    """H(a) / H_0 for Einstein-de Sitter."""
    return a**(-1.5)

def drift_factor(a1, a2):
    """D(a1, a2) = 2 (1/sqrt(a1) - 1/sqrt(a2)) for EdS (H_0 = 1)."""
    return 2.0 * (1.0 / jnp.sqrt(a1) - 1.0 / jnp.sqrt(a2))

def kick_factor(a1, a2):
    """K(a1, a2) = (2/3) (a2^{3/2} - a1^{3/2}) for EdS (H_0 = 1)."""
    return (2.0 / 3.0) * (a2**1.5 - a1**1.5)
NoteGeneralizing Beyond EdS

For a general \(\Lambda\)CDM cosmology, replace these three functions with numerical quadrature (e.g., jnp.trapezoid over a fine grid in \(a\)) or precomputed lookup tables. The rest of the code is unchanged — this is one of the advantages of the \(\ln a\) formulation.

CIC Mass Assignment

We now implement the CIC deposit. For each particle, we find the two nearest grid points in each dimension and distribute mass with linear weights.

The key subtlety is periodicity: particles near the box boundary wrap around, so their CIC stencil may span the periodic boundary. We handle this with modular arithmetic on the grid indices.

Code
def cic_deposit(pos, Ng, L):
    """
    CIC mass assignment: deposit unit-mass particles onto an Ng^3 grid.

    Parameters
    ----------
    pos : array, shape (N, 3)
        Particle positions in [0, L).
    Ng : int
        Grid size per dimension.
    L : float
        Box size.

    Returns
    -------
    rho : array, shape (Ng, Ng, Ng)
        Density field (particle count per cell).
    """
    h = L / Ng

    # Wrap positions into [0, L)
    pos = pos % L

    # Grid index of lower-left corner and fractional offset
    cell = (pos / h - 0.5)          # shifted so grid points are at centers
    idx = jnp.floor(cell).astype(int)
    dx = cell - idx                 # fractional distance, in [0, 1)

    # CIC weights: 1 - dx for the lower grid point, dx for the upper
    wx = jnp.stack([1.0 - dx, dx], axis=-1)  # shape (N, 3, 2)

    # Build the 2^3 = 8 contributions
    rho = jnp.zeros((Ng, Ng, Ng))
    for ii in range(2):
        for jj in range(2):
            for kk in range(2):
                weight = wx[:, 0, ii] * wx[:, 1, jj] * wx[:, 2, kk]
                ix = (idx[:, 0] + ii) % Ng
                iy = (idx[:, 1] + jj) % Ng
                iz = (idx[:, 2] + kk) % Ng
                rho = rho.at[ix, iy, iz].add(weight)

    return rho
NoteCIC and JAX Tracing

The triple loop over the 8 CIC corners is unrolled at trace time by JAX — it produces 8 .at[].add() operations in the computation graph. This is fine because 8 is small and fixed. The actual particle loop is vectorized (all \(N\) particles are processed simultaneously via array operations). When we jit the pm_force function below, these CIC functions get compiled as part of the larger computation graph.

CIC Force Interpolation

The transpose operation: interpolate a grid-based field to particle positions using the same CIC weights.

Code
def cic_interp(field, pos, Ng, L):
    """
    CIC interpolation: read a grid field at particle positions.

    Parameters
    ----------
    field : array, shape (Ng, Ng, Ng)
        Field defined on the grid.
    pos : array, shape (N, 3)
        Particle positions in [0, L).

    Returns
    -------
    values : array, shape (N,)
        Interpolated field values at particle positions.
    """
    h = L / Ng
    pos = pos % L

    cell = (pos / h - 0.5)
    idx = jnp.floor(cell).astype(int)
    dx = cell - idx

    wx = jnp.stack([1.0 - dx, dx], axis=-1)

    values = jnp.zeros(pos.shape[0])
    for ii in range(2):
        for jj in range(2):
            for kk in range(2):
                weight = wx[:, 0, ii] * wx[:, 1, jj] * wx[:, 2, kk]
                ix = (idx[:, 0] + ii) % Ng
                iy = (idx[:, 1] + jj) % Ng
                iz = (idx[:, 2] + kk) % Ng
                values = values + weight * field[ix, iy, iz]

    return values

Green’s Function

We precompute the Fourier-space Green’s function that combines the Poisson solve and gradient in one step, as described in Section 3. This only needs to be computed once for a given grid.

Code
def make_green_function(Ng, L):
    """
    Precompute the Green's function for the PM force calculation.

    Returns the gradient of the inverse Laplacian in Fourier space:
    G_i(k) = D_i(k) / L(k), with discrete operators.

    Returns
    -------
    green_x, green_y, green_z : arrays, shape (Ng, Ng, Ng)
        Fourier-space Green's function for each force component.
    """
    h = L / Ng

    # Wavenumber indices
    k = jnp.fft.fftfreq(Ng, d=1.0/Ng)  # integer wavenumber indices 0..Ng/2..-1
    kx, ky, kz = jnp.meshgrid(k, k, k, indexing='ij')

    # Discrete Laplacian: -sum_i (4/h^2) sin^2(pi k_i / Ng)
    laplacian = -(4.0 / h**2) * (
        jnp.sin(jnp.pi * kx / Ng)**2 +
        jnp.sin(jnp.pi * ky / Ng)**2 +
        jnp.sin(jnp.pi * kz / Ng)**2
    )

    # Avoid division by zero at k=0
    laplacian = laplacian.at[0, 0, 0].set(1.0)

    # Discrete gradient: (i/h) sin(2 pi k_i / Ng)
    grad_x = 1j / h * jnp.sin(2.0 * jnp.pi * kx / Ng)
    grad_y = 1j / h * jnp.sin(2.0 * jnp.pi * ky / Ng)
    grad_z = 1j / h * jnp.sin(2.0 * jnp.pi * kz / Ng)

    # Green's function: G_i = D_i / L
    green_x = grad_x / laplacian
    green_y = grad_y / laplacian
    green_z = grad_z / laplacian

    # Zero the k=0 mode (no mean force)
    green_x = green_x.at[0, 0, 0].set(0.0)
    green_y = green_y.at[0, 0, 0].set(0.0)
    green_z = green_z.at[0, 0, 0].set(0.0)

    return green_x, green_y, green_z

PM Force Calculation

Now we assemble the full force pipeline: CIC deposit → FFT → Green’s function → IFFT → CIC interpolate.

Code
def pm_force(pos, a, green_x, green_y, green_z, Ng, L, Omega_m):
    """
    Compute -grad(Phi) at each particle position via the PM method.

    The kick step is: dp = -grad(Phi) * K(a1, a2).
    This function returns -grad(Phi), including the Poisson prefactor
    (3/2) Omega_m / a.  The caller multiplies by the kick factor.

    Parameters
    ----------
    pos : array, shape (N, 3)
        Particle positions in [0, L).
    a : float
        Scale factor.
    green_x, green_y, green_z : arrays
        Precomputed Green's functions.
    Ng : int
        Grid size per dimension.
    L : float
        Box size.
    Omega_m : float
        Matter density parameter.

    Returns
    -------
    force : array, shape (N, 3)
        -grad(Phi) at each particle position.
    """
    # 1. CIC deposit
    rho = cic_deposit(pos, Ng, L)

    # 2. Overdensity: delta = rho / rho_bar - 1, where rho_bar = N / Ng^3
    rho_bar = pos.shape[0] / Ng**3
    delta = rho / rho_bar - 1.0

    # 3. Forward FFT
    delta_hat = jnp.fft.fftn(delta)

    # 4. Poisson prefactor: (3/2) Omega_m / a  (in H_0=1 units)
    prefactor = 1.5 * Omega_m / a

    # 5. Force in Fourier space: -prefactor * G_i * delta_hat
    fx_hat = -prefactor * green_x * delta_hat
    fy_hat = -prefactor * green_y * delta_hat
    fz_hat = -prefactor * green_z * delta_hat

    # 6. Inverse FFT (force fields are real)
    fx = jnp.fft.ifftn(fx_hat).real
    fy = jnp.fft.ifftn(fy_hat).real
    fz = jnp.fft.ifftn(fz_hat).real

    # 7. CIC interpolate to particle positions
    ax = cic_interp(fx, pos, Ng, L)
    ay = cic_interp(fy, pos, Ng, L)
    az = cic_interp(fz, pos, Ng, L)

    return jnp.stack([ax, ay, az], axis=-1)
NoteUnits Check

Let us verify the units. The Poisson equation in our \(H_0 = 1\) units is: \[ \hat{L}(\mathbf{k}) \hat{\Phi} = \frac{3}{2}\Omega_m \frac{1}{a} \hat{\delta} \] The Green’s function \(\hat{G}_i = \hat{D}_i / \hat{L}\) has dimensions of \(1/[\text{length}]\) (since \(\hat{D}_i \sim 1/h\) and \(\hat{L} \sim 1/h^2\)). So \(\nabla\Phi = (3\Omega_m / 2a) \cdot G_i * \delta\) has dimensions of \(H_0^2 \times \text{length}\), which is an acceleration — correct for the kick equation \(\Delta p = -m \nabla\Phi \cdot K(a_1, a_2)\) where \(K\) has dimensions of time.

The KDK Integrator

Finally, we assemble the KDK leapfrog step. Each step advances the system from \(a_n\) to \(a_{n+1}\).

Code
def kdk_step(pos, mom, a_start, a_end, green_x, green_y, green_z,
             Ng, L, Omega_m):
    """
    One KDK leapfrog step from a_start to a_end.

    Parameters
    ----------
    pos : array, shape (N, 3)
        Comoving positions.
    mom : array, shape (N, 3)
        Conjugate momenta (p = m a^2 dx/dt; m=1 here).
    a_start, a_end : float
        Scale factor at start and end of step.
    green_x, green_y, green_z : arrays
        Precomputed Green's functions.

    Returns
    -------
    pos_new, mom_new : arrays
        Updated positions and momenta.
    """
    a_mid = 0.5 * (a_start + a_end)

    # Half kick
    K1 = kick_factor(a_start, a_mid)
    force = pm_force(pos, a_start, green_x, green_y, green_z, Ng, L, Omega_m)
    mom = mom + force * K1

    # Full drift
    D = drift_factor(a_start, a_end)
    pos = pos + mom * D
    pos = pos % L   # periodic wrapping

    # Half kick
    K2 = kick_factor(a_mid, a_end)
    force = pm_force(pos, a_end, green_x, green_y, green_z, Ng, L, Omega_m)
    mom = mom + force * K2

    return pos, mom
NoteForce Evaluation Timing

Notice that the force in the first half-kick is evaluated at \(a_{\text{start}}\) and the force in the second half-kick at \(a_{\text{end}}\). Strictly, the kick integral \(K(a_1, a_2)\) assumes the force is constant over that interval, so the “correct” scale factor to evaluate at is somewhere in between. Using the endpoint values is the standard approximation — the error is absorbed into the \(\mathcal{O}(\Delta a^2)\) truncation error of the leapfrog scheme.

In a more careful implementation, the second half-kick of step \(n\) and the first half-kick of step \(n+1\) share the same force evaluation (since both use the force at \(a_{n+1}\)). We have not optimized for this here, computing the force twice. We will fix this in the simulation loop below.

The Simulation Loop

We now write the main simulation loop that advances the system from an initial scale factor \(a_i\) to a final scale factor \(a_f\) over \(N_{\text{steps}}\) equal steps in \(\ln a\).

Code
@partial(jit, static_argnums=(4, 8))
def run_simulation(pos_init, mom_init, a_start, a_end, n_steps,
                   green_x, green_y, green_z, Ng, L, Omega_m):
    """
    Run PM N-body simulation from a_start to a_end.

    Uses equally-spaced steps in ln(a), and avoids redundant force
    evaluations by merging half-kicks. The entire simulation is
    JIT-compiled using jax.lax.fori_loop.

    Parameters
    ----------
    n_steps : int (static)
        Number of timesteps.
    Ng : int (static)
        Grid size per dimension.

    Returns
    -------
    pos, mom : arrays
        Final positions and momenta.
    """
    # Steps equally spaced in ln(a)
    a_values = jnp.exp(jnp.linspace(jnp.log(a_start), jnp.log(a_end), n_steps + 1))
    a_mid = 0.5 * (a_values[:-1] + a_values[1:])  # midpoints, length n_steps

    # Precompute all drift factors: D(a_n, a_{n+1})
    D_arr = jax.vmap(drift_factor)(a_values[:-1], a_values[1:])

    # Precompute all kick factors for the merged scheme.
    # Step i kicks from a_mid[i] to a_mid[i+1],
    # except the last step kicks from a_mid[-1] to a_values[-1].
    kick_ends = jnp.concatenate([a_mid[1:], a_values[-1:]])
    K_arr = jax.vmap(kick_factor)(a_mid, kick_ends)

    # Initial half-kick
    K_init = kick_factor(a_values[0], a_mid[0])
    force = pm_force(pos_init, a_values[0], green_x, green_y, green_z, Ng, L, Omega_m)
    mom_init = mom_init + force * K_init

    def body(i, carry):
        pos, mom = carry
        # Full drift
        pos = pos + mom * D_arr[i]
        pos = pos % L
        # Force at a_{n+1}, then merged kick
        force = pm_force(pos, a_values[i + 1], green_x, green_y, green_z, Ng, L, Omega_m)
        mom = mom + force * K_arr[i]
        return pos, mom

    pos, mom = jax.lax.fori_loop(0, n_steps, body, (pos_init, mom_init))
    return pos, mom
ImportantJIT-Compiled Simulation Loop

The simulation is compiled end-to-end using @jit and jax.lax.fori_loop. The Python for loop is replaced by XLA’s loop primitive, which avoids Python overhead and enables hardware acceleration.

The merged-kick optimization is implemented by precomputing the kick factors: for step \(i\), the merged kick uses \(K(a_{i-1/2}, a_{i+1/2})\), except the final step which closes with \(K(a_{N-1/2}, a_N)\). This halves the number of force evaluations compared to a naive KDK implementation.

The n_steps and Ng arguments are marked static — changing them triggers recompilation, but within a run the loop executes as a single compiled computation. The first call incurs a compilation cost; subsequent calls with the same n_steps and Ng reuse the compiled code.

Utility: Uniform Particle Grid

Code
def make_uniform_grid(Np, L):
    """Create a uniform grid of Np^3 particles in a box of side L."""
    x = jnp.linspace(0, L, Np, endpoint=False) + L / (2 * Np)
    grid = jnp.meshgrid(x, x, x, indexing='ij')
    pos = jnp.stack([g.ravel() for g in grid], axis=-1)
    return pos

20.10 Test: Single Sine Wave vs Zel’dovich

We now test our PM code against the Zel’dovich approximation — an exact solution to the equations of motion in the limit of small perturbations. This is the standard first test for any \(N\)-body code.

The Zel’dovich Approximation

The Zel’dovich approximation describes the displacement of particles from a uniform grid under a growing-mode perturbation. In Lagrangian coordinates \(\mathbf{q}\) (the initial grid positions), the comoving position at time \(a\) is: \[ \mathbf{x}(\mathbf{q}, a) = \mathbf{q} + D_+(a) \, \boldsymbol{\Psi}(\mathbf{q}) \] where \(D_+(a)\) is the linear growth factor and \(\boldsymbol{\Psi}(\mathbf{q})\) is the displacement field, related to the initial overdensity by \(\nabla \cdot \boldsymbol{\Psi} = -\delta_0(\mathbf{q})\).

The velocity (conjugate momentum) follows from differentiating: \[ \mathbf{p} = m a^2 \dot{\mathbf{x}} = m a^2 \dot{D}_+(a) \, \boldsymbol{\Psi}(\mathbf{q}) \] where \(\dot{D}_+ = dD_+/dt\).

For an Einstein–de Sitter universe, the growing mode is simply \(D_+(a) = a\), and using \(H = H_0 a^{-3/2}\) we get \(\dot{D}_+ = \dot{a} = H a = H_0 a^{-1/2}\), so: \[ \mathbf{p} = m a^2 H_0 a^{-1/2} \boldsymbol{\Psi} = m H_0 a^{3/2} \boldsymbol{\Psi} \]

In our code units (\(H_0 = 1\), \(m = 1\)): \[ \mathbf{x}(\mathbf{q}, a) = \mathbf{q} + a \, \boldsymbol{\Psi}(\mathbf{q}) \qquad \mathbf{p}(\mathbf{q}, a) = a^{3/2} \, \boldsymbol{\Psi}(\mathbf{q}) \]

Single Plane-Wave Perturbation

The simplest test case is a single sinusoidal perturbation along one axis. Take the displacement field: \[ \boldsymbol{\Psi}(\mathbf{q}) = -\frac{A}{k} \sin(k q_x) \, \hat{\mathbf{x}} \] where \(k = 2\pi n / L\) for some integer mode number \(n\), and \(A\) is the amplitude of the initial overdensity: \[ \delta_0(\mathbf{q}) = -\nabla \cdot \boldsymbol{\Psi} = A \cos(k q_x) \]

The Zel’dovich solution is exact (for this 1D perturbation) until shell crossing — the moment when particle trajectories intersect. To see when this happens, consider the Jacobian of the Lagrangian-to-Eulerian map: \[ \frac{\partial x}{\partial q_x} = 1 + a \frac{d\Psi_x}{dq_x} = 1 - a A \cos(k q_x) \] Mass conservation requires \(\bar{\rho}\,dq = \rho\,dx\) (the mass in a Lagrangian volume element is conserved), so the Eulerian density is \[ \rho = \frac{\bar{\rho}}{|\partial x / \partial q_x|} \] When \(\partial x / \partial q_x \to 0\), the density diverges — neighboring particles that started at \(q\) and \(q + dq\) have been compressed to the same Eulerian position \(x\). Their trajectories cross, forming a caustic. After this point, multiple particle streams overlap at the same location and the single-valued Zel’dovich map breaks down.

The first crossing occurs when \(\cos(k q_x) = 1\) and \(aA = 1\), i.e., at \(a_{\text{cross}} = 1/A\). We will choose \(A\) and the final scale factor such that we stay well before shell crossing.

NoteWhy This Is an Exact Test

For a single Fourier mode along \(\hat{x}\), the Zel’dovich approximation is actually the exact solution to the full nonlinear equations of motion (before shell crossing), not just a linear approximation.

The physical picture is illuminating. A perturbation purely along \(\hat{x}\) means particles only move in \(x\) — they form infinite 2D sheets (uniform in \(y\) and \(z\)) that slide back and forth along \(\hat{x}\). The gravitational force on a sheet depends only on how many sheets are to its left vs. right (just like parallel plates in electrostatics). In a periodic box, this count is determined purely by the ordering of the sheets, which is preserved until shell crossing. So as long as no sheets pass through each other, each sheet feels a force that depends only on its Lagrangian label \(q_x\) — not on the detailed positions of other sheets. The sheets evolve independently, and the Zel’dovich solution describes their motion exactly.

Mathematically: the Poisson equation is linear, so the force from a sinusoidal \(\delta\) is sinusoidal, and the resulting motion preserves the sinusoidal structure. Any disagreement between the simulation and the Zel’dovich prediction (before shell crossing) is a bug.

Setting Up the Test

Code
# Simulation parameters
Ng = 64          # grid cells per dimension
Np = 64          # particles per dimension (one particle per cell)
L = 1.0          # box size (code units)
Omega_m = 1.0    # Einstein-de Sitter

# Precompute Green's function for this grid
green_x, green_y, green_z = make_green_function(Ng, L)

# Quick sanity check: uniform grid should give constant density
pos_test = make_uniform_grid(Np, L)
rho_test = cic_deposit(pos_test, Ng, L)
print(f"Sanity check: {Np}^3 = {Np**3} particles on {Ng}^3 grid")
print(f"  Density: min={rho_test.min():.4f}, max={rho_test.max():.4f} "
      f"(expected {(Np/Ng)**3:.4f})")

# Zel'dovich test parameters
n_mode = 2          # mode number (2 full wavelengths across the box)
A = 0.5             # initial amplitude (shell crossing at a = 1/A = 2)
k_mode = 2 * jnp.pi * n_mode / L

# Initial scale factor and final (before shell crossing)
a_init = 0.1
a_final = 1.5       # a_final * A = 0.75, before shell crossing at a=2

# Number of timesteps
n_steps = 100

print(f"\nMode: n={n_mode}, k={k_mode:.4f}")
print(f"Amplitude: A={A}")
print(f"Shell crossing at a = {1/A:.1f}")
print(f"Running from a={a_init} to a={a_final} ({n_steps} steps)")
print(f"Final displacement amplitude: a*A/k = {a_final * A / k_mode:.4f} "
      f"(in units of L)")
Sanity check: 64^3 = 262144 particles on 64^3 grid
  Density: min=1.0000, max=1.0000 (expected 1.0000)

Mode: n=2, k=12.5664
Amplitude: A=0.5
Shell crossing at a = 2.0
Running from a=0.1 to a=1.5 (100 steps)
Final displacement amplitude: a*A/k = 0.0597 (in units of L)

Zel’dovich Initial Conditions

Code
def make_zeldovich_ic(Np, L, n_mode, A, a_init):
    """
    Create Zel'dovich initial conditions for a single sine-wave mode.

    Parameters
    ----------
    Np : int
        Particles per dimension.
    L : float
        Box size.
    n_mode : int
        Mode number (number of wavelengths across the box).
    A : float
        Initial overdensity amplitude.
    a_init : float
        Initial scale factor.

    Returns
    -------
    pos : array, shape (Np^3, 3)
        Initial positions.
    mom : array, shape (Np^3, 3)
        Initial momenta.
    q : array, shape (Np^3, 3)
        Lagrangian (grid) positions.
    """
    # Lagrangian grid
    q = make_uniform_grid(Np, L)
    k = 2 * jnp.pi * n_mode / L

    # Displacement field: Psi_x = -(A/k) sin(k q_x)
    psi_x = -(A / k) * jnp.sin(k * q[:, 0])

    # Positions: x = q + D_+(a) * Psi = q + a * Psi  (EdS: D_+ = a)
    pos = q.at[:, 0].add(a_init * psi_x)
    pos = pos % L  # periodic wrapping

    # Momenta: p = a^{3/2} * Psi  (EdS, H_0=1, m=1)
    mom = jnp.zeros_like(q)
    mom = mom.at[:, 0].set(a_init**1.5 * psi_x)

    return pos, mom, q

pos_init, mom_init, q_grid = make_zeldovich_ic(Np, L, n_mode, A, a_init)

Zel’dovich Prediction at Final Time

Code
def zeldovich_prediction(q, n_mode, A, a, L):
    """Zel'dovich position and momentum at scale factor a."""
    k = 2 * jnp.pi * n_mode / L
    psi_x = -(A / k) * jnp.sin(k * q[:, 0])

    pos_exact = q.at[:, 0].add(a * psi_x) % L
    mom_exact = jnp.zeros_like(q)
    mom_exact = mom_exact.at[:, 0].set(a**1.5 * psi_x)

    return pos_exact, mom_exact

Running the Simulation

Code
# Run to final time
pos_final, mom_final = run_simulation(
    pos_init, mom_init, a_init, a_final, n_steps,
    green_x, green_y, green_z, Ng, L, Omega_m
)

# Zel'dovich prediction at final time
pos_exact, mom_exact = zeldovich_prediction(q_grid, n_mode, A, a_final, L)

# --- Helpers for visualization ---

# Extract 1D slice: particles along the x-axis with fixed q_y, q_z.
# With indexing='ij', stride by Np^2 to vary x at fixed y,z.
slice_idx = jnp.arange(0, Np**3, Np**2)
q_x = q_grid[slice_idx, 0]
sort_idx = jnp.argsort(q_x)
q_x_sorted = q_x[sort_idx]

def displacement(pos, q, L):
    """x-displacement from grid, with periodic unwrapping."""
    dx = pos[:, 0] - q[:, 0]
    dx = dx - L * jnp.round(dx / L)
    return dx

Time Evolution

Before comparing in detail at the final time, let us visualize how the displacement evolves. We run to several intermediate scale factors and compare with the Zel’dovich prediction at each epoch.

Code
a_snapshots = [0.2, 0.5, 1.0, 1.5]
n_steps_snap = 100

fig, ax = plt.subplots(1, 1, figsize=(10, 5))

colors = plt.cm.viridis(np.linspace(0.15, 0.85, len(a_snapshots)))

for a_snap, color in zip(a_snapshots, colors):
    # Run simulation to this epoch
    pos_s, mom_s = run_simulation(
        pos_init, mom_init, a_init, a_snap, n_steps_snap,
        green_x, green_y, green_z, Ng, L, Omega_m
    )

    # Zel'dovich prediction
    pos_z, _ = zeldovich_prediction(q_grid, n_mode, A, a_snap, L)

    # Extract 1D slice
    dx_s = displacement(pos_s, q_grid, L)[slice_idx][sort_idx]
    dx_z = displacement(pos_z, q_grid, L)[slice_idx][sort_idx]

    ax.plot(q_x_sorted, dx_z, '-', color=color, lw=2)
    ax.plot(q_x_sorted, dx_s, '.', color=color, ms=3, alpha=0.7,
            label=f'$a = {a_snap}$')

ax.set_xlabel('$q_x$ (Lagrangian coordinate)')
ax.set_ylabel('$x - q_x$ (displacement)')
ax.legend(ncol=2)
ax.set_title(f'Displacement evolution: $n={n_mode}$, $A={A}$, '
             f'shell crossing at $a={1/A:.0f}$')
ax.axhline(0, color='0.7', lw=0.5)
plt.tight_layout()
plt.show()
Figure 20.3: Time evolution of the x-displacement for a single sine-wave perturbation. Solid lines show the exact Zel’dovich solution; dots show the PM simulation. The displacement grows linearly with \(a\) (since \(D_+ = a\) in EdS) and the sinusoidal shape steepens as the system approaches shell crossing at \(a = 1/A = 2\).

Comparing Results at Final Time

We compare the simulation output against the Zel’dovich prediction. Since the perturbation is along \(x\) only, we look at the \(x\)-displacement and \(x\)-momentum as functions of the Lagrangian coordinate \(q_x\).

Code
fig, axes = plt.subplots(2, 1, figsize=(10, 7), sharex=True)

dx_sim = displacement(pos_final, q_grid, L)[slice_idx][sort_idx]
dx_exact = displacement(pos_exact, q_grid, L)[slice_idx][sort_idx]

# Momentum
px_sim = mom_final[slice_idx, 0][sort_idx]
px_exact = mom_exact[slice_idx, 0][sort_idx]

# --- Displacement ---
ax = axes[0]
ax.plot(q_x_sorted, dx_exact, '-', color='C1', lw=2, label='Zel\'dovich (exact)')
ax.plot(q_x_sorted, dx_sim, '.', color='C0', ms=3, alpha=0.7, label='PM simulation')
ax.set_ylabel('$x - q_x$ (displacement)')
ax.legend()
ax.set_title(f'Single sine wave: $n={n_mode}$, $A={A}$, '
             f'$a_i={a_init}$ → $a_f={a_final}$, {n_steps} steps')

# --- Momentum ---
ax = axes[1]
ax.plot(q_x_sorted, px_exact, '-', color='C1', lw=2, label='Zel\'dovich (exact)')
ax.plot(q_x_sorted, px_sim, '.', color='C0', ms=3, alpha=0.7, label='PM simulation')
ax.set_ylabel('$p_x$ (momentum)')
ax.set_xlabel('$q_x$ (Lagrangian coordinate)')
ax.legend()

plt.tight_layout()
plt.show()
Figure 20.4: Comparison of the PM simulation (blue dots) against the exact Zel’dovich solution (orange line) for a single sine-wave perturbation. Top: x-displacement from the grid. Bottom: x-momentum. The agreement validates the CIC deposit, Poisson solver, force interpolation, and KDK integrator.
Code
fig, axes = plt.subplots(2, 1, figsize=(10, 5), sharex=True)

dx_err = (dx_sim - dx_exact)
px_err = (px_sim - px_exact)

ax = axes[0]
ax.plot(q_x_sorted, dx_err, '.', color='C0', ms=3, alpha=0.7)
ax.axhline(0, color='0.5', lw=0.5)
ax.set_ylabel('$\\Delta(x - q_x)$')
ax.set_title('Absolute error relative to Zel\'dovich solution')

ax = axes[1]
ax.plot(q_x_sorted, px_err, '.', color='C0', ms=3, alpha=0.7)
ax.axhline(0, color='0.5', lw=0.5)
ax.set_ylabel('$\\Delta p_x$')
ax.set_xlabel('$q_x$ (Lagrangian coordinate)')

plt.tight_layout()
plt.show()

# Summary statistics
print(f"Displacement: max |error| = {jnp.max(jnp.abs(dx_err)):.2e}")
print(f"Momentum:     max |error| = {jnp.max(jnp.abs(px_err)):.2e}")
Figure 20.5: Absolute error in displacement and momentum relative to the exact Zel’dovich solution. The errors combine both time-stepping error and spatial PM error (CIC window, discrete grid). See the convergence tests below to disentangle these contributions.
Displacement: max |error| = 3.58e-03
Momentum:     max |error| = 4.30e-03

The errors are at the \(\sim 10^{-3}\) level in both displacement and momentum. These errors have two sources: (1) time-stepping error from the KDK integrator, and (2) spatial error from the PM force calculation (CIC window function, discrete grid). We disentangle these below.

Convergence with Timestep

Let us verify the second-order convergence explicitly. A subtlety: if we compare against the exact Zel’dovich solution, the error is dominated by the spatial PM error (CIC window, discrete particle sampling) rather than the temporal error from the integrator. These spatial errors are independent of timestep count, creating a floor.

To isolate the time-stepping error, we compare each run against a high-resolution reference run (1000 steps) using the same PM code. The spatial errors cancel in the difference, leaving only the integrator error.

Code
# High-resolution reference
pos_ref, mom_ref = run_simulation(
    pos_init, mom_init, a_init, a_final, 1000,
    green_x, green_y, green_z, Ng, L, Omega_m
)
dx_ref = displacement(pos_ref, q_grid, L)[slice_idx][sort_idx]
px_ref = mom_ref[slice_idx, 0][sort_idx]

step_counts = [5, 10, 20, 50, 100, 200]
dx_errors = []
px_errors = []

for ns in step_counts:
    pos_f, mom_f = run_simulation(
        pos_init, mom_init, a_init, a_final, ns,
        green_x, green_y, green_z, Ng, L, Omega_m
    )
    dx_s = displacement(pos_f, q_grid, L)[slice_idx][sort_idx]
    px_s = mom_f[slice_idx, 0][sort_idx]

    dx_errors.append(float(jnp.max(jnp.abs(dx_s - dx_ref))))
    px_errors.append(float(jnp.max(jnp.abs(px_s - px_ref))))

fig, ax = plt.subplots(1, 1, figsize=(8, 5))
ax.loglog(step_counts, dx_errors, 'o-', color='C0', label='max $|\\Delta x|$')
ax.loglog(step_counts, px_errors, 's-', color='C1', label='max $|\\Delta p_x|$')

# Reference line: N^{-2}
ns_ref = np.array(step_counts, dtype=float)
ax.loglog(ns_ref, dx_errors[0] * (ns_ref[0] / ns_ref)**2, '--',
          color='0.5', label='$\\propto N_{\\mathrm{steps}}^{-2}$')

ax.set_xlabel('Number of timesteps')
ax.set_ylabel('Maximum absolute error (vs 1000-step reference)')
ax.legend()
ax.set_title(f'Time-stepping convergence: mode $n={n_mode}$, $A={A}$, '
             f'$a: {a_init}$ → ${a_final}$')
plt.tight_layout()
plt.show()
Figure 20.6: Convergence of the KDK leapfrog integrator. Errors are measured relative to a high-resolution reference run (1000 steps) to isolate the time-stepping error from the spatial PM error. Both displacement and momentum errors decrease as \(N_{\text{steps}}^{-2}\), confirming second-order convergence.

Convergence with Grid Resolution

The timestep convergence test above isolates the integrator error by comparing against a reference run with the same spatial resolution. But how large is the spatial PM error itself? We can test this by running the same problem at increasing grid resolution \(N_g = N_p\) (keeping 1 particle per cell and using enough timesteps that the time-stepping error is negligible) and comparing against the exact Zel’dovich solution.

Code
grid_sizes = [32, 64, 128]
n_steps_hires = 500  # enough steps to make time-stepping error negligible

dx_spatial_errors = []
px_spatial_errors = []

for N in grid_sizes:
    # Set up at this resolution
    green_xi, green_yi, green_zi = make_green_function(N, L)
    pos_i, mom_i, q_i = make_zeldovich_ic(N, L, n_mode, A, a_init)
    slice_i = jnp.arange(0, N**3, N**2)

    # Run simulation
    pos_f, mom_f = run_simulation(
        pos_i, mom_i, a_init, a_final, n_steps_hires,
        green_xi, green_yi, green_zi, N, L, Omega_m
    )

    # Exact solution
    pos_e, mom_e = zeldovich_prediction(q_i, n_mode, A, a_final, L)

    # Errors on the 1D slice
    q_xi = q_i[slice_i, 0]
    sort_i = jnp.argsort(q_xi)

    dx_f = displacement(pos_f, q_i, L)[slice_i][sort_i]
    dx_e = displacement(pos_e, q_i, L)[slice_i][sort_i]
    px_f = mom_f[slice_i, 0][sort_i]
    px_e = mom_e[slice_i, 0][sort_i]

    dx_spatial_errors.append(float(jnp.max(jnp.abs(dx_f - dx_e))))
    px_spatial_errors.append(float(jnp.max(jnp.abs(px_f - px_e))))

    print(f"Ng=Np={N:3d}: max |Δx| = {dx_spatial_errors[-1]:.2e}, "
          f"max |Δp| = {px_spatial_errors[-1]:.2e}")

fig, ax = plt.subplots(1, 1, figsize=(8, 5))
ax.loglog(grid_sizes, dx_spatial_errors, 'o-', color='C0',
          label='max $|\\Delta x|$')
ax.loglog(grid_sizes, px_spatial_errors, 's-', color='C1',
          label='max $|\\Delta p_x|$')

# Reference line: N^{-1}
ns_ref = np.array(grid_sizes, dtype=float)
ax.loglog(ns_ref, dx_spatial_errors[0] * (ns_ref[0] / ns_ref)**1, '--',
          color='0.5', label='$\\propto N_g^{-1}$')

ax.set_xlabel('Grid size $N_g = N_p$')
ax.set_ylabel('Maximum absolute error (vs exact Zel\'dovich)')
ax.legend()
ax.set_title(f'Spatial convergence: mode $n={n_mode}$, $A={A}$, '
             f'$a: {a_init}$ → ${a_final}$, {n_steps_hires} steps')
plt.tight_layout()
plt.show()
Ng=Np= 32: max |Δx| = 7.50e-03, max |Δp| = 9.23e-03
Ng=Np= 64: max |Δx| = 3.58e-03, max |Δp| = 4.30e-03
Ng=Np=128: max |Δx| = 1.77e-03, max |Δp| = 2.11e-03
Figure 20.7: Spatial convergence of the PM force. Each run uses 500 timesteps (so time-stepping error is negligible) at increasing grid resolution. The error against the exact Zel’dovich solution decreases with resolution, demonstrating that the PM spatial error is the dominant contribution to the total error at coarse grids.

The spatial error decreases as \(\sim N_g^{-1}\), not \(N_g^{-2}\). This may seem surprising since the discrete gradient and Laplacian operators are both second-order accurate. The bottleneck is CIC itself: CIC is a first-order (piecewise linear) interpolation scheme, producing a density field that is \(C^0\) but not \(C^1\). The resulting force — the gradient of the potential — is discontinuous at cell boundaries, with \(\mathcal{O}(h)\) jumps. This limits the maximum force error at particle positions to \(\mathcal{O}(h) = \mathcal{O}(N_g^{-1})\), regardless of how accurate the Poisson solver is.

Higher-order mass assignment schemes improve this: TSC (second-order B-spline) gives continuous forces and \(\mathcal{O}(N_g^{-2})\) spatial convergence.

20.11 Cosmological Initial Conditions

We now put the pieces together and run a cosmological simulation with realistic initial conditions drawn from the matter power spectrum.

The Recipe

To generate cosmological ICs we need:

  1. A matter power spectrum \(P(k)\) at the initial redshift — we use CLASS to compute this.
  2. A Gaussian random field \(\delta(\mathbf{k})\) on the grid, with variance set by \(P(k)\): each Fourier mode is drawn as \[ \hat{\delta}(\mathbf{k}) = \sqrt{P(k) / V} \; (g_1 + i\, g_2) / \sqrt{2} \] where \(g_1, g_2\) are independent standard normals and \(V = L^3\) is the box volume. The factor \(1/V\) converts the continuous power spectrum (dimensions of volume) to the variance of the discrete Fourier mode.
  3. Zel’dovich displacements from the density field: \(\hat{\Psi}_i(\mathbf{k}) = -i k_i / k^2 \, \hat{\delta}(\mathbf{k})\), then \(\mathbf{x} = \mathbf{q} + D_+(a_{\text{init}}) \boldsymbol{\Psi}\) and \(\mathbf{p} = a_{\text{init}}^{3/2} \boldsymbol{\Psi}\) (EdS).

Power Spectrum from CLASS

Code
from classy import Class

# Set up CLASS with Planck-like cosmology
cosmo = Class()
cosmo.set({
    'output': 'mPk',
    'h': 0.7,
    'omega_b': 0.02237,
    'omega_cdm': 0.1200,
    'A_s': 2.1e-9,
    'n_s': 0.9649,
    'P_k_max_h/Mpc': 50.0,
    'z_max_pk': 100.0
})
cosmo.compute()

h_class = cosmo.h()
print(f"CLASS cosmology: Omega_m = {cosmo.Omega_m():.4f}, "
      f"h = {h_class:.3f}, sigma8 = {cosmo.sigma8():.4f}")
CLASS cosmology: Omega_m = 0.2906, h = 0.700, sigma8 = 0.8300
NoteEdS Approximation

Our PM code uses Einstein–de Sitter dynamics (\(\Omega_m = 1\)). We use the CLASS power spectrum for the shape of \(P(k)\) (which encodes the transfer function, BAO, etc.) but evolve under EdS gravity. This is a common simplification in pedagogical codes — the large-scale structure looks qualitatively correct, but the growth rate and normalization differ from \(\Lambda\)CDM. A production code would use the correct \(H(a)\) and growth factor for the target cosmology.

Generating Gaussian Initial Conditions

Code
# Simulation parameters for cosmological run
Ng_cosmo = 128
Np_cosmo = 128
L_cosmo = 200.0       # Mpc/h (comoving box size)
z_init = 49.0
a_init_cosmo = 1.0 / (1.0 + z_init)

# Build the P(k) interpolator at z=0.
# The Zel'dovich displacement field Psi is defined via delta_0 (the
# present-day density), so we need P(k, z=0). The growth factor D_+(a)
# then scales the displacement to the initial time.
k_arr = np.logspace(-4, np.log10(50.0), 2000)  # k in h/Mpc
pk_arr = np.array([cosmo.pk(k * h_class, 0.0) * h_class**3
                   for k in k_arr])  # P(k, z=0) in (Mpc/h)^3

# Clean up CLASS
cosmo.struct_cleanup()

# Interpolate in log-log space
from scipy.interpolate import interp1d
log_pk_interp = interp1d(np.log(k_arr), np.log(pk_arr),
                         kind='cubic', fill_value=-100.0,
                         bounds_error=False)

def pk_of_k(k):
    """P(k) in (Mpc/h)^3, k in h/Mpc."""
    return np.exp(log_pk_interp(np.log(k)))

print(f"Box: L = {L_cosmo} Mpc/h, Ng = Np = {Ng_cosmo}")
print(f"Starting at z = {z_init} (a = {a_init_cosmo:.4f})")
print(f"Fundamental mode: k_f = {2*np.pi/L_cosmo:.4f} h/Mpc")
print(f"Nyquist: k_Ny = {np.pi*Ng_cosmo/L_cosmo:.2f} h/Mpc")
Box: L = 200.0 Mpc/h, Ng = Np = 128
Starting at z = 49.0 (a = 0.0200)
Fundamental mode: k_f = 0.0314 h/Mpc
Nyquist: k_Ny = 2.01 h/Mpc
Code
def make_cosmo_ic(Np, Ng, L, pk_func, a_init, seed=42):
    """
    Generate cosmological Zel'dovich ICs from a power spectrum.

    Parameters
    ----------
    Np : int
        Particles per dimension.
    Ng : int
        Grid cells per dimension.
    L : float
        Box size (same units as P(k), e.g. Mpc/h).
    pk_func : callable
        P(k) function, k in h/Mpc, returns (Mpc/h)^3.
    a_init : float
        Initial scale factor.
    seed : int
        Random seed.

    Returns
    -------
    pos, mom, q : arrays
        Positions, momenta, and Lagrangian grid.
    """
    h = L / Ng
    V = L**3

    # Wavenumber grid (angular wavenumber k in h/Mpc)
    # fftfreq returns frequency nu = n/(N*h); multiply by 2pi for k
    kfreq = 2 * np.pi * np.fft.fftfreq(Ng, d=h)   # k in h/Mpc
    kx, ky, kz = np.meshgrid(kfreq, kfreq, kfreq, indexing='ij')
    kmag = np.sqrt(kx**2 + ky**2 + kz**2)
    kmag[0, 0, 0] = 1.0  # avoid division by zero

    # Power spectrum on the grid
    pk_grid = pk_func(kmag)
    pk_grid[0, 0, 0] = 0.0  # no mean overdensity

    # Generate Gaussian random field by "coloring" white noise.
    # Start with real-space white noise, FFT it (reality condition is
    # automatic since the input is real), then multiply by sqrt(P(k))
    # to imprint the power spectrum.
    #
    # Normalization: E[|noise_hat_k|^2] = Ng^3 for white noise.
    # We want E[|delta_hat_k|^2] = Ng^6 P(k) / V, so multiply
    # noise_hat by Ng^{3/2} sqrt(P(k)/V).
    rng = np.random.default_rng(seed)
    white_noise = rng.standard_normal((Ng, Ng, Ng))
    noise_hat = np.fft.fftn(white_noise)

    amplitude = Ng**1.5 * np.sqrt(pk_grid / V)
    delta_hat = amplitude * noise_hat
    delta_hat[0, 0, 0] = 0.0  # no mean overdensity

    # Displacement field: Psi_i(k) = -i k_i / k^2 * delta_hat(k)
    inv_k2 = 1.0 / (kmag**2)
    inv_k2[0, 0, 0] = 0.0

    psi_hat_x = -1j * kx * inv_k2 * delta_hat
    psi_hat_y = -1j * ky * inv_k2 * delta_hat
    psi_hat_z = -1j * kz * inv_k2 * delta_hat

    # Transform to real space
    psi_x = np.fft.ifftn(psi_hat_x).real
    psi_y = np.fft.ifftn(psi_hat_y).real
    psi_z = np.fft.ifftn(psi_hat_z).real

    # Lagrangian grid
    q = make_uniform_grid(Np, L)
    q_np = np.array(q)

    # Interpolate displacement to particle positions
    # For Np = Ng, particles sit at grid centers — direct indexing
    idx = np.round(q_np / h - 0.5).astype(int) % Ng

    psi_at_q = np.stack([
        psi_x[idx[:, 0], idx[:, 1], idx[:, 2]],
        psi_y[idx[:, 0], idx[:, 1], idx[:, 2]],
        psi_z[idx[:, 0], idx[:, 1], idx[:, 2]],
    ], axis=-1)

    # Zel'dovich: x = q + D_+(a) * Psi, p = a^{3/2} * Psi  (EdS)
    pos = jnp.array((q_np + a_init * psi_at_q) % L)
    mom = jnp.array(a_init**1.5 * psi_at_q)

    return pos, mom, q

pos_cosmo, mom_cosmo, q_cosmo = make_cosmo_ic(
    Np_cosmo, Ng_cosmo, L_cosmo, pk_of_k, a_init_cosmo
)
print(f"Particles: {pos_cosmo.shape[0]:,}")
print(f"Initial rms displacement: {np.std(np.array(pos_cosmo - q_cosmo)):.4f} Mpc/h")
Particles: 2,097,152
Initial rms displacement: 0.1098 Mpc/h

Running the Cosmological Simulation

Code
# Precompute Green's function for the cosmological grid
green_cx, green_cy, green_cz = make_green_function(Ng_cosmo, L_cosmo)

# Run from z=49 to z=0 (a=0.02 to a=1)
a_final_cosmo = 1.0
n_steps_cosmo = 100

print(f"Running {Np_cosmo}^3 simulation: a = {a_init_cosmo:.4f}{a_final_cosmo}")
print(f"  {n_steps_cosmo} steps, Ng = {Ng_cosmo}, L = {L_cosmo} Mpc/h")

pos_cosmo_final, mom_cosmo_final = run_simulation(
    pos_cosmo, mom_cosmo, a_init_cosmo, a_final_cosmo, n_steps_cosmo,
    green_cx, green_cy, green_cz, Ng_cosmo, L_cosmo, 1.0
)
print("Done!")
Running 128^3 simulation: a = 0.0200 → 1.0
  100 steps, Ng = 128, L = 200.0 Mpc/h
Done!

The Cosmic Web

We visualize the result by depositing particles onto the grid and showing a thin slice through the density field.

Code
# Deposit final positions onto grid
rho_final = cic_deposit(pos_cosmo_final, Ng_cosmo, L_cosmo)
rho_bar = Np_cosmo**3 / Ng_cosmo**3
delta_final = rho_final / rho_bar - 1.0

# Take a thin slice (single grid plane)
slice_z = Ng_cosmo // 2

fig, ax = plt.subplots(1, 1, figsize=(8, 8))

im = ax.imshow(
    np.log10(np.maximum(1 + np.array(delta_final[:, :, slice_z]), 1e-6)).T,
    origin='lower',
    extent=[0, L_cosmo, 0, L_cosmo],
    cmap='magma',
    vmin=-0.5,
    vmax=2.0
)
ax.set_xlabel('$x$ [Mpc/$h$]')
ax.set_ylabel('$y$ [Mpc/$h$]')
ax.set_title(f'$z = 0$ density slice ($N_g = {Ng_cosmo}$, '
             f'$L = {L_cosmo}$ Mpc/$h$)')
plt.colorbar(im, ax=ax, label='$\\log_{10}(1 + \\delta)$', shrink=0.8)
plt.tight_layout()
plt.show()

# Statistics
print(f"Density field statistics at z=0:")
print(f"  min(delta) = {delta_final.min():.2f}")
print(f"  max(delta) = {delta_final.max():.1f}")
print(f"  rms(delta) = {jnp.std(delta_final):.2f}")
Figure 20.8: Slice through the evolved density field at \(z = 0\) from our \(128^3\) PM simulation with cosmological initial conditions. The familiar features of the cosmic web — filaments, nodes, and voids — are clearly visible. The color scale shows \(\log_{10}(1 + \delta)\).
Density field statistics at z=0:
  min(delta) = -1.00
  max(delta) = 494.1
  rms(delta) = 3.68
Code
a_snaps = [0.1, 0.2, 0.5, 1.0]
z_snaps = [1/a - 1 for a in a_snaps]

fig, axes = plt.subplots(1, 4, figsize=(18, 4.5),
                         gridspec_kw={'wspace': 0.05, 'right': 0.92})

for ax, a_snap, z_snap in zip(axes, a_snaps, z_snaps):
    pos_s, _ = run_simulation(
        pos_cosmo, mom_cosmo, a_init_cosmo, a_snap, n_steps_cosmo,
        green_cx, green_cy, green_cz, Ng_cosmo, L_cosmo, 1.0
    )
    rho_s = cic_deposit(pos_s, Ng_cosmo, L_cosmo)
    delta_s = rho_s / rho_bar - 1.0

    im = ax.imshow(
        np.log10(np.maximum(1 + np.array(delta_s[:, :, slice_z]), 1e-6)).T,
        origin='lower',
        extent=[0, L_cosmo, 0, L_cosmo],
        cmap='magma',
        vmin=-0.5,
        vmax=2.0
    )
    ax.set_title(f'$z = {z_snap:.0f}$')
    ax.set_xlabel('$x$ [Mpc/$h$]')
    if ax == axes[0]:
        ax.set_ylabel('$y$ [Mpc/$h$]')
    else:
        ax.set_yticklabels([])
    ax.set_aspect('equal')

cbar_ax = fig.add_axes([0.93, 0.15, 0.015, 0.7])
fig.colorbar(im, cax=cbar_ax, label='$\\log_{10}(1 + \\delta)$')
plt.show()
Figure 20.9: Evolution of the cosmic web from \(z = 9\) to \(z = 0\). Structure grows hierarchically: small-scale perturbations collapse first, then larger structures assemble through merging and accretion along filaments.

Animation

Finally, let us watch the cosmic web form in real time. We run the simulation to 50 intermediate epochs and assemble the density slices into a movie.

Code
from matplotlib.animation import FuncAnimation
from IPython.display import Video

# Scale factors for the movie frames (log-spaced)
n_frames = 50
a_frames = np.exp(np.linspace(np.log(a_init_cosmo), np.log(a_final_cosmo), n_frames))

# Run simulation to each epoch and store density slices
slices = []
for a_snap in a_frames:
    n_steps_snap = max(10, int(100 * (a_snap - a_init_cosmo) / (a_final_cosmo - a_init_cosmo)))
    pos_s, _ = run_simulation(
        pos_cosmo, mom_cosmo, a_init_cosmo, a_snap, n_steps_snap,
        green_cx, green_cy, green_cz, Ng_cosmo, L_cosmo, 1.0
    )
    rho_s = cic_deposit(pos_s, Ng_cosmo, L_cosmo)
    delta_s = rho_s / rho_bar - 1.0
    slices.append(
        np.log10(np.maximum(1 + np.array(delta_s[:, :, slice_z]), 1e-6)).T
    )

# Build animation
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
im = ax.imshow(slices[0], origin='lower',
               extent=[0, L_cosmo, 0, L_cosmo],
               cmap='magma', vmin=-0.5, vmax=2.0)
ax.set_xlabel('$x$ [Mpc/$h$]')
ax.set_ylabel('$y$ [Mpc/$h$]')
title = ax.set_title(f'$z = {1/a_frames[0] - 1:.1f}$')
fig.colorbar(im, ax=ax, label='$\\log_{10}(1 + \\delta)$', shrink=0.8)
fig.tight_layout()

def update(frame):
    im.set_data(slices[frame])
    z_val = 1/a_frames[frame] - 1
    title.set_text(f'$z = {z_val:.1f}$')
    return [im, title]

anim = FuncAnimation(fig, update, frames=n_frames, interval=150, blit=True)

# Save as MP4
movie_path = 'cosmic_web_evolution.mp4'
anim.save(movie_path, writer='ffmpeg', dpi=120)
plt.close(fig)

Video(movie_path, embed=True, html_attributes='controls loop autoplay muted width="700"')
Figure 20.10: Animation of the cosmic web forming from \(z = 49\) to \(z = 0\). The density field grows from nearly uniform initial conditions through gravitational instability alone — no hydrodynamics, no star formation, just dark matter and gravity on a \(128^3\) PM grid.

This is the cosmic web emerging from nearly uniform initial conditions through gravitational instability alone — no hydrodynamics, no star formation, just dark matter and gravity computed on a \(128^3\) grid with the PM algorithm we built from scratch in this lecture.