Optax

Optax

Optax is a gradient processing and optimization library for JAX.