
Figure 1: Model visualization of pneumothorax detection. (a) The model correctly identifies pneumothorax in the right and left upper lungs. (b) The model correctly identifies pneumothorax in the right lower lung. Figure from Taylor et al., PLOS Medicine (2018).


import jax
import jax.numpy as jnp
import optax # gradient-based optimizers
# Define network forward pass
def predict(params, x):
for W, b in params[:-1]:
x = jax.nn.relu(x @ W + b)
W, b = params[-1]
return x @ W + b
# Loss function
def loss_fn(params, x_batch, y_batch):
preds = predict(params, x_batch)
return jnp.mean((preds - y_batch) ** 2)
# Gradient via reverse-mode autodiff
grad_fn = jax.jit(jax.grad(loss_fn))
# Stochastic gradient descent step
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)
for x_batch, y_batch in dataloader:
grads = grad_fn(params, x_batch, y_batch)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)Unlike ridge regression or LASSO, the neural network loss surface is non-convex
Classical convex optimization guarantees do not apply
Yet in practice, large overparameterized networks train reliably. Why?


| Method | Mechanism | Analogy |
|---|---|---|
| Weight decay | \(\ell_2\) penalty on weights | Ridge regression |
| Dropout | Random neuron silencing | Ensemble averaging |
| Early stopping | Stop before validation loss rises | Cross-validation |
| Batch norm | Normalize activations per batch | Conditioning |
| Data augmentation | Random transforms of training data | Effectively enlarges \(n\) |
Figure 2: The convolution operation. A filter (or kernel) slides across the input image, computing a weighted sum of the local receptive field at each position to produce an activation map. Image credit: Bill Kromydas.
\[(\mathbf{W} * \mathbf{x})_{i,j} = \sum_{k,\ell} W_{k,\ell} \cdot x_{i+k,\, j+\ell}\]
Figure 3: Examples of graph-structured data. Biology is full of relational structure, from molecular bonds and gene regulatory networks to the connectome. Image credit: Rick Merritt.
Figure 4: The attention mechanism. Input tokens are transformed into Queries (\(Q\)), Keys (\(K\)), and Values (\(V\)). The attention weights are computed by comparing Queries with Keys, then used to take a weighted average of Values. Image credit: Ebrahim Pichka.
Rather than fixed local connectivity, attention learns which inputs to weight: \[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V\]
Inductive bias: relational structure is data-dependent, not hardcoded
More flexible than CNNs and GNNs—but requires more data to learn useful structure
Bioengineering applications:
Figure 5: AlphaFold performance and architecture. (a) Competitive performance on CASP14 targets. (b–d) Examples of high-accuracy predictions (blue) compared to experimental structures (green), including complex domain packing and side-chain placement. (e) The model architecture, integrating MSA and template features through iterative recycling. Figure from Jumper et al., Nature (2021).
Figure 6: Cell Painting workflow. (a) Genetic or chemical perturbations are applied to cells in multi-well plates. (b) Cells are stained and imaged via microscopy. (c) Automated image analysis extracts morphological features. (d) Resulting morphological profiles are used for downstream tasks. Figure from Bray et al., Nature Protocols (2016).