jaxtyping v0.2.37
https://github.com/patrick-kidger/jaxtyping/releases/tag/v0.2.37
Added fp8 dtypes:
jaxtyping.Float8e4m3b11fnuz
jaxtyping.Float8e4m3fn
jaxtyping.Float8e4m3fnuz
jaxtyping.Float8e5m2
jaxtyping.Float8e5m2fnuz
Static type-checking compatibility when decorating dataclasses with
@jaxtyped
Now pretty-printing error messages using the wadler_lindig library. In particular this means that PyTorch tensors etc. won’t be printed out in their entirety, and will be summarised into just their shape and dtype.