Overview
This chapter turns the theory of linear regression into an end-to-end PyTorch implementation on a real dataset. You will implement a custom model, MSE loss, gradient descent, feature scaling and regularisation, and interpret learned weights.
You Will Learn
- Implementing LinearRegression as an nn.Module with learnable parameters
- Computing MSE loss manually and via PyTorch
- Normalising features correctly using only training statistics
- Monitoring training and validation loss curves for convergence
- Inspecting and interpreting learned weights on a real dataset
- Extending to polynomial regression with and without L2 regularisation
Main Content
Custom LinearRegression Module
Instead of relying on high-level scikit-learn wrappers, you build a LinearRegression module derived from nn.Module. Parameters w and b are registered as nn.Parameter tensors so that autograd tracks them. The forward method performs a single matrix multiplication plus bias addition. This design pattern mirrors how you will construct more complex neural architectures later.
Implementing and Using MSE Loss
Although PyTorch provides nn.MSELoss, implementing loss = ((y_pred − y_true) ** 2).mean() makes the computation completely transparent. You confirm that autograd correctly computes gradients by checking that small perturbations to w and b change the loss as expected. This understanding is crucial for debugging when gradients later vanish or explode in deep networks.
Training on the Diabetes Dataset
Using sklearn’s Diabetes dataset, you construct feature and target tensors and split them into training and validation sets. Features are standardised using the training set mean and standard deviation only. Training proceeds over multiple epochs: in each epoch you iterate over mini-batches (or the full batch), compute predictions and loss, call loss.backward(), and step the optimiser. Plotting training and validation loss against epochs reveals whether the model is converging and whether overfitting is occurring.
Interpreting Weights and Regularisation Effects
Once training has converged, you inspect w. Features with large positive weights are strong positive predictors; large negative weights are strong negative predictors; near-zero weights are effectively irrelevant. Training models with different λ values for L2 regularisation shows how increasing λ shrinks coefficients and can stabilise estimates in the presence of multicollinearity. Comparing performance across λ values on the validation set reinforces the bias–variance trade-off empirically.
Examples
LinearRegression Module in PyTorch
Minimal implementation of a linear regression model with learnable parameters.
import torch
import torch.nn as nn
class LinearRegression(nn.Module):
def __init__(self, in_features: int):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_features, 1))
self.bias = nn.Parameter(torch.zeros(1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x @ self.weight + self.biasTrain Loop with MSE
One full training loop with manual MSE.
model = LinearRegression(in_features=X_train.shape[1]).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
for epoch in range(200):
model.train()
optimizer.zero_grad()
preds = model(X_train)
loss = ((preds - y_train) ** 2).mean()
loss.backward()
optimizer.step()Common Mistakes
Normalising using statistics computed on the full dataset
Why: Leaks information from validation/test into training and inflates reported performance.
Fix: Compute mean and std on the training partition only; apply the same transform to validation and test.
Forgetting to call optimizer.zero_grad()
Why: Gradients accumulate across iterations, leading to incorrect updates and potential divergence.
Fix: Zero gradients at the start of each optimisation step or right after optimizer.step().
Mini Exercises
1. Modify the training script to track both training and validation MSE each epoch and plot them on the same graph. What patterns indicate overfitting?
2. Extend the feature matrix with polynomial terms (e.g., x, x², x³) for a single selected feature and compare MSE with and without L2 regularisation.