jaxtyping v0.2.37

https://github.com/patrick-kidger/jaxtyping/releases/tag/v0.2.37

  • Added fp8 dtypes:

    • jaxtyping.Float8e4m3b11fnuz
    • jaxtyping.Float8e4m3fn
    • jaxtyping.Float8e4m3fnuz
    • jaxtyping.Float8e5m2
    • jaxtyping.Float8e5m2fnuz
  • Static type-checking compatibility when decorating dataclasses with @jaxtyped

  • Now pretty-printing error messages using the wadler_lindig library. In particular this means that PyTorch tensors etc. won’t be printed out in their entirety, and will be summarised into just their shape and dtype.