jax.grad: Computes the gradient of a function.jax.jit: Compiles functions for speed (using XLA).jax.vmap: Vectorizes functions (automatic batching).The Rosenbrock function is a classic optimization test case: \[f(x, y) = (a-x)^2 + b(y-x^2)^2\] Usually \(a=1, b=100\). It has a global minimum at \((1, 1)\).
import jax.numpy as jnp
from jax import grad, jit
def rosenbrock(params):
x, y = params
return (1.0 - x)**2 + 100.0 * (y - x**2)**2
# Compute the gradient function
grad_f = jit(grad(rosenbrock))
# Initial point
params = jnp.array([0.0, 0.0])
# One step of gradient descent
learning_rate = 0.001
for i in range(1000):
grads = grad_f(params)
params = params - learning_rate * grads
print(f"Final parameters: {params}")
print(f"Loss: {rosenbrock(params)}")