


import jax
import jax.numpy as jnp
import optax # gradient-based optimizers
# Define network forward pass
def predict(params, x):
for W, b in params[:-1]:
x = jax.nn.relu(x @ W + b)
W, b = params[-1]
return x @ W + b
# Loss function
def loss_fn(params, x_batch, y_batch):
preds = predict(params, x_batch)
return jnp.mean((preds - y_batch) ** 2)
# Gradient via reverse-mode autodiff
grad_fn = jax.jit(jax.grad(loss_fn))
# Stochastic gradient descent step
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)
for x_batch, y_batch in dataloader:
grads = grad_fn(params, x_batch, y_batch)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)Unlike ridge regression or LASSO, the neural network loss surface is non-convex
Classical convex optimization guarantees do not apply
Yet in practice, large overparameterized networks train reliably. Why?


| Method | Mechanism | Analogy |
|---|---|---|
| Weight decay | \(\ell_2\) penalty on weights | Ridge regression |
| Dropout | Random neuron silencing | Ensemble averaging |
| Early stopping | Stop before validation loss rises | Cross-validation |
| Batch norm | Normalize activations per batch | Conditioning |
| Data augmentation | Random transforms of training data | Effectively enlarges \(n\) |
\[(\mathbf{W} * \mathbf{x})_{i,j} = \sum_{k,\ell} W_{k,\ell} \cdot x_{i+k,\, j+\ell}\]
Rather than fixed local connectivity, attention learns which inputs to weight: \[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V\]
Inductive bias: relational structure is data-dependent, not hardcoded
More flexible than CNNs and GNNs—but requires more data to learn useful structure
Bioengineering applications: