f = lambda x: np.cos(2*np.pi*x) + np.sin((4*np.pi*x) + 0.25)
df = lambda x: (-2*np.pi*np.sin(2*np.pi*x)
+ 4*np.pi*np.cos((4*np.pi*x) + 0.25))
x_ref = 0.65
x = np.linspace(0, 1, 101)
fig, ax = plt.subplots(figsize=(5, 4))
ax.plot(x, f(x), 'b', label='f(x)')
ax.plot(x_ref, f(x_ref), 'ro', label='$x_0$')
tangent = df(x_ref)*x + (f(x_ref) - df(x_ref)*x_ref)
ax.plot(x, tangent, 'm--', label="f'(x)")
ax.grid(alpha=0.5)
ax.set_xlabel('x')
ax.set_ylabel('f(x)')
ax.legend()
plt.tight_layout()
plt.show()Autodifferentiation
Optimization and Gradients
Optimization in Machine Learning
- Most ML problems involve finding parameters \(\theta\) that minimize a loss function \(L(\theta)\).
- How do we find the minimum?
- If \(L\) is convex, we can follow the negative gradient \(-\nabla L(\theta)\).
The Role of Gradients
- The gradient \(\nabla L(\theta)\) is a vector of partial derivatives: \[\nabla L(\theta) = \left[ \frac{\partial L}{\partial \theta_1}, \frac{\partial L}{\partial \theta_2}, \dots, \frac{\partial L}{\partial \theta_N} \right]\]
- It tells us the direction of steepest ascent.
- Moving in the opposite direction (gradient descent) helps us find local minima.
The Curse of Dimensionality
Growth of Parameter Space
- As the number of dimensions \(N\) increases, the volume of the space increases exponentially.
- If we want to sample a 10-dimensional space with 10 points per dimension, we need \(10^{10}\) points.
- Modern neural networks have millions or billions of parameters.
Why Gradients Matter Here
- In high-dimensional spaces, “guessing” the right direction is impossible.
- Gradients provide a local “compass,” pointing exactly where to go to decrease the loss.
- Without efficient gradients, deep learning would be computationally infeasible.
Numerical Differencing
Finite Differences
- The simplest way to approximate a derivative: \[f'(x) \approx \frac{f(x+h) - f(x)}{h}\]
- For a gradient in \(N\) dimensions, we need to evaluate \(f(x)\) at least \(N+1\) times.
Limitations of Finite Differences
- Computational Cost: \(O(N)\) evaluations per gradient step. If \(N=1,000,000\), this is too slow.
- Accuracy:
- If \(h\) is too large, the approximation is poor (truncation error).
- If \(h\) is too small, we hit floating-point precision limits (round-off error).
- Stability: Sensitive to noisy functions.
Demonstration: Test Function
Error vs. Step Size
h = np.logspace(0, -5, 75)
fwd = (f(x_ref + h) - f(x_ref)) / h
bwd = (f(x_ref) - f(x_ref - h)) / h
cen = (f(x_ref + h) - f(x_ref - h)) / (2*h)
err_fwd = abs(fwd - df(x_ref))
err_bwd = abs(bwd - df(x_ref))
err_cen = abs(cen - df(x_ref))
fig, ax = plt.subplots(figsize=(5, 4))
ax.loglog(h, err_fwd, 'b--', label='Forward', alpha=0.75)
ax.loglog(h, err_bwd, 'r-.', label='Backward', alpha=0.75)
ax.loglog(h, err_cen, 'm-', label='Central')
ax.grid(which='both', alpha=0.5)
ax.set_xlabel('h (step size)')
ax.set_ylabel('Absolute Error')
ax.legend()
plt.tight_layout()
plt.show()Convergence Orders
ddf = lambda x: (-4*(np.pi**2)*np.cos(2*np.pi*x)
- 16*(np.pi**2)*np.sin((4*np.pi*x)+0.25))
dddf = lambda x: (8*(np.pi**3)*np.sin(2*np.pi*x)
- 64*(np.pi**3)*np.cos((4*np.pi*x)+0.25))
fig, ax = plt.subplots(figsize=(5, 4))
ax.loglog(h, err_fwd, 'b--', alpha=0.4, label='Forward')
ax.loglog(h, err_bwd, 'r-.', alpha=0.4, label='Backward')
ax.loglog(h, abs(ddf(x_ref))*h/2, 'k:', label='O(h)')
ax.loglog(h, err_cen, 'm-', alpha=0.4, label='Central')
ax.loglog(h, abs(dddf(x_ref))*(h**2)/6, 'k-', label='$O(h^2)$')
ax.grid(which='both', alpha=0.5)
ax.set_xlabel('h (step size)')
ax.set_ylabel('Absolute Error')
ax.legend(fontsize=8)
plt.tight_layout()
plt.show()Higher-Order Schemes
ddddf = lambda x: (16*(np.pi**4)*np.cos(2*np.pi*x)
+ 256*(np.pi**4)*np.sin((4*np.pi*x)+0.25))
dddddf = lambda x: (-32*(np.pi**5)*np.sin(2*np.pi*x)
+ 1024*(np.pi**5)*np.cos((4*np.pi*x)+0.25))
A = (-f(x_ref+2*h) + 4*f(x_ref+h) - 3*f(x_ref)) / (2*h)
B = (f(x_ref+2*h) + f(x_ref+h) - f(x_ref) - f(x_ref-h)) / (4*h)
C = (2*f(x_ref+h) + 3*f(x_ref) - 6*f(x_ref-h) + f(x_ref-2*h)) / (6*h)
D = (-f(x_ref+2*h) + 8*f(x_ref+h) - 8*f(x_ref-h) + f(x_ref-2*h)) / (12*h)
fig, ax = plt.subplots(figsize=(5, 4))
for val, lbl, fmt in [(A,'A','b^'),(B,'B','r.'),(C,'C','g*'),(D,'D','co')]:
ax.loglog(h, abs(val - df(x_ref)), fmt+'-', label=f'Scheme {lbl}', alpha=0.5)
for err, lbl, fmt in [(abs(ddf(x_ref))*h/2,'O(h)','k:'),
(abs(dddf(x_ref))*(h**2)/6,'$O(h^2)$','k--'),
(abs(ddddf(x_ref))*(h**3)/24,'$O(h^3)$','k-.'),
(abs(dddddf(x_ref))*(h**4)/120,'$O(h^4)$','k-')]:
ax.loglog(h, err, fmt, label=lbl, alpha=0.6)
ax.grid(which='both', alpha=0.5)
ax.set_xlabel('h')
ax.set_ylabel('Absolute Error')
ax.legend(fontsize=7, ncol=2)
plt.tight_layout()
plt.show()Round-off Error
h_fine = np.logspace(-10, -3, 500)
fwd_q = (f(x_ref + h_fine) - f(x_ref)) / h_fine
bwd_q = (f(x_ref) - f(x_ref - h_fine)) / h_fine
cen_q = (f(x_ref + h_fine) - f(x_ref - h_fine)) / (2*h_fine)
err_fwd_q = abs(fwd_q - df(x_ref))
err_bwd_q = abs(bwd_q - df(x_ref))
err_cen_q = abs(cen_q - df(x_ref))
fig, ax = plt.subplots(figsize=(5, 4))
ax.loglog(h_fine, err_fwd_q, 'b--', label='Forward', alpha=0.75)
ax.loglog(h_fine, err_bwd_q, 'r-.', label='Backward', alpha=0.75)
ax.loglog(h_fine, err_cen_q, 'm-', label='Central', alpha=0.75)
ax.grid(which='both', alpha=0.5)
ax.set_xlabel('h (step size)')
ax.set_ylabel('Absolute Error')
ax.legend()
plt.tight_layout()
plt.show()Machine Epsilon
print(7./3 - 4./3 - 1)
print(np.finfo(float).eps)
print(np.finfo(np.float32).eps)
print(np.spacing(5))2.220446049250313e-16
2.220446049250313e-16
1.1920929e-07
8.881784197001252e-16
Visualizing Quantization Error
num_rel_fwd = ((np.spacing(f(x_ref+h_fine))
+ np.spacing(f(x_ref)))
/ abs(f(x_ref+h_fine) - f(x_ref)))
num_rel_bwd = ((np.spacing(f(x_ref))
+ np.spacing(f(x_ref-h_fine)))
/ abs(f(x_ref) - f(x_ref-h_fine)))
num_rel_cen = ((np.spacing(f(x_ref+h_fine))
+ np.spacing(f(x_ref-h_fine)))
/ abs(f(x_ref+h_fine) - f(x_ref-h_fine)))
den_rel = np.spacing(h_fine) / h_fine
fig, ax = plt.subplots(figsize=(5, 4))
ax.loglog(h_fine, err_fwd_q, 'b--', alpha=0.25, label='Fwd error')
ax.loglog(h_fine, abs(fwd_q)*(num_rel_fwd+den_rel), 'b', label='Fwd quant.')
ax.loglog(h_fine, err_bwd_q, 'r-.', alpha=0.25, label='Bwd error')
ax.loglog(h_fine, abs(bwd_q)*(num_rel_bwd+den_rel), 'r', label='Bwd quant.')
ax.loglog(h_fine, err_cen_q, 'm-', alpha=0.25, label='Cen error')
ax.loglog(h_fine, abs(cen_q)*(num_rel_cen+den_rel), 'm', label='Cen quant.')
ax.grid(which='both', alpha=0.5)
ax.set_xlabel('h')
ax.set_ylabel('Absolute Error')
ax.legend(fontsize=7)
plt.tight_layout()
plt.show()Second-Order Finite Differences
FDxx = (f(x_ref+2*h) - 2*f(x_ref+h) + f(x_ref)) / h**2
BDxx = (f(x_ref) - 2*f(x_ref-h) + f(x_ref-2*h)) / h**2
CDxx = (f(x_ref+h) - 2*f(x_ref) + f(x_ref-h)) / h**2
CDxxO4 = (-f(x_ref+2*h) + 16*f(x_ref+h)
- 30*f(x_ref) + 16*f(x_ref-h)
- f(x_ref-2*h)) / (12*h**2)
fig, ax = plt.subplots(figsize=(5, 4))
ax.loglog(h, abs(FDxx - ddf(x_ref)), 'b^-', label='FD', alpha=0.45)
ax.loglog(h, abs(BDxx - ddf(x_ref)), 'r.-', label='BD', alpha=0.45)
ax.loglog(h, abs(CDxx - ddf(x_ref)), 'g*-', label='CD', alpha=0.25)
ax.loglog(h, abs(CDxxO4 - ddf(x_ref)), 'co-', label='CD $O(h^4)$', alpha=0.25)
ax.loglog(h, abs(ddf(x_ref))*h/2, 'k^:', label='O(h)', alpha=0.4)
ax.loglog(h, abs(dddf(x_ref))*(h**2)/6, 'k.:', label='$O(h^2)$', alpha=0.5)
ax.loglog(h, abs(dddddf(x_ref))*(h**4)/120, 'ko:', label='$O(h^4)$', alpha=0.5)
ax.grid(which='both', alpha=0.5)
ax.set_xlabel('h')
ax.set_ylabel('Absolute Error')
ax.legend(fontsize=7)
plt.tight_layout()
plt.show()Autodifferentiation (AD)
Autodifferentiation Is Criminally Underused
What is Autodifferentiation?
- It is not symbolic differentiation (like SymPy/Mathematica).
- It is not numerical differentiation (finite differences).
- AD decomposes a program into a sequence of elementary operations (addition, multiplication, exp, sin, etc.) and applies the chain rule to each.
Symbolic vs. AD
- Symbolic differentiation can lead to “expression swell”—the resulting formula for the derivative can be much larger than the original function.
- AD keeps the computational cost proportional to the original function’s evaluation.
The Chain Rule
- If \(y = g(u)\) and \(u = f(x)\), then: \[\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx}\]
- AD automates this book-keeping across complex programs.
Forward and Reverse Mode AD
Forward Mode AD
- Computes the derivative “alongside” the function evaluation.
- We track \(v\) and its derivative \(\dot{v} = \frac{\partial v}{\partial x}\) for every intermediate variable.
- Efficient when the number of inputs \(N\) is small compared to the number of outputs \(M\) (\(f: \mathbb{R}^N \to \mathbb{R}^M\)).
Reverse Mode AD
- Also known as Backpropagation.
- Requires two passes:
- Forward Pass: Compute and store all intermediate values.
- Backward Pass: Compute derivatives starting from the output and moving toward the inputs.
- Efficient when the number of outputs \(M\) is small compared to the number of inputs \(N\) (e.g., a single scalar loss function).
Why Reverse Mode (Usually) Wins for ML
- In ML, we usually have:
- Millions of parameters (Inputs).
- One scalar loss value (Output).
- Reverse mode computes the entire gradient in roughly the same time it takes to evaluate the function once!
AD in Practice: JAX
What is JAX?
- A Python library for high-performance numerical computing.
- Developed by Google.
- “Numpy on steroids” with Autograd and JIT (Just-In-Time) compilation.
Key JAX Transformations
jax.grad: Computes the gradient of a function.jax.jit: Compiles functions for speed (using XLA).jax.vmap: Vectorizes functions (automatic batching).
Coding Example: Rosenbrock Function
The Rosenbrock function is a classic optimization test case: \[f(x, y) = (a-x)^2 + b(y-x^2)^2\] Usually \(a=1, b=100\). It has a global minimum at \((1, 1)\).
import jax.numpy as jnp
from jax import grad, jit
def rosenbrock(params):
x, y = params
return (1.0 - x)**2 + 100.0 * (y - x**2)**2
# Compute the gradient function
grad_f = jit(grad(rosenbrock))
# Initial point
params = jnp.array([0.0, 0.0])
# One step of gradient descent
learning_rate = 0.001
for i in range(1000):
grads = grad_f(params)
params = params - learning_rate * grads
print(f"Final parameters: {params}")
print(f"Loss: {rosenbrock(params)}")AD Beyond ML: Flexible Scientific Computing
Early Example: Haunschild et al.
- An early demonstration of AD applied to flexibly explore different computational models.
- By making the model differentiable, parameters can be fit without hand-deriving gradients.
Their Workflow
Current Research: Making More Functions Differentiable
- A major focus of current research is extending differentiability to operations not traditionally considered smooth (sorting, argmax, rendering, simulation, etc.).
- This allows these operations to be embedded inside learned pipelines.
Summary
- Gradients are essential for high-dimensional optimization.
- Finite differences are too slow and inaccurate for large-scale ML.
- Autodifferentiation (specifically Reverse Mode) allows for efficient gradient computation.
- Tools like JAX make these techniques accessible and performant.
Review Questions
- Why is the curse of dimensionality a problem for optimization without gradients?
- What are the two main types of error in numerical finite differences?
- In what situation (input/output dimensions) is forward mode AD more efficient than reverse mode?
- Why is reverse mode AD (backpropagation) the standard for training deep neural networks?
- How does the computational cost of autodifferentiation compare to the cost of evaluating the original function?