JAX-Fluids: A Differentiable Fluid Dynamics Package for Machine-Learning Accelerated CFD
General classification
Programming
language
Code
architectures
Static stretched mesh
finite volume simulations
Fluid flows are omnipresent in nature and engineering applications. The numerical prediction of compressible fluid flows is challenging due to intrinsic nonlinear mechanism and complex flow phenomena like turbulence, shock waves, material interfaces, and chemical reactions. Accurate computational fluid dynamics (CFD) simulations require sophisticated numerical methods tailored to aforementioned problems. The recent success of machine learning (ML) for solving partial differential equations (PDEs) has shown that ML has the potential to assist and accelerate conventional CFD. The symbiosis of ML and CFD requires novel CFD solvers and paradigms. In our effort to facilitate machine-learning assisted computational fluid dynamics, we have developed JAX-Fluids. JAX-Fluids is a comprehensive fully-differentiable CFD solver for compressible flows written in JAX. JAX-Fluids solves the compressible single-phase, two-phase, and multicomponent Navier-Stokes equations. Every single routine of JAX-Fluids is automatically differentiable which allows end-to-end optimization of numerical methods.
Key features of JAX-Fluids include amongst others:
- Seamless integration of machine learning models into a multi-physics CFD solver.
- Differentiable algorithms allowing end-to-end optimization of data-driven models.
- Rapid prototyping capability the high-level programming language Python.
- Efficient scaling on CPU, GPU, and TPU HPC systems.
Our current efforts include:
- Development towards a differentiable multi-physics solver.
- Development of numerical methods for two-phase flows.
- End-to-end optimization of surrogate models.
- Inverse problems.
- Uncertainty quantification.
Machine learning integration
JAX-Fluids was developed with the goal to facilitate research at the intersection of computational fluid dynamics and machine learning. Integration and hybridization of machine learning models with established CFD algorithms is a key feature of JAX-Fluids. For this reason, we have chosen JAX as the backend of our CFD solver. This allows seamless integration of data-driven model into the CFD solver and end-to-end optimization of neural networks models.
Integration of neural networks
JAX is a Python library designed for high-performance numerical computing and machine learning. Integration of neural network-based subroutines into the JAX-Fluids is easy. For example, we have developed a neural network-based WENO3-NN cell-face reconstruction which can replace the classical WENO3-JS reconstruction. The schematic of the architecture shows that input and output of the WENO3-NN are identical to the classical WENO3-JS scheme, however, in WENO3-NN a neural network replaces the analytical smoothness measure.
End-to-end optimization
The entire CFD algorithm is automatically differentiable which allows neural network training in an end-to-end fashion. We illustrate the end-to-end optimization loop in JAX-Fluids by the training process of a machine-learned ILES model (ML-ILES). Starting from an initial condition at , we unroll the ML-CFD model with trainable parameters for steps in JAX-Fluids and obtain a spatio-temporal trajectory (labeled JXF). Comparing the JXF trajectory with a ground truth (here, a filtered DNS trajectory) we calculate the loss . Using the JAX routine jax.grad we obtain the gradient which we use to update model parameters.
Applications
Turbulent
flows
Multi-phase
flows
Reactive
flows
JAX-Fluids stands out as a versatile solution for simulating intricate multi-phase, reactive, and turbulent flows. Its multi-phase models demonstrate efficacy, particularly evident in scenarios involving the interaction of strong shocks with drops or bubbles. Additionally, JAX-Fluids enables detailed simulations of reactive flows handling complex chemistry. Another notable strength lies in its ability to numerically investigate turbulent flows effectively. Leveraging the sharp-interface level-set method, JAX-Fluids allows to incorporate complex geometries as immersed boundaries, further enhancing its utility and flexibility.
Multi-phase flows
JAX-Fluids implements two different two-phase models: a level-set based sharp-interface method (LSM) and a five-equation diffuse-interface method (DIM). The LSM maintains a sharp interface throughout the simulation while the DIM allows artificial mixing in a narrow region around the interface. The entire LSM and DIM algorithms are automatically differentiable.
Shock-bubble interaction
Shock-drop interaction
Shock-bubble array interaction simulated on TPUs
Diffuse-interface method (DIM) Levelset methond (LSM)
Reactive multi-component flows
JAX-Fluids implements functionality to simulate reactive multicomponent flows. In particular, we solve the compressible reactive multicomponent Navier-Stokes equations with detailed chemistry and transport models. Material properties and reaction mechanisms can be conveniently setup using Cantera input files. The implementation has been verified on canonical problems, e.g., comparing JAX-Fluids results with Cantera for hydrogen-oxygen combustion (San Diego mechanism) at stoichiometric conditions or a comparison for methane-oxygen combustion (GRI-Mech 3.0 mechanism). Furthermore, we studied a reactive shock-bubble simulation in which a shock waves impinges upon a hydrogen-oxygen-xenon bubble. The mixture ignites once pressure and temperature have reached suitable values due to shock focusing.
Canonical validation
Reactive shock-bubble (hydrogen-oxygen-xenon bubble)
Compressible turbulent flows
JAX-Fluids supports the simulation of compressible turbulent flows featuring shock waves. In particular, JAX-Fluids has capabilities to perform Direct Numerical Simulations (DNS) and Implicit Large-Eddy Simulations (ILES). We utilize state-of-the-art non-linear discretization schemes with tailored subgrid-scale behavior like the ALDM method or TENO-type schemes with adaptive dissipation control. Recently, we have started exploring Machine-Learned Implicit Large-Eddy Simulations (ML-ILES).
Turbulent boundary layer at Ma = 2
Turbulent channel at Ma = 1.5
Compressible Homogeneous Isotropic Turbulence (HIT)
Immersed boundary
We utilize the sharp-interface level-set formulation to model static and dynamic immersed solid boundaries. This allows the simulation of fluid flows in and around complex geometries. One-way and two-way coupling between fluid and solid phases are supported. For one-way coupling, temperature and velocity of the rigid solid body are specified by the user.
Publications
Bezgin, D., Buhendwa, A. & Adams, N. A. (2023): Jax-Fuids: A fully-differentiable high-order computational fluid dynamics solver for compressible two-phase flows. Computer Physics Communications, 282, 108527.