Python · SQL · Web Dev · Java · AI/ML tracks launching soon — your one platform for all of IT

Semantic Segmentation — Pixel-Level Classification

U-Net architecture, skip connections, and how segmentation powers medical imaging and autonomous vehicles. Label every pixel in one forward pass.

35–40 min March 2026
Section 09 · Computer Vision
Vision · 5 topics0/5 done
Before any code — what segmentation adds over detection

Object detection draws rectangles. Semantic segmentation colours every pixel with its class — no rectangles, no approximations, pixel-perfect boundaries.

A radiologist reading a chest X-ray does not draw a box around the tumour and call it done. They need to know the exact boundary — how many cubic centimetres, which tissue is affected, where does it end. A bounding box cannot answer these questions. Semantic segmentation can. It produces a mask: every pixel labelled as "tumour", "healthy tissue", "background."

Practical examples in India: Ola and Uber's dashcam systems segment road, vehicles, pedestrians, and lane markings pixel-by-pixel for driver safety scoring. Agri-tech startups segment satellite images into crop types for yield forecasting. Quality control systems at garment factories segment defect regions in fabric images to measure defect area precisely.

The output of segmentation is a mask — a 2D array of the same height and width as the input image, where each value is a class index. For a 3-class problem (background=0, road=1, vehicle=2), the mask contains integers 0, 1, or 2 at every pixel position.

🧠 Analogy — read this first

Colouring a map. Detection is like placing stickers on a map — one sticker per city, approximately where each city is. Segmentation is like colouring the map by region — every pixel of India is coloured by state, every coastline is traced exactly, every river is coloured blue. Much more precise, much more useful for geography.

The challenge: to colour pixels precisely, the model needs to understand both the broad context (what is in the image) and fine spatial detail (exactly where boundaries are). Pooling layers in CNNs lose spatial detail. U-Net's skip connections restore it — that is the key architectural insight.

Three types of segmentation

Semantic vs instance vs panoptic — what each one produces

Three segmentation tasks on the same image
Semantic Segmentation
Output: Class label per pixel — all cars same colour, all pedestrians same colour
Asks: What class is each pixel?
Used for: Road scene understanding, medical imaging, satellite analysis
Limitation: Cannot distinguish individual instances — two cars get same label
Instance Segmentation
Output: Unique mask per object instance — car 1 and car 2 get different colours
Asks: Which object does each pixel belong to?
Used for: Counting objects, tracking individuals, robotic grasping
Limitation: Does not label background pixels — gaps between instances unlabelled
Panoptic Segmentation
Output: Every pixel labelled — stuff (background, sky, road) + things (each car, person)
Asks: What class and which instance is each pixel?
Used for: Autonomous driving, complete scene understanding
Limitation: Most complex — requires both semantic and instance heads
The architecture

U-Net — encoder, bottleneck, decoder, and skip connections

U-Net (Ronneberger et al., 2015) was designed for medical image segmentation with very few training images. Its key insight: the encoder (contracting path) captures what is in the image by progressively downsampling. The decoder (expanding path) restores spatial resolution. Skip connections copy feature maps directly from encoder to decoder at each scale — providing fine spatial detail that pooling destroyed. The result: precise pixel boundaries even from a small dataset.

U-Net architecture — encoder, bottleneck, decoder, skip connections
Input(B, 3, 256, 256)
Encoder Block 1(B, 64, 256, 256)
→ skip
MaxPool → Encoder Block 2(B, 128, 128, 128)
→ skip
MaxPool → Encoder Block 3(B, 256, 64, 64)
→ skip
MaxPool → Encoder Block 4(B, 512, 32, 32)
→ skip
Bottleneck (MaxPool + Conv)(B, 1024, 16, 16)
Upsample + Concat + Decoder Block 4(B, 512, 32, 32)
Upsample + Concat + Decoder Block 3(B, 256, 64, 64)
Upsample + Concat + Decoder Block 2(B, 128, 128, 128)
Upsample + Concat + Decoder Block 1(B, 64, 256, 256)
Output Conv 1×1(B, n_classes, 256, 256)

Skip connections copy encoder feature maps and concatenate them with upsampled decoder feature maps at the same resolution. This gives the decoder both high-level semantics (from bottleneck) and fine spatial detail (from skip connections).

python
import torch
import torch.nn as nn
import torch.nn.functional as F

# ── U-Net from scratch — every component explicit ─────────────────────

class DoubleConv(nn.Module):
    """Two consecutive Conv2d + BN + ReLU — the basic U-Net block."""
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class Down(nn.Module):
    """MaxPool then DoubleConv — one encoder step."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x): return self.conv(self.pool(x))

class Up(nn.Module):
    """Upsample, concatenate skip connection, then DoubleConv."""
    def __init__(self, in_ch, out_ch, bilinear=True):
        super().__init__()
        if bilinear:
            self.up   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_ch, out_ch)
        else:
            self.up   = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
            self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x, skip):
        x = self.up(x)
        # Pad if sizes differ (input not perfectly divisible)
        dy = skip.size(2) - x.size(2)
        dx = skip.size(3) - x.size(3)
        x  = F.pad(x, [dx//2, dx-dx//2, dy//2, dy-dy//2])
        # Concatenate skip connection — this is the key U-Net operation
        x  = torch.cat([skip, x], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels: int = 3, n_classes: int = 2,
                  features: list = [64, 128, 256, 512]):
        super().__init__()
        # Encoder
        self.inc   = DoubleConv(in_channels, features[0])
        self.down1 = Down(features[0], features[1])
        self.down2 = Down(features[1], features[2])
        self.down3 = Down(features[2], features[3])
        # Bottleneck
        self.down4 = Down(features[3], features[3] * 2)
        # Decoder — in_ch = skip_ch + upsampled_ch
        self.up1   = Up(features[3] * 4, features[3])
        self.up2   = Up(features[3] * 2, features[2])
        self.up3   = Up(features[2] * 2, features[1])
        self.up4   = Up(features[1] * 2, features[0])
        # Output
        self.outc  = nn.Conv2d(features[0], n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder — save skip connections
        x1 = self.inc(x)     # (B, 64,  H,   W)
        x2 = self.down1(x1)  # (B, 128, H/2, W/2)
        x3 = self.down2(x2)  # (B, 256, H/4, W/4)
        x4 = self.down3(x3)  # (B, 512, H/8, W/8)
        x5 = self.down4(x4)  # (B,1024, H/16,W/16)  ← bottleneck

        # Decoder — receive skip connections
        x  = self.up1(x5, x4)  # concat → (B, 512, H/8, W/8)
        x  = self.up2(x,  x3)  # concat → (B, 256, H/4, W/4)
        x  = self.up3(x,  x2)  # concat → (B, 128, H/2, W/2)
        x  = self.up4(x,  x1)  # concat → (B, 64,  H,   W)

        return self.outc(x)     # (B, n_classes, H, W)

# ── Shape check ───────────────────────────────────────────────────────
model = UNet(in_channels=3, n_classes=4)  # 4 classes: bg, road, vehicle, pedestrian
x     = torch.randn(2, 3, 256, 256)
out   = model(x)

total = sum(p.numel() for p in model.parameters())
print(f"U-Net architecture:")
print(f"  Input:   {tuple(x.shape)}")
print(f"  Output:  {tuple(out.shape)}  ← same H×W as input, n_classes channels")
print(f"  Params:  {total:,}")
print(f"
Output interpretation:")
print(f"  out[b, c, h, w] = logit for class c at pixel (h, w) in batch item b")
print(f"  argmax over dim=1 → predicted class per pixel")
Training loop

Loss functions, masks, and the complete training pipeline

Segmentation training is similar to classification but operates at the pixel level. The target is not a single integer per image — it is a 2D mask of shape (H, W) where each value is a class index. The loss is cross-entropy computed over all pixels simultaneously. Class imbalance is severe in segmentation — background pixels vastly outnumber foreground pixels in most tasks. Weighted loss or Dice loss addresses this.

Dice loss — the segmentation-specific loss function
Dice = 2 × |A ∩ B| / (|A| + |B|)
Dice Loss = 1 − Dice
A = predicted mask (probabilities) B = ground truth mask (binary)
Dice = 1.0 → perfect overlap Dice = 0.0 → no overlap
Advantage: handles class imbalance naturally — background pixels do not dominate
Production: Combined Loss = CrossEntropy + DiceLoss (best of both)
python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader

# ── Dice loss implementation ──────────────────────────────────────────
class DiceLoss(nn.Module):
    def __init__(self, smooth: float = 1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        logits:  (B, C, H, W) — raw model output (before softmax)
        targets: (B, H, W)    — integer class indices
        """
        n_classes = logits.size(1)
        probs     = torch.softmax(logits, dim=1)

        # One-hot encode targets: (B, H, W) → (B, C, H, W)
        one_hot = torch.zeros_like(probs)
        one_hot.scatter_(1, targets.unsqueeze(1), 1)

        # Compute Dice per class, then average
        dice_per_class = []
        for c in range(n_classes):
            pred_c = probs[:, c].reshape(-1)
            true_c = one_hot[:, c].reshape(-1)
            intersection = (pred_c * true_c).sum()
            dice = (2 * intersection + self.smooth) / (
                pred_c.sum() + true_c.sum() + self.smooth
            )
            dice_per_class.append(dice)

        return 1 - torch.stack(dice_per_class).mean()

class CombinedLoss(nn.Module):
    """CrossEntropy + Dice — standard for segmentation."""
    def __init__(self, ce_weight=0.5, dice_weight=0.5,
                  class_weights=None):
        super().__init__()
        self.ce      = nn.CrossEntropyLoss(weight=class_weights)
        self.dice    = DiceLoss()
        self.ce_w    = ce_weight
        self.dice_w  = dice_weight

    def forward(self, logits, targets):
        return self.ce_w * self.ce(logits, targets) +                self.dice_w * self.dice(logits, targets)

# ── Synthetic segmentation dataset ───────────────────────────────────
class SyntheticSegDataset(Dataset):
    """Simulates a road scene segmentation dataset."""
    CLASSES = {0: 'background', 1: 'road', 2: 'vehicle', 3: 'pedestrian'}

    def __init__(self, n: int = 200, img_size: int = 128):
        self.n, self.sz = n, img_size

    def __len__(self): return self.n

    def __getitem__(self, idx):
        np.random.seed(idx)
        img  = np.random.randint(30, 220, (3, self.sz, self.sz)).astype(np.float32) / 255
        mask = np.zeros((self.sz, self.sz), dtype=np.int64)

        # Road: bottom half
        mask[self.sz//2:, :] = 1
        # Vehicles: random rectangles in road region
        for _ in range(np.random.randint(1, 4)):
            x, y = np.random.randint(0, self.sz-30), np.random.randint(self.sz//2, self.sz-20)
            mask[y:y+20, x:x+30] = 2
        # Pedestrians: small rectangles
        for _ in range(np.random.randint(0, 3)):
            x, y = np.random.randint(0, self.sz-10), np.random.randint(self.sz//3, self.sz-30)
            mask[y:y+30, x:x+10] = 3

        return torch.FloatTensor(img), torch.LongTensor(mask)

# ── Full training loop ────────────────────────────────────────────────
from torch.utils.data import random_split

dataset  = SyntheticSegDataset(n=200, img_size=128)
train_ds, val_ds = random_split(dataset, [160, 40])
train_ld = DataLoader(train_ds, batch_size=8, shuffle=True)
val_ld   = DataLoader(val_ds,   batch_size=8)

# Class weights — background is ~60% of pixels
class_weights = torch.tensor([0.5, 1.5, 3.0, 4.0])  # upweight rare classes
model     = UNet(in_channels=3, n_classes=4, features=[32, 64, 128, 256])
criterion = CombinedLoss(class_weights=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

def pixel_accuracy(logits, targets):
    preds = logits.argmax(dim=1)
    return (preds == targets).float().mean().item()

def mean_iou(logits, targets, n_classes=4):
    preds  = logits.argmax(dim=1)
    ious   = []
    for c in range(n_classes):
        pred_c = (preds == c)
        true_c = (targets == c)
        inter  = (pred_c & true_c).sum().float()
        union  = (pred_c | true_c).sum().float()
        if union > 0:
            ious.append((inter / union).item())
    return np.mean(ious) if ious else 0.0

print("Training U-Net on road scene segmentation:")
print(f"{'Epoch':>6} {'Train loss':>12} {'Val acc':>10} {'Val mIoU':>10}")
print("─" * 42)

for epoch in range(1, 21):
    model.train()
    total_loss = 0
    for imgs, masks in train_ld:
        optimizer.zero_grad()
        logits = model(imgs)
        loss   = criterion(logits, masks)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()

    if epoch % 4 == 0:
        model.eval()
        accs, ious = [], []
        with torch.no_grad():
            for imgs, masks in val_ld:
                logits = model(imgs)
                accs.append(pixel_accuracy(logits, masks))
                ious.append(mean_iou(logits, masks))
        print(f"  {epoch:>4}  {total_loss/len(train_ld):>12.4f}  "
              f"{np.mean(accs):>10.4f}  {np.mean(ious):>10.4f}")
Measuring segmentation quality

Pixel accuracy, IoU, and mIoU — the segmentation metric family

Pixel accuracy — fraction of correctly classified pixels — is misleading when classes are imbalanced. A model that predicts "background" for every pixel gets 90% pixel accuracy on a dataset where 90% of pixels are background. The correct metrics are per-class IoU and mean IoU (mIoU).

python
import torch
import numpy as np

def segmentation_metrics(pred_mask: torch.Tensor,
                          true_mask: torch.Tensor,
                          n_classes: int,
                          ignore_index: int = 255) -> dict:
    """
    Compute comprehensive segmentation metrics.
    pred_mask: (H, W) — predicted class indices
    true_mask: (H, W) — ground truth class indices
    """
    # Ignore unlabelled pixels (index 255 used in many datasets)
    valid = true_mask != ignore_index
    pred  = pred_mask[valid]
    true  = true_mask[valid]

    # Pixel accuracy
    pixel_acc = (pred == true).float().mean().item()

    # Per-class IoU
    class_ious = {}
    for c in range(n_classes):
        pred_c = (pred == c)
        true_c = (true == c)
        inter  = (pred_c & true_c).sum().float()
        union  = (pred_c | true_c).sum().float()
        if union > 0:
            class_ious[c] = (inter / union).item()

    # Frequency-weighted IoU — weights by class pixel frequency
    freq_iou = 0.0
    total_pixels = valid.sum().float()
    for c, iou in class_ious.items():
        freq = (true == c).sum().float() / total_pixels
        freq_iou += freq * iou

    return {
        'pixel_accuracy': pixel_acc,
        'mean_iou':       np.mean(list(class_ious.values())),
        'freq_iou':       freq_iou.item(),
        'class_iou':      class_ious,
    }

# ── Demonstrate metric sensitivity to class imbalance ─────────────────
H, W = 256, 256
n_classes = 4

# Scenario: 80% background, 15% road, 4% vehicle, 1% pedestrian
true_mask = torch.zeros(H, W, dtype=torch.long)
true_mask[H//4:,    :]           = 1  # road (bottom 75%)
true_mask[H//2:H*3//4, W//4:W//2] = 2  # vehicles
true_mask[H//3:H//3+20, W//3:W//3+10] = 3  # pedestrians

# Model that predicts everything as background
all_background = torch.zeros_like(true_mask)

# Model that predicts perfectly
perfect = true_mask.clone()

for name, pred in [('All background', all_background), ('Perfect', perfect)]:
    m = segmentation_metrics(pred, true_mask, n_classes)
    print(f"{name}:")
    print(f"  Pixel accuracy: {m['pixel_accuracy']:.4f}  ← misleading for all-bg!")
    print(f"  Mean IoU:       {m['mean_iou']:.4f}  ← correctly shows all-bg is bad")
    for c, iou in m['class_iou'].items():
        classes = ['background', 'road', 'vehicle', 'pedestrian']
        print(f"    {classes[c]:<12}: IoU = {iou:.4f}")
    print()

print("Lesson: Always report mIoU, not just pixel accuracy.")
Production approach

DeepLab and SegFormer — pretrained segmentation models for fine-tuning

U-Net trained from scratch requires thousands of labelled images. For most production tasks, fine-tune a pretrained segmentation model instead. DeepLabV3+ (Google) and SegFormer (Nvidia) are the two most widely used pretrained models — both available via HuggingFace with ImageNet-pretrained backbones and COCO/Cityscapes-pretrained heads.

python
# pip install transformers torch torchvision

from transformers import (
    SegformerForSemanticSegmentation,
    SegformerImageProcessor,
)
import torch
import numpy as np
from PIL import Image

# ── Load pretrained SegFormer ─────────────────────────────────────────
# b0=lightest, b5=heaviest. b2 is a good production balance.
model_name = 'nvidia/segformer-b0-finetuned-ade-512-512'
processor  = SegformerImageProcessor.from_pretrained(model_name)
model      = SegformerForSemanticSegmentation.from_pretrained(model_name)
model.eval()

print(f"SegFormer-b0:")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Classes:    {model.config.num_labels} (ADE20K dataset)")

# ── Inference ─────────────────────────────────────────────────────────
img = Image.fromarray(
    np.random.randint(50, 200, (512, 512, 3), dtype=np.uint8)
)
inputs  = processor(images=img, return_tensors='pt')
with torch.no_grad():
    outputs = model(**inputs)

# SegFormer outputs at 1/4 resolution — upsample to original size
logits = outputs.logits          # (1, 150, 128, 128)
upsampled = torch.nn.functional.interpolate(
    logits, size=img.size[::-1],  # (H, W)
    mode='bilinear', align_corners=False,
)
pred_mask = upsampled.argmax(dim=1).squeeze().numpy()
print(f"
Prediction mask shape: {pred_mask.shape}")
print(f"Unique classes predicted: {np.unique(pred_mask).tolist()[:10]}")

# ── Fine-tuning SegFormer on custom classes ───────────────────────────
print("""
# Fine-tuning SegFormer for custom segmentation:

from transformers import SegformerForSemanticSegmentation, TrainingArguments, Trainer

# 1. Load with custom number of classes
model = SegformerForSemanticSegmentation.from_pretrained(
    'nvidia/segformer-b2-finetuned-ade-512-512',
    num_labels=4,                    # your classes
    id2label={0:'bg', 1:'road', 2:'vehicle', 3:'pedestrian'},
    label2id={'bg':0, 'road':1, 'vehicle':2, 'pedestrian':3},
    ignore_mismatched_sizes=True,    # replaces the head for your classes
)

# 2. Dataset returns: pixel_values (3,H,W) + labels (H,W) integer mask
# 3. TrainingArguments — same as classification but with eval_do_concat=False
# 4. Custom compute_metrics using mIoU
# 5. trainer.train()

# Key difference from classification fine-tuning:
# - Labels are 2D masks not 1D class vectors
# - Loss is CrossEntropy over all pixels (model handles this internally)
# - Evaluation uses mIoU not accuracy
""")

# ── torchvision pretrained segmentation models ─────────────────────────
import torchvision.models.segmentation as seg_models

# DeepLabV3 with ResNet50 backbone — pretrained on COCO
deeplab = seg_models.deeplabv3_resnet50(pretrained=False, num_classes=4)
# pretrained=True in production — downloads COCO weights

x   = torch.randn(2, 3, 256, 256)
out = deeplab(x)
print(f"
DeepLabV3 output:")
print(f"  'out' key shape:  {tuple(out['out'].shape)}  ← main prediction")
print(f"  'aux' key shape:  {tuple(out['aux'].shape)}  ← auxiliary loss head")
Errors you will hit

Every common segmentation mistake — explained and fixed

CrossEntropyLoss raises ValueError: Expected target size (B, H, W) but got (B, 1, H, W)
Why it happens

The mask tensor has an extra channel dimension. When loading masks with PIL or OpenCV and converting to tensor, a greyscale mask of shape (H, W) becomes (1, H, W) after ToTensor(). CrossEntropyLoss expects (B, H, W) — the extra dimension causes a shape mismatch.

Fix

Squeeze the mask tensor before computing loss: mask = mask.squeeze(1). Or in your Dataset's __getitem__, convert the PIL mask to a numpy array and then to a LongTensor directly: mask = torch.tensor(np.array(mask_pil), dtype=torch.long) — this gives shape (H, W) without the extra channel. Never use ToTensor() on segmentation masks — it adds a channel dimension and also normalises values to [0, 1], destroying integer class indices.

Model predicts only the background class for every pixel — mIoU near zero
Why it happens

Severe class imbalance — background pixels dominate and standard cross-entropy minimises by predicting background everywhere. Also caused by missing class weights or using Dice loss incorrectly. With 90% background pixels, a model that predicts all-background achieves 90% pixel accuracy and near-zero CrossEntropy loss — the model finds this easy minimum and never learns to segment foreground.

Fix

Add class weights inversely proportional to frequency: weights = 1 / class_pixel_frequencies, normalised. Pass to CrossEntropyLoss: criterion = nn.CrossEntropyLoss(weight=class_weights.to(device)). Use Dice loss or combined CE + Dice — Dice is not dominated by frequent classes. Verify masks are loaded correctly: print(torch.unique(mask)) in __getitem__ to confirm foreground class indices are present.

U-Net skip connections fail — RuntimeError: sizes don't match for concatenation
Why it happens

Input image dimensions are not divisible by 2^(number of pooling layers). U-Net with 4 pooling layers requires the input to be divisible by 16. An input of (256, 341) — width 341 not divisible by 16 — produces feature maps at the bottleneck that cannot be exactly upsampled back to 341 width. The encoder skip connection has 341 width but the upsampled decoder has 340.

Fix

Resize inputs to dimensions divisible by 16 (or 2^n_pooling): use T.Resize to the nearest valid size. Or add padding in the Up module's forward pass: use F.pad to match encoder and decoder spatial dimensions before concatenation. The provided UNet implementation already handles this with F.pad — ensure you use the padded version.

SegFormer logits are at 1/4 resolution — predictions look blocky and low-resolution
Why it happens

SegFormer's design outputs logits at 1/4 of the input resolution (e.g. 128×128 for 512×512 input). This is by design — the hierarchical Transformer encoder downsamples aggressively. If you take argmax directly on the 1/4 resolution logits and use as the final mask, you get a coarse blocky prediction.

Fix

Always upsample SegFormer logits back to the original image resolution before computing the final mask: upsampled = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=False). Use bilinear interpolation for logits (before argmax) — never for the integer mask (after argmax). For training, compute loss on the 1/4 resolution logits with a 1/4 resolution target mask — the HuggingFace Trainer handles this automatically.

What comes next

You can segment any image. Next: get ImageNet-level features without ImageNet-level compute.

You have built segmentation from scratch and used pretrained models. Both required labelled masks — expensive to collect. Module 59 covers transfer learning for vision: how to use a ResNet or EfficientNet backbone pretrained on ImageNet as a feature extractor for your own task, freezing early layers and fine-tuning later layers. The same technique powers every production computer vision system at Indian startups today — building on ImageNet representations instead of training from scratch.

Next — Module 59 · Computer Vision
Transfer Learning — Fine-Tuning Pretrained Vision Models

Feature extraction vs fine-tuning, layer freezing, and choosing the right backbone for your task.

coming soon

🎯 Key Takeaways

  • Semantic segmentation assigns a class label to every pixel — output is a 2D mask of shape (H, W) with integer class indices. Unlike detection (bounding boxes) it traces exact boundaries. Unlike classification (one label per image) it works at pixel granularity.
  • U-Net has two paths: the encoder (downsampling with MaxPool) captures what is in the image, the decoder (upsampling) restores spatial resolution. Skip connections copy encoder feature maps directly to the decoder at each scale — providing fine spatial detail that pooling destroyed. This is why U-Net produces sharp precise boundaries.
  • Input dimensions must be divisible by 2^(number of pooling layers). U-Net with 4 pooling layers requires input divisible by 16. Use F.pad in the decoder to handle any size mismatches between encoder skip connections and upsampled decoder features.
  • Never use ToTensor() on segmentation masks — it adds a channel dimension and normalises to [0, 1], destroying integer class indices. Convert masks with: torch.tensor(np.array(mask_pil), dtype=torch.long) for shape (H, W) with correct integer values.
  • Pixel accuracy is misleading for imbalanced datasets — always use mIoU (mean Intersection over Union). A model predicting all-background gets high pixel accuracy but near-zero mIoU. Use class-weighted CrossEntropyLoss or Dice loss to prevent the model from collapsing to predicting the majority class.
  • For production: fine-tune SegFormer or DeepLabV3 pretrained on ADE20K or Cityscapes. Requires far fewer labelled images than training U-Net from scratch. SegFormer outputs at 1/4 resolution — always upsample with F.interpolate(logits, size=(H,W), mode="bilinear") before argmax for the final prediction.
Share

Discussion

0

Have a better approach? Found something outdated? Share it — your knowledge helps everyone learning here.

Continue with GitHub
Loading...