FURAX: a modular JAX toolbox for solving inverse problems in science

Simon Biquard, Pierre Chanial, Wassim Kabalan

2024-12-19

Outline

  1. Motivations and goals: why and for what?
  2. Presentation of the framework
  3. Applications in CMB science

Motivations and Goals

  • Inverse problems: assuming data \(d_{obs} = F(p)\), find the parameters \(p\)
  • Pure Python, open source: https://github.com/CMBSciPol/furax
    • pip install furax (work in progress!)
  • Modular & extensible: Easy to experiment new ideas
  • High-performance: JAX (see next slide)


  • Maximum likelihood and template map-making as in El Bouhargani et al. 2021
  • Non-ideal, frequency-dependent optical components
  • Provide tools for next generation experiments (SO, CMB-S4, LiteBIRD)
    • Interfaces with TOAST, sotodlib
    • Large data sets: multi-GPU parallelization is underway

What is JAX

From the JAX website:

JAX is a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.

Key features

  • NumPy-like interface with CPU/GPU/TPU support in local and distributed environments
  • Just-in-time (JIT) compilation via Open XLA
  • Automatic differentiation
  • Automatic vectorization

PyTrees

FURAX relies on PyTrees to represent the data.

Example: random sky with 3 components

sky = {
  'cmb': HealpixLandscape(nside, 'IQU').normal(key1),
  'dust': HealpixLandscape(nside, 'IQU').normal(key2),
  'synchrotron': HealpixLandscape(nside, 'IQU').normal(key3),
}

HealpixLandscape(nside, 'IQU') returns an instance of StokesIQUPyTree, a container for the Stokes parameters I, Q, U.

Use FrequencyLandscape to generalize to multiple frequencies.

frequencies = np.array([93, 145])
sky = {
  'cmb': FrequencyLandscape(nside, frequencies, 'IQU').normal(key1),
  'dust': FrequencyLandscape(nside, frequencies, 'IQU').normal(key2),
  'synchrotron': FrequencyLandscape(nside, frequencies, 'IQU').normal(key3),
}

Operators

The base class AbstractLinearOperator provides a default implementation for the usual linear algebra operations.

Operation FURAX Comment
Addition A + B
Composition A @ B
Multiplication by scalar k * A Returns the composition of a HomothetyOperator and A
Transpose A.T Through JAX autodiff, but can be overriden
Inverse A.I By default, the CG solver is used, but it can be overriden or configured using a context manager
Block Assembly BlockColumnOperator([A, B]) BlockDiagonalOperator([A, B]) BlockRowOperator([A, B]) Handle any PyTree of Operators: Block*Operator({'a': A, 'b': B})
Flattened dense matrix A.as_matrix()
Algebraic reduction A.reduce()

Operators (cont’d)

Generic Operator Description
IdentityOperator
HomothetyOperator
DiagonalOperator
BroadcastDiagonalOperator Non-square operator for broadcasting
TensorOperator For dense matrix operations
IndexOperator Can be used for projecting skies onto time-ordered series
MoveAxisOperator
ReshapeOperator
RavelOperator
SymmetricBandToeplitzOperator Methods: direct, FFT, overlap and save
Block*Operator Block assembly operators (column, diagonal, row)
Applied Operator Description
QURotationOperator
HWPOperator Ideal HWP
LinearPolarizerOperator Ideal linear polarizer
CMBOperator Parametrized CMB SED
DustOperator Parametrized dust SED
SynchrotronOperator Parametrized synchrotron SED

Algebraic reductions: Rotations

Classic acquisition model with ideal linear polarizer \(\mathbf{M}_{\textrm{LP}}\), ideal half wave plate \(\mathbf{M}_{\textrm{HWP}}\), rotations \(\mathbf{R}\), pointing matrix \(\mathbf P\):

\[ \mathbf{H} = \mathbf{M}_{\textrm{LP}} \, \mathbf{R}_{2\theta} \, \mathbf{R}_{-2\phi} \, \mathbf{M}_{\textrm{HWP}} \, \mathbf{R}_{2\phi} \, \mathbf{R}_{2\psi} \, \mathbf{P} \]

where

  • \(\theta\): polarizer angle
  • \(\phi\): half-wave plate angle
  • \(\psi\): telescope angle

is reduced automatically to the much simpler

\[ \mathbf{H} = \mathbf{M}_{\textrm{LP}} \, \mathbf{R}_{-2\theta + 4\phi + 2\psi}\, \mathbf{P} \]

CMB Applications


Credits: Ema Tsang, Wassim Kabalan, Amalia Villarrubia & the whole SciPol team

Generalized Map-Making

Classic data model

\[ d = \mathbf{P}s + n \]

  • \(d\): time-ordered data
  • \(\mathbf{P}\): pointing matrix (telescope scanning)
  • \(s\): discretized sky signal
  • \(n\): stochastic contribution (noise)


Optimal (GLS) solution

\[ \widehat{s} = (\mathbf{P}^\top \mathbf{N}^{-1} \mathbf{P})^{-1} \mathbf{P}^\top \mathbf{N}^{-1} d \]

Generalized parametric data model

\[ d_{\nu, i, t} = \int_{\textrm{BP}_\nu} d\nu' \mathbf{M}^{(\gamma)}_{\nu', i, t, p} \mathbf{A}^{(\beta)}_{\nu', t, c, p} s_{c, p} + n_{\nu, i, t} \]

  • \(\mathbf{M}\): instrument matrix (pointing, HWP parameters, bandpasses, beam properties, etc.)
  • \(\mathbf{A}\): mixing matrix (modeling of CMB, astrophysical foregrounds, atmosphere, etc.)
  • \(\mathbf{H} = \mathbf{MA}\) is the generalized pointing operator

Time-domain noise correlations

Noise correlations in a stationary period correspond to a symmetric Toeplitz matrix structure.

SymmetricBandToeplitzOperator with optimized matrix-vector operations in \(\mathcal O(n \log \lambda)\) (overlap-and-save method).

Impact of gaps. Credits: B3DCMB

Restoring stationarity

To work around this problem, one solution is to fill the gaps with synthetic samples consistent with noise.

Furax’s GapFillingOperator computes a constrained noise realization from an estimate of the noise correlations.

Toy example. Only gaps are modified.

Non-Ideal HWP

Realistic HWP operator

  • several layers stacked
  • transmission + reflection of incident EM field at the boundaries

Cf. great presentation by Miguel Gomes yesterday!

Mueller matrix coefficients in angle-frequency space for the SO mid-frequency SAT

Component separation

tldr: Does everything fgbuster does, but better

  • Furax operators efficiently represent the mixing matrix
  • hardware accelerated
  • easy access to gradients and hessians

Beyond fgbuster

  • Automatic cluster detection for spectral index parameters
  • Very flexible model: straight-forward extensions to other sky components, different objective functions, etc.

Cost of evaluating the likelihood function is reduced by a factor 10 for \(\textrm{nside} \geq 64\).

This will power the map-based pipeline for \(r\) estimation in SO (cf. presentation by Baptiste Jost earlier today).

Atmosphere decontamination: time-domain component separation

Goal: For the Simons Observatory, characterize the observed atmospheric template from the recorded time-ordered data to separate the atmosphere from the sky signal we are after.

Detector array scanning the sky (signal has arbitrary units)

Data Model

\[ d_{\text{atm}} = \mathbf{A}(\text{pwv}) \mathbf{P}(\vec{w}) s_{\text{atm}} + n \]

with parameters:

  • Wind velocity: \(\vec{w} = (w_x, w_y)\)
  • Precipitable Water Vapour (PWV): ~Amplitude of atmospheric fluctuations.

Estimate parameters by minimizing the spectral likelihood.

\[ \boxed{ \langle \delta \mathcal{S}_\text{spec}(\vec{w}, \text{pwv} \mid d_{\text{atm}}) \rangle = d_{\text{atm}}^\top \cdot \mathbf{AP} \Big[ (\mathbf{AP})^\top \mathbf{N}^{-1} (\mathbf{AP}) \Big]^{-1} (\mathbf{AP})^\top \mathbf{N}^{-1} d_{\text{atm}} } \]

Atmosphere decontamination: spectral likelihood minimization


Spectral likelihood values in the \((w_x, w_y)\) plane for a fixed PWV value.

Minimization is done by brute force: we compute \(\langle \delta \mathcal{S}_\text{spec}(w_x, w_y \mid \text{pwv}_{\text{sim}}) \rangle\) for 22,500 different combinations of \((w_x, w_y)\).

Proof of concept: we can recover the wind parameters!

Future work:

  • make the likelihood smooth (differentiable) by interpolating the pointing matrix coefficients
  • use a gradient-based minimization algorithm

Conclusion

  • CMB polarization analysis mixes instrument + foregrounds + cosmology
  • Need for robust and efficient tools
  • We are building a python-based toolbox
    • using jax for performance and portability
    • able to represent complex instrument models
  • Applications: map-making, component separation, atmosphere decontamination, etc.
  • Goals for early 2025
    • unbiased map-making + map-based pipeline for SO-SAT
    • cluster component separation for LiteBIRD

If you are interested, check out our the repository on GitHub: CMBSciPol/furax.

This work is part of the ERC project SciPol (https://scipol.in2p3.fr/).