GitHub Repositories

A collection of useful GitHub repositories for JAX-based Reinforcement Learning.

Jax RL Implementations

Reinforcement Learning in JAX.

  • CleanRL  GitHub Stars for CleanRL

    High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)

  • Dopamax  GitHub Stars for Dopamax

    Reinforcement learning in pure JAX.

  • Earl  GitHub Stars for Earl

    Reinforcement learning with Equinox

  • FISOR  GitHub Stars for FISOR

    [ICLR 2024] The official implementation of 'Safe Offline Reinforcement Learning with Feasibility-Guided Diffusion Model'

  • Implicit Q-Learning  GitHub Stars for Implicit Q-Learning

    This repository contains the official implementation of Offline Reinforcement Learning with Implicit Q-Learning

  • Jax-Baseline  GitHub Stars for Jax-Baseline

    Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.

  • JAX-Corl  GitHub Stars for JAX-Corl

    Clean single-file implementation of offline RL algorithms in JAX

  • JaxCQL  GitHub Stars for JaxCQL

    Conservative Q learning in Jax

  • JaxGCRL  GitHub Stars for JaxGCRL

    Goal-Conditioned Reinforcement Learning with JAX

  • JaxIRL  GitHub Stars for JaxIRL

    JaxIRL ontains JAX implementation of algorithms for inverse reinforcement learning.

  • JaxMARL  GitHub Stars for JaxMARL

    Multi-Agent Reinforcement Learning with JAX

  • JaxRL  GitHub Stars for JaxRL

    JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces.

  • jym  GitHub Stars for jym

    JAX implementation of RL algorithms and vectorized environments

  • Mava  GitHub Stars for Mava

    A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX.

  • pax  GitHub Stars for pax

    Scalable Opponent Shaping Experiments in JAX

  • popjaxrl  GitHub Stars for popjaxrl

    Benchmarking RL for POMDPs in Pure JAX [Code for 'Structured State Space Models for In-Context Reinforcement Learning' (NeurIPS 2023)]

  • purejaxql  GitHub Stars for purejaxql

    Simple single-file baselines for Q-Learning in pure-GPU setting

  • PureJaxRL  GitHub Stars for PureJaxRL

    Really Fast End-to-End Jax RL Implementations.

  • Reinforced-lib  GitHub Stars for Reinforced-lib

    Reinforcement learning library.

  • Rejax  GitHub Stars for Rejax

    Hardware-Accelerated Reinforcement Learning Algorithms in pure Jax!

  • RL Basics  GitHub Stars for RL Basics

    Simple single file implementation of Deep Reinforcement Learning algorithms.

  • RL-Flax  GitHub Stars for RL-Flax

    Various reinforcement learning algorithms written in Jax + Flax.

  • RLax  GitHub Stars for RLax

    Building blocks for implementing RL agents.

  • rlbase_stable  GitHub Stars for rlbase_stable

    This is a codebase that implements simple reinforcement learning algorithms in JAX.

  • SBX  GitHub Stars for SBX

    Stable Baselines Jax (SB3 + Jax) RL algorithms

  • sebulba  GitHub Stars for sebulba

    🪐 The Sebulba architecture to scale reinforcement learning on Cloud TPUs in JAX

  • Skrl  GitHub Stars for Skrl

    Modular reinforcement learning library (on PyTorch and JAX) with support for NVIDIA Isaac Gym, Omniverse Isaac Gym and Isaac Lab

  • Stoix  GitHub Stars for Stoix

    A research-friendly codebase for fast experimentation of single-agent reinforcement learning in JAX.

  • UniFloral  GitHub Stars for UniFloral

    Unified Implementations of Offline Reinforcement Learning Algorithms

  • Wsrl  GitHub Stars for Wsrl

    JAX implementation of WSRL and RL baselines | ICLR 2025


Network Libraries

Neural networks and scientific computing in JAX.

  • Equinox  GitHub Stars for Equinox

    Elegant easy-to-use neural networks + scientific computing in JAX.

  • Flax  GitHub Stars for Flax

    High-performance neural network library for JAX.


Training

Repositories tailored to optimize and enhance the training process in JAX.

  • EasyDeL  GitHub Stars for EasyDeL

    Accelerate, Optimize performance with streamlined training and serving options with JAX.

  • Flashbax  GitHub Stars for Flashbax

    Accelerated Replay Buffers in JAX.

  • Minari  GitHub Stars for Minari

    A standard format for offline reinforcement learning datasets, with popular reference datasets and related utilities


Optimization Libraries

Gradient processing and optimization in JAX.

  • Optax  GitHub Stars for Optax

    Optax is a gradient processing and optimization library for JAX.


RL Environments in JAX

Reinforcement Learning environments written in JAX.

  • Brax  GitHub Stars for Brax

    Massively parallel rigidbody physics simulation on accelerator hardware.

  • Evojax  GitHub Stars for Evojax

    EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit.

  • Evorl  GitHub Stars for Evorl

    EvoRL is a fully GPU-accelerated framework for Evolutionary Reinforcement Learning, implemented with JAX

  • Gymnax  GitHub Stars for Gymnax

    RL Environments in JAX.

  • JaxMARL  GitHub Stars for JaxMARL

    Multi-Agent Reinforcement Learning with JAX

  • Jumanji  GitHub Stars for Jumanji

    A diverse suite of scalable reinforcement learning environments in JAX.

  • Kinetix  GitHub Stars for Kinetix

    Reinforcement learning on general 2D physics environments in JAX.

  • Navix  GitHub Stars for Navix

    Accelerated minigrid environments with JAX

  • Pgx  GitHub Stars for Pgx

    Vectorized RL game environments in JAX

  • popjym  GitHub Stars for popjym

    POPGym Library in JAX

  • XLand-MiniGrid  GitHub Stars for XLand-MiniGrid

    JAX-accelerated Meta-Reinforcement Learning Environments Inspired by XLand and MiniGrid


Tools

Additional repositories that may be of interest.

  • Chex  GitHub Stars for Chex

    Chex is a library of utilities for helping to write reliable JAX code.

  • Distrax  GitHub Stars for Distrax

    Distrax is a lightweight library of probability distributions and bijectors. It acts as a JAX-native reimplementation of a subset of TensorFlow Probability (TFP).

  • Haliax  GitHub Stars for Haliax

    Named Tensors for Legible Deep Learning in JAX.

  • JaxTyping  GitHub Stars for JaxTyping

    Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays.

  • Orbax  GitHub Stars for Orbax

    Orbax provides common checkpointing and persistence utilities for JAX users.

  • Treescope  GitHub Stars for Treescope

    An interactive HTML pretty-printer for machine learning research in IPython notebooks.