Cutting Edge Model Footprint with Pruning and Distillation
TL;DR — Shrinking edge model footprint is two complementary moves, not one / structured pruning removes whole channels so you get real speedups, and knowledge distillation trains a small model to mimic a large one / I walk through importance scoring, channel removal, recovery fine-tuning, the distillation loss, and the export step.
There’s a persistent myth that you pick one compression technique. You quantize, or you prune, or you distill. In practice the teams shipping the smallest, fastest edge models stack them, and the order matters. I learned this the slow way after spending two weeks pruning a model 60% and watching it run exactly as fast as before, because I’d used unstructured pruning and the sparse weights still occupied dense tensors.
Edge model footprint has three axes: file size on disk, resident memory at runtime, and latency per inference. Quantization mostly attacks the first two. Structured pruning and distillation attack all three — and they attack the architecture, which is why they compose well with the bit-width tricks. If 4-bit quantization is on your roadmap too, my walkthrough on quantizing SLMs to 4-bit GGUF is the natural follow-on once the architecture is lean.
This post is the architecture-level work. We take a model, score channel importance, physically remove the least useful channels, recover accuracy with fine-tuning, then go further with distillation — training a genuinely smaller student against the pruned model as teacher. Every step is runnable PyTorch 2.6.
Structured vs. unstructured pruning
This is the distinction that wasted my two weeks, so it goes first.
Unstructured pruning zeros individual weights. A weight matrix becomes sparse — lots of zeros scattered through it. The problem: a dense matmul still multiplies by those zeros. You get a smaller model only if you also have a runtime that exploits sparsity, and most edge runtimes don’t. On a Cortex-A76 you’ll see no latency change at all.
Structured pruning removes whole units — entire output channels of a convolution, entire rows of a linear layer. The resulting tensors are smaller and still dense. A standard matmul on a standard runtime is genuinely faster because there’s less arithmetic. This is what you want for edge.
The cost is that structured pruning is coarser. Removing a whole channel takes useful weights along with dead ones, so accuracy drops further per parameter removed. You buy that back with recovery fine-tuning.
Step 1: Importance scoring
You need a criterion for which channels to drop. The simplest defensible one is the L2 norm of each output channel’s weights — low norm means the channel contributes little to the output. A better one weights by the gradient (a Taylor-expansion estimate of the loss change from removing the channel). I’ll show both and use the Taylor score.
# importance.py — PyTorch 2.6
import torch
import torch.nn as nn
@torch.no_grad()
def l2_channel_importance(layer: nn.Conv2d) -> torch.Tensor:
"""L2 norm per output channel. Shape: [out_channels]."""
w = layer.weight.data # [out, in, kh, kw]
return w.flatten(1).norm(p=2, dim=1)
def taylor_channel_importance(layer: nn.Conv2d) -> torch.Tensor:
"""First-order Taylor importance: |weight * grad| summed per channel.
Requires a backward pass to have populated .grad first."""
if layer.weight.grad is None:
raise RuntimeError("run a backward pass before Taylor scoring")
contribution = (layer.weight * layer.weight.grad).abs()
return contribution.flatten(1).sum(dim=1) # [out_channels]
To populate gradients for the Taylor score, run a few hundred calibration batches through the model with loss.backward() and accumulate. Don’t call optimizer.step() — you only want the gradients, not a weight update.
# calibrate.py
import torch
def accumulate_taylor_grads(model, loader, loss_fn, device, n_batches=200):
model.train()
model.zero_grad(set_to_none=True)
seen = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
loss = loss_fn(model(x), y)
loss.backward() # grads accumulate across batches
seen += 1
if seen >= n_batches:
break
# grads now hold the sum over n_batches; that's fine for ranking
return model
Step 2: Physically removing channels
Zeroing weights isn’t enough — you must rebuild the layers smaller. When you drop output channels from a conv layer, the next layer’s input channels must shrink to match. This bookkeeping is the part that gets miscoded.
# prune.py
import torch
import torch.nn as nn
def prune_conv_bn_pair(
conv: nn.Conv2d,
bn: nn.BatchNorm2d,
next_conv: nn.Conv2d,
keep_ratio: float,
importance: torch.Tensor,
):
"""Remove output channels from `conv`, fix up `bn` and `next_conv`.
Returns three freshly-shaped modules."""
n_keep = max(1, int(conv.out_channels * keep_ratio))
keep_idx = torch.argsort(importance, descending=True)[:n_keep]
keep_idx, _ = torch.sort(keep_idx) # preserve channel order
# --- new conv: fewer output channels ---
new_conv = nn.Conv2d(
conv.in_channels, n_keep,
kernel_size=conv.kernel_size, stride=conv.stride,
padding=conv.padding, bias=conv.bias is not None,
)
new_conv.weight.data = conv.weight.data[keep_idx].clone()
if conv.bias is not None:
new_conv.bias.data = conv.bias.data[keep_idx].clone()
# --- new batchnorm: matching channel count ---
new_bn = nn.BatchNorm2d(n_keep)
new_bn.weight.data = bn.weight.data[keep_idx].clone()
new_bn.bias.data = bn.bias.data[keep_idx].clone()
new_bn.running_mean.data = bn.running_mean.data[keep_idx].clone()
new_bn.running_var.data = bn.running_var.data[keep_idx].clone()
# --- next conv: fewer INPUT channels (must mirror what we kept) ---
new_next = nn.Conv2d(
n_keep, next_conv.out_channels,
kernel_size=next_conv.kernel_size, stride=next_conv.stride,
padding=next_conv.padding, bias=next_conv.bias is not None,
)
new_next.weight.data = next_conv.weight.data[:, keep_idx].clone()
if next_conv.bias is not None:
new_next.bias.data = next_conv.bias.data.clone()
return new_conv, new_bn, new_next
The subtlety: next_conv.weight.data[:, keep_idx] indexes the input dimension. If you index the output dimension by mistake the shapes will still be plausible and the model will run — and produce nonsense. Always validate with a forward pass on a dummy tensor right after pruning:
# validate.py
import torch
def assert_forward_ok(model, input_shape, device):
model.eval()
dummy = torch.randn(*input_shape, device=device)
with torch.no_grad():
out = model(dummy)
assert torch.isfinite(out).all(), "pruned model produced non-finite output"
print(f"forward OK, output shape {tuple(out.shape)}")
Step 3: Recovery fine-tuning
A freshly pruned model has a noticeable accuracy drop — often 3 to 8 points depending on how aggressive you were. Recovery fine-tuning gets most of it back. Use a low learning rate; you’re nudging, not retraining.
# recover.py
import torch
def recover_finetune(model, train_loader, val_loader, loss_fn, device, epochs=5):
opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best_acc = 0.0
for epoch in range(epochs):
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
opt.zero_grad(set_to_none=True)
loss = loss_fn(model(x), y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
sched.step()
acc = evaluate(model, val_loader, device)
best_acc = max(best_acc, acc)
print(f"epoch {epoch}: val_acc={acc:.4f}")
return best_acc
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
correct = total = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
pred = model(x).argmax(dim=1)
correct += (pred == y).sum().item()
total += y.numel()
return correct / total
Prune in iterations, not one shot. Removing 50% of channels in a single pass is hard to recover from. Removing 15%, fine-tuning, removing another 15%, fine-tuning again, and so on lands at the same final ratio with several points more accuracy retained.
Step 4: Knowledge distillation for the real win
Pruning trims an existing architecture. Distillation lets you design a deliberately small student — fewer layers, narrower width — and train it to imitate the pruned model’s behavior, not just the dataset labels. The student learns from the teacher’s full output distribution, which carries far more information than a one-hot label.
The distillation loss blends a soft term (match the teacher’s softened logits) with a hard term (match the true labels).
# distill.py — PyTorch 2.6
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
"""Hinton-style KD loss: weighted soft + hard targets."""
def __init__(self, temperature: float = 4.0, alpha: float = 0.7):
super().__init__()
self.t = temperature
self.alpha = alpha # weight on the soft (teacher-matching) term
def forward(self, student_logits, teacher_logits, targets):
# Soft term: KL divergence between softened distributions.
soft_student = F.log_softmax(student_logits / self.t, dim=1)
soft_teacher = F.softmax(teacher_logits / self.t, dim=1)
soft_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean")
soft_loss = soft_loss * (self.t ** 2) # restore gradient scale
# Hard term: ordinary cross-entropy on real labels.
hard_loss = F.cross_entropy(student_logits, targets)
return self.alpha * soft_loss + (1.0 - self.alpha) * hard_loss
The temperature ** 2 factor is not optional. Softening logits by T shrinks the gradient by 1/T^2; multiplying back keeps the soft term on the same scale as the hard term so alpha actually means what you think.
# train_student.py
import torch
def train_student(student, teacher, train_loader, val_loader, device, epochs=30):
teacher.eval() # teacher is frozen; never updated
for p in teacher.parameters():
p.requires_grad_(False)
kd_loss = DistillationLoss(temperature=4.0, alpha=0.7)
opt = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=1e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
for epoch in range(epochs):
student.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
t_logits = teacher(x)
s_logits = student(x)
loss = kd_loss(s_logits, t_logits, y)
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
sched.step()
acc = evaluate(student, val_loader, device)
print(f"epoch {epoch}: student val_acc={acc:.4f}")
return student
A well-distilled student frequently beats the same architecture trained from scratch on labels alone by 2 to 5 points — the teacher’s soft targets encode similarity structure (“this 7 looks a bit like a 1”) that hard labels throw away.
Step 5: Export the lean model
Once pruned and distilled, export to ONNX so an edge runtime can load it.
# export.py
import torch
def export_onnx(model, sample_input, path="edge_model.onnx"):
model.eval()
torch.onnx.export(
model, sample_input, path,
input_names=["input"], output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
opset_version=20, dynamo=True, # PyTorch 2.6 dynamo-based exporter
)
print(f"exported to {path}")
The dynamo=True exporter in PyTorch 2.6 traces through more dynamic Python than the legacy tracer; see the PyTorch docs
for operator coverage caveats.
Common Pitfalls
- Unstructured pruning expecting a speedup. Zeroed weights in a dense tensor don’t make inference faster on a typical edge runtime. Use structured pruning for latency wins.
- Forgetting downstream layers. Pruning a conv’s outputs without shrinking the next conv’s inputs is the single most common bug. The model may still run and produce garbage.
- One-shot aggressive pruning. Removing half the channels in a single pass loses accuracy you can’t recover. Iterate.
- Dropping the temperature-squared term. Without
T**2, your soft loss is silently scaled down andalphais meaningless. - Leaving the teacher in train mode. A teacher with active dropout or updating BatchNorm stats produces noisy, drifting targets. Freeze it and call
.eval().
Troubleshooting
Symptom: pruned model file is smaller but inference latency is unchanged.
Cause: unstructured pruning, or the channels were zeroed rather than physically removed.
Fix: rebuild layers at the smaller channel count as in prune_conv_bn_pair. Confirm with sum(p.numel() for p in model.parameters()) before and after — it must actually drop.
Symptom: shape mismatch error on the forward pass after pruning.
Cause: the next layer’s input channels weren’t updated to match removed output channels.
Fix: prune layers in connected pairs and run assert_forward_ok immediately after each pruning step to localize the break.
Symptom: student accuracy plateaus well below the teacher.
Cause: temperature too low (soft targets nearly one-hot, no extra signal) or the student is genuinely under-capacity.
Fix: raise temperature to 4–8 and re-tune alpha. If it still won’t close, the student architecture is too small — add a layer or width.
Symptom: ONNX export fails with an unsupported-operator error.
Cause: a custom op or a dynamic control-flow pattern the exporter can’t trace.
Fix: bump opset_version, or replace the offending module with a traceable equivalent. The dynamo=True exporter handles more cases than the legacy path.
What’s Next
Pruning plus distillation gives you a model that’s smaller in file size, lighter in memory, and genuinely faster — because you cut the architecture, not just the bit width. Quantize the result for another 2 to 4x on disk and memory, and profile the final model on the actual target board before you call it done. Architecture compression and bit-width compression stack; doing both is how edge models get small enough to ship.