import jax
import jax.numpy as jnp
from jaxify import jaxify
@jax.jit
@jax.vmap
@jaxify # <-- Just decorate your function with @jaxify
def absolute_value(x):
if x >= 0: # <-- If block in a JIT-compiled function
return x
else:
return -x
xs = jnp.arange(-1000, 1000)
ys = absolute_value(xs) # <-- Runs at JAX speed!
print(ys)
The @jaxify decorator transforms Python functions using static analysis to replace unsupported Python constructs with JAX-compatible alternatives. After the transformations, the functions become traceable by JAX, enabling you to apply functional JAX transformations like @jax.jit and @jax.vmap in a seamless manner.
The following Python constructs are currently supported within @jaxify-decorated functions:
| Construct |
Works? |
Notes |
if statements |
✅ |
Fully supported including elif and else clauses. Translated to calls to jax.lax.cond |
if expressions (e.g. a if b else c) |
✅ |
Translated to jax.lax.cond |
| Construct |
Works? |
Notes |
==, !=, <, >, <=, >= |
✅ |
Chained comparisons (e.g. x < y <= z) are supported by translation to the equivalent chain of individual comparisons |
| Construct |
Works? |
Notes |
and / or |
✅ |
Short-circuiting of traced values supported via translation to jax.lax.cond calls |
not |
✅ |
Translates to jnp.logical_not for traced single values |
| Construct |
Works? |
Notes |
match-case |
✅⚠️ |
Static values only. For traced values, use an if-elif-else chain or jax.lax.switch instead |