2.220446049250313e-16
2.220446049250313e-16
1.1920929e-07
8.881784197001252e-16
f = lambda x: np.cos(2*np.pi*x) + np.sin((4*np.pi*x) + 0.25)
df = lambda x: (-2*np.pi*np.sin(2*np.pi*x)
+ 4*np.pi*np.cos((4*np.pi*x) + 0.25))
x_ref = 0.65
x = np.linspace(0, 1, 101)
fig, ax = plt.subplots(figsize=(5, 4))
ax.plot(x, f(x), 'b', label='f(x)')
ax.plot(x_ref, f(x_ref), 'ro', label='$x_0$')
tangent = df(x_ref)*x + (f(x_ref) - df(x_ref)*x_ref)
ax.plot(x, tangent, 'm--', label="f'(x)")
ax.grid(alpha=0.5)
ax.set_xlabel('x')
ax.set_ylabel('f(x)')
ax.legend()
plt.tight_layout()
plt.show()
h = np.logspace(0, -5, 75)
fwd = (f(x_ref + h) - f(x_ref)) / h
bwd = (f(x_ref) - f(x_ref - h)) / h
cen = (f(x_ref + h) - f(x_ref - h)) / (2*h)
err_fwd = abs(fwd - df(x_ref))
err_bwd = abs(bwd - df(x_ref))
err_cen = abs(cen - df(x_ref))
fig, ax = plt.subplots(figsize=(5, 4))
ax.loglog(h, err_fwd, 'b--', label='Forward', alpha=0.75)
ax.loglog(h, err_bwd, 'r-.', label='Backward', alpha=0.75)
ax.loglog(h, err_cen, 'm-', label='Central')
ax.grid(which='both', alpha=0.5)
ax.set_xlabel('h (step size)')
ax.set_ylabel('Absolute Error')
ax.legend()
plt.tight_layout()
plt.show()
ddf = lambda x: (-4*(np.pi**2)*np.cos(2*np.pi*x)
- 16*(np.pi**2)*np.sin((4*np.pi*x)+0.25))
dddf = lambda x: (8*(np.pi**3)*np.sin(2*np.pi*x)
- 64*(np.pi**3)*np.cos((4*np.pi*x)+0.25))
fig, ax = plt.subplots(figsize=(5, 4))
ax.loglog(h, err_fwd, 'b--', alpha=0.4, label='Forward')
ax.loglog(h, err_bwd, 'r-.', alpha=0.4, label='Backward')
ax.loglog(h, abs(ddf(x_ref))*h/2, 'k:', label='O(h)')
ax.loglog(h, err_cen, 'm-', alpha=0.4, label='Central')
ax.loglog(h, abs(dddf(x_ref))*(h**2)/6, 'k-', label='$O(h^2)$')
ax.grid(which='both', alpha=0.5)
ax.set_xlabel('h (step size)')
ax.set_ylabel('Absolute Error')
ax.legend(fontsize=8)
plt.tight_layout()
plt.show()
ddddf = lambda x: (16*(np.pi**4)*np.cos(2*np.pi*x)
+ 256*(np.pi**4)*np.sin((4*np.pi*x)+0.25))
dddddf = lambda x: (-32*(np.pi**5)*np.sin(2*np.pi*x)
+ 1024*(np.pi**5)*np.cos((4*np.pi*x)+0.25))
A = (-f(x_ref+2*h) + 4*f(x_ref+h) - 3*f(x_ref)) / (2*h)
B = (f(x_ref+2*h) + f(x_ref+h) - f(x_ref) - f(x_ref-h)) / (4*h)
C = (2*f(x_ref+h) + 3*f(x_ref) - 6*f(x_ref-h) + f(x_ref-2*h)) / (6*h)
D = (-f(x_ref+2*h) + 8*f(x_ref+h) - 8*f(x_ref-h) + f(x_ref-2*h)) / (12*h)
fig, ax = plt.subplots(figsize=(5, 4))
for val, lbl, fmt in [(A,'A','b^'),(B,'B','r.'),(C,'C','g*'),(D,'D','co')]:
ax.loglog(h, abs(val - df(x_ref)), fmt+'-', label=f'Scheme {lbl}', alpha=0.5)
for err, lbl, fmt in [(abs(ddf(x_ref))*h/2,'O(h)','k:'),
(abs(dddf(x_ref))*(h**2)/6,'$O(h^2)$','k--'),
(abs(ddddf(x_ref))*(h**3)/24,'$O(h^3)$','k-.'),
(abs(dddddf(x_ref))*(h**4)/120,'$O(h^4)$','k-')]:
ax.loglog(h, err, fmt, label=lbl, alpha=0.6)
ax.grid(which='both', alpha=0.5)
ax.set_xlabel('h')
ax.set_ylabel('Absolute Error')
ax.legend(fontsize=7, ncol=2)
plt.tight_layout()
plt.show()
h_fine = np.logspace(-10, -3, 500)
fwd_q = (f(x_ref + h_fine) - f(x_ref)) / h_fine
bwd_q = (f(x_ref) - f(x_ref - h_fine)) / h_fine
cen_q = (f(x_ref + h_fine) - f(x_ref - h_fine)) / (2*h_fine)
err_fwd_q = abs(fwd_q - df(x_ref))
err_bwd_q = abs(bwd_q - df(x_ref))
err_cen_q = abs(cen_q - df(x_ref))
fig, ax = plt.subplots(figsize=(5, 4))
ax.loglog(h_fine, err_fwd_q, 'b--', label='Forward', alpha=0.75)
ax.loglog(h_fine, err_bwd_q, 'r-.', label='Backward', alpha=0.75)
ax.loglog(h_fine, err_cen_q, 'm-', label='Central', alpha=0.75)
ax.grid(which='both', alpha=0.5)
ax.set_xlabel('h (step size)')
ax.set_ylabel('Absolute Error')
ax.legend()
plt.tight_layout()
plt.show()
num_rel_fwd = ((np.spacing(f(x_ref+h_fine))
+ np.spacing(f(x_ref)))
/ abs(f(x_ref+h_fine) - f(x_ref)))
num_rel_bwd = ((np.spacing(f(x_ref))
+ np.spacing(f(x_ref-h_fine)))
/ abs(f(x_ref) - f(x_ref-h_fine)))
num_rel_cen = ((np.spacing(f(x_ref+h_fine))
+ np.spacing(f(x_ref-h_fine)))
/ abs(f(x_ref+h_fine) - f(x_ref-h_fine)))
den_rel = np.spacing(h_fine) / h_fine
fig, ax = plt.subplots(figsize=(5, 4))
ax.loglog(h_fine, err_fwd_q, 'b--', alpha=0.25, label='Fwd error')
ax.loglog(h_fine, abs(fwd_q)*(num_rel_fwd+den_rel), 'b', label='Fwd quant.')
ax.loglog(h_fine, err_bwd_q, 'r-.', alpha=0.25, label='Bwd error')
ax.loglog(h_fine, abs(bwd_q)*(num_rel_bwd+den_rel), 'r', label='Bwd quant.')
ax.loglog(h_fine, err_cen_q, 'm-', alpha=0.25, label='Cen error')
ax.loglog(h_fine, abs(cen_q)*(num_rel_cen+den_rel), 'm', label='Cen quant.')
ax.grid(which='both', alpha=0.5)
ax.set_xlabel('h')
ax.set_ylabel('Absolute Error')
ax.legend(fontsize=7)
plt.tight_layout()
plt.show()
FDxx = (f(x_ref+2*h) - 2*f(x_ref+h) + f(x_ref)) / h**2
BDxx = (f(x_ref) - 2*f(x_ref-h) + f(x_ref-2*h)) / h**2
CDxx = (f(x_ref+h) - 2*f(x_ref) + f(x_ref-h)) / h**2
CDxxO4 = (-f(x_ref+2*h) + 16*f(x_ref+h)
- 30*f(x_ref) + 16*f(x_ref-h)
- f(x_ref-2*h)) / (12*h**2)
fig, ax = plt.subplots(figsize=(5, 4))
ax.loglog(h, abs(FDxx - ddf(x_ref)), 'b^-', label='FD', alpha=0.45)
ax.loglog(h, abs(BDxx - ddf(x_ref)), 'r.-', label='BD', alpha=0.45)
ax.loglog(h, abs(CDxx - ddf(x_ref)), 'g*-', label='CD', alpha=0.25)
ax.loglog(h, abs(CDxxO4 - ddf(x_ref)), 'co-', label='CD $O(h^4)$', alpha=0.25)
ax.loglog(h, abs(ddf(x_ref))*h/2, 'k^:', label='O(h)', alpha=0.4)
ax.loglog(h, abs(dddf(x_ref))*(h**2)/6, 'k.:', label='$O(h^2)$', alpha=0.5)
ax.loglog(h, abs(dddddf(x_ref))*(h**4)/120, 'ko:', label='$O(h^4)$', alpha=0.5)
ax.grid(which='both', alpha=0.5)
ax.set_xlabel('h')
ax.set_ylabel('Absolute Error')
ax.legend(fontsize=7)
plt.tight_layout()
plt.show()
“Autodiff is criminally underused”
jax.grad: Computes the gradient of a function.jax.jit: Compiles functions for speed (using XLA).jax.vmap: Vectorizes functions (automatic batching).The Rosenbrock function is a classic optimization test case: \[f(x, y) = (a-x)^2 + b(y-x^2)^2\] Usually \(a=1, b=100\). It has a global minimum at \((1, 1)\).
import jax.numpy as jnp
from jax import grad, jit
def rosenbrock(params):
x, y = params
return (1.0 - x)**2 + 100.0 * (y - x**2)**2
# Compute the gradient function
grad_f = jit(grad(rosenbrock))
# Initial point
params = jnp.array([0.0, 0.0])
# One step of gradient descent
learning_rate = 0.001
for i in range(1000):
grads = grad_f(params)
params = params - learning_rate * grads
print(f"Final parameters: {params}")
print(f"Loss: {rosenbrock(params)}")Haunschild et al. — flexible model exploration with AD
Workflow overview from Haunschild et al.
Research landscape: making functions differentiable for ML