Neural Networks

Aaron Meyer

Why can’t simpler models handle this?

  • The feature space is high-dimensional: images + mechanical signals
  • The relevant structure is nonlinear: cell type isn’t linearly separable in raw pixel space
  • The “right” features are unknown a priori: handcrafted features miss what matters
  • Neural networks learn a hierarchy of features directly from the data

Neural Network Architecture

From Neurons to Networks

  • A single artificial neuron computes: \[a = \sigma\left(\sum_{i} w_i x_i + b\right) = \sigma(\mathbf{w}^\top \mathbf{x} + b)\]
    • \(\mathbf{x}\): inputs; \(\mathbf{w}\): weights; \(b\): bias; \(\sigma\): activation function
  • A layer applies this operation to every neuron in parallel
  • A network stacks multiple layers: the output of one layer is the input to the next

Feedforward network architecture

A fully connected feedforward network. Each node computes a weighted sum of its inputs, then applies an activation function.

Activation Functions

Common activation functions. ReLU dominates modern hidden layers; sigmoid and softmax are used in output layers for classification.
  • ReLU (rectified linear unit): sparse, fast to compute, avoids vanishing gradient—the default choice for hidden layers
  • Sigmoid: output in \((0,1)\)—useful for binary probability outputs
  • tanh: output in \((-1,1)\)—zero-centered, often used in recurrent networks

Backpropagation

Backprop Is Reverse-Mode Autodiff

  • Recall from the autodifferentiation lecture:
    • Reverse mode: one forward pass + one backward pass computes the gradient w.r.t. all parameters in \(O(1)\) function evaluations
    • Cost ∝ cost of the forward pass, regardless of parameter number
  • Backpropagation is exactly reverse-mode autodiff applied to a neural network: \[\frac{\partial L}{\partial w_{ij}^{(\ell)}} = \frac{\partial L}{\partial a_j^{(\ell)}} \cdot \frac{\partial a_j^{(\ell)}}{\partial z_j^{(\ell)}} \cdot \frac{\partial z_j^{(\ell)}}{\partial w_{ij}^{(\ell)}}\] where \(z_j^{(\ell)} = \mathbf{w}_j^{(\ell)\top}\mathbf{a}^{(\ell-1)} + b_j^{(\ell)}\) and \(a_j^{(\ell)} = \sigma(z_j^{(\ell)})\)

Training loop in practice

import jax
import jax.numpy as jnp
import optax  # gradient-based optimizers

# Define network forward pass
def predict(params, x):
    for W, b in params[:-1]:
        x = jax.nn.relu(x @ W + b)
    W, b = params[-1]
    return x @ W + b

# Loss function
def loss_fn(params, x_batch, y_batch):
    preds = predict(params, x_batch)
    return jnp.mean((preds - y_batch) ** 2)

# Gradient via reverse-mode autodiff
grad_fn = jax.jit(jax.grad(loss_fn))

# Stochastic gradient descent step
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)

for x_batch, y_batch in dataloader:
    grads = grad_fn(params, x_batch, y_batch)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

The non-convex optimization landscape

  • Unlike ridge regression or LASSO, the neural network loss surface is non-convex

    • Multiple local minima exist
    • The landscape has saddle points and flat regions (“plateaus”)
  • Classical convex optimization guarantees do not apply

  • Yet in practice, large overparameterized networks train reliably. Why?

    • In very high dimensions, local minima tend to have similar loss values
    • Stochastic gradient descent (SGD) acts as an implicit regularizer—it finds “flat” minima that generalize better than sharp ones

Regularization in Neural Networks

The Gap Between Theory and Practice

  • The Universal Approximation Theorem guarantees that a large enough network can represent the target function
  • But in practice, without regularization:
    • Large networks memorize noise
    • Training diverges or oscillates
    • Generalization fails badly
  • Everything we know about regularization still applies, just adapted to the NN setting

Training curves without and with regularization

Training and validation loss with and without regularization. Without regularization, the network memorizes training data but fails to generalize. Early stopping (dashed line) is a simple regularizer.

Weight Decay

  • Weight decay adds an \(\ell_2\) penalty to the loss—directly analogous to ridge regression: \[L_\text{reg}(\boldsymbol{\theta}) = L(\boldsymbol{\theta}) + \frac{\lambda}{2}\|\boldsymbol{\theta}\|_2^2\]
  • Penalizes large weights; shrinks them toward zero during training
  • Bayesian interpretation: Gaussian prior \(\theta_j \sim \mathcal{N}(0, 1/\lambda)\) on each weight
    • From the Bayesian lecture: MAP estimation with a Gaussian prior = \(\ell_2\) regularization

Dropout

  • During each forward pass, randomly zero out a fraction \(p\) of neurons (Srivastava et al., 2014)
  • At test time, scale all activations by \((1-p)\) to preserve expected magnitude

Dropout at training time: a random subset of neurons is silenced each forward pass. This forces the network to learn redundant representations.

Batch normalization

  • After each layer, normalize activations to have zero mean and unit variance across the training mini-batch: \[\hat{x}_i = \frac{x_i - \mu_\mathcal{B}}{\sqrt{\sigma_\mathcal{B}^2 + \epsilon}}, \qquad y_i = \gamma \hat{x}_i + \delta\]
  • Learnable scale \(\gamma\) and shift \(\delta\) restore representational capacity
  • Benefits:
    • Stabilizes optimization: prevents activations from exploding or vanishing
    • Acts as mild regularization (adding noise through batch statistics)
    • Allows higher learning rates → faster training

Summary: regularization toolkit for neural networks

Method Mechanism Analogy
Weight decay \(\ell_2\) penalty on weights Ridge regression
Dropout Random neuron silencing Ensemble averaging
Early stopping Stop before validation loss rises Cross-validation
Batch norm Normalize activations per batch Conditioning
Data augmentation Random transforms of training data Effectively enlarges \(n\)

Structure as Inductive Bias

The Key Insight

  • A fully connected network is structure-agnostic: it treats all inputs symmetrically
  • This is flexible, but expensive: it must learn from scratch that nearby pixels are correlated, that molecules are rotation-invariant, that sequences have order
  • Inductive biases bake known structure into the architecture:
    • Less data needed to learn the relevant function
    • Better generalization on finite datasets
    • Often leads to models that are interpretable in domain terms

Convolutional Neural Networks (CNNs)

  • Designed for spatially structured data (images, 1-D signals)
  • Two key inductive biases:
    1. Local connectivity: neurons connect only to nearby inputs—nearby pixels are more related
    2. Weight sharing: the same filter is applied across all spatial positions (translation equivariance)

\[(\mathbf{W} * \mathbf{x})_{i,j} = \sum_{k,\ell} W_{k,\ell} \cdot x_{i+k,\, j+\ell}\]

  • Many fewer parameters than fully connected layers for image-sized inputs
  • Applications in biology: cell morphology, histology, microscopy image analysis

Graph Neural Networks (GNNs)

  • Designed for data with relational structure (graphs, molecules, interaction networks)
  • Update node features by aggregating information from neighbors: \[\mathbf{h}_v^{(\ell+1)} = \sigma\left(\mathbf{W}^{(\ell)} \cdot \text{AGGREGATE}\left(\left\{\mathbf{h}_u^{(\ell)} : u \in \mathcal{N}(v)\right\}\right)\right)\]
  • Key inductive biases:
    • Permutation invariance over the node set (molecules have no canonical atom ordering)
    • Local neighborhood aggregation reflects chemical bonding
  • Applications in bioengineering:
    • Drug-target interaction prediction
    • Protein–protein interaction networks
    • Metabolic pathway modeling

Attention and Transformers

  • Rather than fixed local connectivity, attention learns which inputs to weight: \[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V\]

  • Inductive bias: relational structure is data-dependent, not hardcoded

  • More flexible than CNNs and GNNs—but requires more data to learn useful structure

  • Bioengineering applications:

    • Protein language models (ESM, ProtTrans): treat protein sequence like text
    • AlphaFold2: combines attention over sequence + structural inductive biases
    • Single-cell transformers: model gene expression programs

Applications in Bioengineering

Drug Response Prediction (Revisited)

  • In the regularization lecture, we used elastic net to predict drug sensitivity from genomic features in the CCLE
  • Neural networks can capture nonlinear gene–drug interactions
  • With enough data, a deep network often outperforms LASSO/ridge
  • But: neural networks require more data and careful regularization
  • With 500 cell lines, regularized linear models often win

Protein structure prediction: AlphaFold2

  • Why it works: massive structural inductive bias
    • Equivariance to rotation and translation of the protein backbone
    • Pair representation explicitly encodes residue–residue distances
    • Iterative “recycling” refines structure over multiple passes
  • Training data: ~170,000 experimentally solved structures (PDB)
  • Result: Near-experimental accuracy on most proteins
  • Solved a 50-year-old problem

Cell morphology as a phenotypic readout

  • Cell Painting (Bray et al., 2016): image cells stained with 6 fluorescent dyes; extract ~1,500 morphological features
  • Neural networks trained on raw images learn features that correlate with:
    • Mechanism of action of compounds
    • Disease phenotype
    • Genetic perturbation identity
  • Architecture: CNN backbone (treats image as spatially structured) + classification head

Review

Further Reading

Review Questions

  1. What is the role of the activation function in a neural network? What would happen if all activations were linear?
  2. Explain how backpropagation is a special case of reverse-mode autodifferentiation. What is computed in the forward pass, and what is computed in the backward pass?
  3. Why is the neural network loss surface non-convex? What practical consequences does this have for training?
  4. A colleague trains a network and observes that training loss continues to decrease while validation loss starts to rise. What is happening, and what would you do?
  5. What is an inductive bias? Give one example from each of: CNNs, GNNs, and attention mechanisms, and explain what assumption about the data each encodes.
  6. A fully connected network and a CNN are each trained on the same image classification task. The CNN uses far fewer parameters. Under what conditions might the simpler CNN outperform the larger fully connected network?
  7. AlphaFold2 has ~93M parameters trained on ~170K examples. How do you reconcile this with the double descent picture—shouldn’t such an overparameterized model need far more data?