Python · SQL · Web Dev · Java · AI/ML tracks launching soon — your one platform for all of IT

Variational Autoencoders — Learning Latent Representations

The reparameterisation trick, KL divergence loss, and why VAEs enable controllable generation through structured latent spaces.

35–40 min March 2026
Before any math — autoencoders vs variational autoencoders

A regular autoencoder compresses images to a point. A VAE compresses images to a region — a probability distribution. That one change makes the latent space smooth, structured, and generatable from.

A standard autoencoder has an encoder that maps an image to a fixed latent vector z, and a decoder that maps z back to an image. Trained to minimise reconstruction error, it learns an efficient compression. But the latent space it creates is fragmented — arbitrary points in it decode to garbage because the model was never trained to handle points other than the exact codes it memorised for training images.

A VAE changes the encoder's output from a single point to a probability distribution — specifically a Gaussian defined by mean μ and variance σ². During training, the latent code z is sampled from this distribution rather than fixed. A regularisation term (KL divergence) forces all these distributions to stay close to a standard normal N(0, I). The result: the entire latent space is covered continuously — any point you sample from N(0, I) decodes to a meaningful image.

🧠 Analogy — read this first

A regular autoencoder is like a library where each book has a specific assigned shelf location. The shelves between books are empty — if you reach between two books you get nothing. A VAE is like a library organised by topic, with smooth transitions between subjects — books on cricket shade gradually into books on other sports, then into general fitness. Any point on the shelf has something meaningful. You can navigate by sliding from one location to another and find related content throughout.

The KL divergence term is the librarian enforcing this organisation. Without it, the encoder would cram all books into tiny clusters and leave most of the shelf empty — efficient but not navigable.

The architecture

Encoder, reparameterisation, decoder — every component explained

VAE data flow — from image to distribution to reconstruction
Input x(B, C, H, W)Original image — pixel values normalised to [0, 1]
Encoder E(x)→ h (B, hidden)CNN or MLP that extracts features — same as classification backbone
μ = fc_mean(h)(B, latent_dim)Mean of posterior distribution q(z|x) — one value per latent dimension
log σ² = fc_logvar(h)(B, latent_dim)Log variance — log space for numerical stability (variance must be positive)
z = μ + σ ⊙ ε(B, latent_dim)Reparameterisation trick — ε ~ N(0,I) is the random part, μ and σ carry gradients
Decoder D(z)→ x̂ (B, C, H, W)Reconstructed image — trained to match input x
The reparameterisation trick — why it exists and what it does
Problem: z ~ N(μ, σ²) is not differentiable — sampling breaks backprop
If we sample z directly, the gradient ∂loss/∂μ and ∂loss/∂σ cannot be computed
Solution: z = μ + σ ⊙ ε where ε ~ N(0, I)
Now ε is the random part (no gradient needed) and μ, σ are deterministic transformations
∂loss/∂μ = ∂loss/∂z × ∂z/∂μ = ∂loss/∂z × 1 ← flows through
∂loss/∂σ = ∂loss/∂z × ∂z/∂σ = ∂loss/∂z × ε ← flows through
The loss function

ELBO — Evidence Lower Bound — reconstruction loss plus KL divergence

The VAE is trained to maximise the ELBO (Evidence Lower Bound) — a lower bound on the log likelihood of the data. Maximising ELBO is equivalent to minimising two terms: the reconstruction loss (how well does the decoder reconstruct the input) and the KL divergence (how close is the encoder's distribution to N(0, I)). These two terms are in tension — the KL term wants to collapse all encodings to N(0, I) which would lose all information, while the reconstruction term wants to preserve all information. The balance between them creates the structured latent space.

ELBO derivation — the two terms and what they enforce
ELBO = E[log p(x|z)] − KL(q(z|x) || p(z))
Term 1: E[log p(x|z)] — Reconstruction
How well does the decoder reconstruct x from sampled z?
Binary CE for image pixels in [0,1]. MSE also common.
Maximise → decoder gets better at reconstruction.
Term 2: KL(q(z|x) || p(z)) — Regularisation
How far is q(z|x) = N(μ, σ²) from the prior p(z) = N(0, I)?
Closed form: −0.5 × Σ(1 + log σ² − μ² − σ²)
Minimise → encoder's distributions stay near standard normal.
Loss = Reconstruction Loss + β × KL (β=1 is standard VAE, β>1 is β-VAE)
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

torch.manual_seed(42)

# ── Full VAE implementation — convolutional ───────────────────────────
class ConvVAE(nn.Module):
    """
    Convolutional VAE for 64×64 RGB images.
    Encoder: image → (μ, log σ²)
    Decoder: z → image
    """
    def __init__(self, latent_dim: int = 128):
        super().__init__()
        self.latent_dim = latent_dim

        # ── Encoder ───────────────────────────────────────────────────
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),   # 64 → 32
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),  # 32 → 16
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), # 16 → 8
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1),# 8 → 4
            nn.ReLU(),
            nn.Flatten(),                 # 256 × 4 × 4 = 4096
        )
        self.fc_mu      = nn.Linear(4096, latent_dim)
        self.fc_log_var = nn.Linear(4096, latent_dim)

        # ── Decoder ───────────────────────────────────────────────────
        self.fc_decode = nn.Linear(latent_dim, 4096)
        self.decoder   = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), # 4 → 8
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 8 → 16
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),   # 16 → 32
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),    # 32 → 64
            nn.Sigmoid(),   # pixel values in [0, 1]
        )

    def encode(self, x):
        h       = self.encoder(x)
        mu      = self.fc_mu(h)
        log_var = self.fc_log_var(h)
        return mu, log_var

    def reparameterise(self, mu, log_var):
        if self.training:
            std = torch.exp(0.5 * log_var)
            eps = torch.randn_like(std)
            return mu + std * eps
        else:
            return mu   # at inference: use mean, no sampling noise

    def decode(self, z):
        h = self.fc_decode(z).view(-1, 256, 4, 4)
        return self.decoder(h)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z           = self.reparameterise(mu, log_var)
        x_recon     = self.decode(z)
        return x_recon, mu, log_var

# ── ELBO loss ─────────────────────────────────────────────────────────
def elbo_loss(x_recon, x, mu, log_var, beta: float = 1.0):
    """
    ELBO = Reconstruction loss + β × KL divergence
    beta=1:  standard VAE
    beta>1:  β-VAE — stronger disentanglement, worse reconstruction
    beta<1:  reconstruction focus — sharper outputs, less structured latent
    """
    # Reconstruction: binary cross-entropy per pixel, summed over batch
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')

    # KL divergence: −0.5 × Σ(1 + log σ² − μ² − σ²)
    # Closed-form for diagonal Gaussian vs standard normal
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    return (recon_loss + beta * kl_loss) / x.size(0)   # normalise by batch size

# ── Shape check ───────────────────────────────────────────────────────
vae = ConvVAE(latent_dim=128)
x   = torch.rand(4, 3, 64, 64)   # batch of 4 random images

x_recon, mu, log_var = vae(x)
loss = elbo_loss(x_recon, x, mu, log_var, beta=1.0)

total_params = sum(p.numel() for p in vae.parameters())
print(f"ConvVAE shapes:")
print(f"  Input:        {tuple(x.shape)}")
print(f"  Reconstructed:{tuple(x_recon.shape)}")
print(f"  μ:            {tuple(mu.shape)}")
print(f"  log σ²:       {tuple(log_var.shape)}")
print(f"  ELBO loss:    {loss.item():.4f}")
print(f"  Parameters:   {total_params:,}")
Training a VAE

Complete training pipeline with KL annealing

A critical practical detail: if you start training with the full KL term, the encoder immediately collapses all posteriors to N(0, I) because that minimises KL loss trivially — the reconstruction loss hasn't had time to build useful representations yet. KL annealing fixes this: start with β=0 (pure reconstruction), gradually increase β to 1 over the first 10–20 epochs. The encoder first learns to reconstruct, then learns to organise the latent space.

python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(42)

# ── Synthetic dataset — fashion product images ────────────────────────
class SyntheticFashionDataset(Dataset):
    """
    Simulates 6 fashion categories with distinct colour signatures.
    In production: torchvision.datasets.ImageFolder with real images.
    """
    CATEGORY_COLOURS = {
        0: [0.8, 0.2, 0.2],  # kurta — warm red
        1: [0.8, 0.6, 0.1],  # saree — gold
        2: [0.2, 0.3, 0.7],  # jeans — denim blue
        3: [0.3, 0.3, 0.3],  # sneakers — grey
        4: [0.6, 0.4, 0.1],  # watch — brown leather
        5: [0.6, 0.1, 0.5],  # handbag — purple
    }

    def __init__(self, n: int = 600):
        self.n = n
        self.labels = np.random.randint(0, 6, n)

    def __len__(self): return self.n

    def __getitem__(self, i):
        label  = self.labels[i]
        colour = self.CATEGORY_COLOURS[label]
        img    = np.random.randn(3, 64, 64) * 0.1
        for c in range(3):
            img[c] += colour[c]
        img = np.clip(img, 0, 1).astype(np.float32)
        return torch.FloatTensor(img), label

dataset    = SyntheticFashionDataset(n=600)
train_size = 480
val_size   = 120
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])
train_ld = DataLoader(train_ds, batch_size=32, shuffle=True)
val_ld   = DataLoader(val_ds,   batch_size=32)

# ── Smaller VAE for demonstration ─────────────────────────────────────
class SmallVAE(nn.Module):
    def __init__(self, latent_dim=32):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1), nn.ReLU(),   # 32
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),  # 16
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU(), # 8
            nn.Flatten(),
        )
        self.fc_mu      = nn.Linear(128*8*8, latent_dim)
        self.fc_log_var = nn.Linear(128*8*8, latent_dim)
        self.fc_decode  = nn.Linear(latent_dim, 128*8*8)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),  nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_log_var(h)

    def reparameterise(self, mu, lv):
        return mu + torch.exp(0.5 * lv) * torch.randn_like(mu) if self.training else mu

    def decode(self, z):
        return self.decoder(self.fc_decode(z).view(-1, 128, 8, 8))

    def forward(self, x):
        mu, lv = self.encode(x)
        z = self.reparameterise(mu, lv)
        return self.decode(z), mu, lv

vae       = SmallVAE(latent_dim=32)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

def kl_anneal(epoch, warmup=10):
    """KL annealing: linearly ramp β from 0 to 1 over warmup epochs."""
    return min(1.0, epoch / warmup)

print(f"Training VAE on fashion dataset:")
print(f"{'Epoch':>6} {'Train loss':>12} {'Recon':>10} {'KL':>8} {'β':>6}")
print("─" * 46)

for epoch in range(1, 21):
    beta = kl_anneal(epoch, warmup=10)
    vae.train()
    total_loss = recon_total = kl_total = 0

    for imgs, _ in train_ld:
        optimizer.zero_grad()
        x_recon, mu, lv = vae(imgs)

        recon = F.binary_cross_entropy(x_recon, imgs, reduction='sum') / imgs.size(0)
        kl    = -0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp()) / imgs.size(0)
        loss  = recon + beta * kl

        loss.backward()
        nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
        optimizer.step()

        total_loss   += loss.item()
        recon_total  += recon.item()
        kl_total     += kl.item()

    scheduler.step()
    n = len(train_ld)
    if epoch % 4 == 0:
        print(f"  {epoch:>4}  {total_loss/n:>12.2f}  {recon_total/n:>10.2f}  "
              f"{kl_total/n:>8.2f}  {beta:>6.2f}")

print(f"
KL annealing: β starts at 0 → 1 over 10 epochs")
print(f"  Without annealing: KL collapses all encodings to N(0,I) immediately")
print(f"  With annealing: encoder first learns reconstruction, then latent structure")
Using the latent space

Interpolation, generation, and anomaly detection — the three VAE superpowers

python
import torch
import torch.nn as nn
import numpy as np

# Assume vae is trained from previous section
vae.eval()

# ── Application 1: Interpolation between two fashion items ───────────
print("=" * 50)
print("1. LATENT SPACE INTERPOLATION")
print("=" * 50)

# Two images from different categories
img_kurta  = torch.rand(1, 3, 64, 64) * 0.6 + 0.2   # warm colours
img_jeans  = torch.rand(1, 3, 64, 64) * 0.3 + 0.1   # cool dark colours

with torch.no_grad():
    mu_kurta, _  = vae.encode(img_kurta)
    mu_jeans, _  = vae.encode(img_jeans)

    print("Interpolating kurta → jeans:")
    for t in [0.0, 0.25, 0.5, 0.75, 1.0]:
        z_interp  = (1 - t) * mu_kurta + t * mu_jeans
        img_interp = vae.decode(z_interp)
        print(f"  t={t:.2f}: mean={img_interp.mean():.4f}  "
              f"R={img_interp[0,0].mean():.3f}  "
              f"B={img_interp[0,2].mean():.3f}")

# ── Application 2: Random generation from prior ───────────────────────
print("
" + "=" * 50)
print("2. GENERATING NEW FASHION ITEMS")
print("=" * 50)

with torch.no_grad():
    # Sample from standard normal — the prior p(z)
    z_samples  = torch.randn(8, vae.latent_dim)
    generated  = vae.decode(z_samples)
    print(f"Generated {len(generated)} new images from N(0,I)")
    for i, img in enumerate(generated):
        print(f"  Item {i+1}: mean={img.mean():.4f}  std={img.std():.4f}")

# ── Application 3: Anomaly detection ─────────────────────────────────
print("
" + "=" * 50)
print("3. ANOMALY DETECTION")
print("=" * 50)
print("""
# Core idea: normal items reconstruct well, anomalies reconstruct poorly.
# Reconstruction error = how "normal" an item is.

# Train VAE only on normal (non-defective) product images.
# At inference: compute reconstruction error for each item.
# High error → anomaly (defect, wrong item, etc.)

def anomaly_score(vae, image, n_samples=10):
    vae.eval()
    with torch.no_grad():
        mu, log_var = vae.encode(image)
        # Average reconstruction error over multiple samples
        # (reduces noise from stochastic sampling)
        errors = []
        for _ in range(n_samples):
            z     = vae.reparameterise(mu, log_var)
            recon = vae.decode(z)
            error = F.mse_loss(recon, image, reduction='mean').item()
            errors.append(error)
        return np.mean(errors)

# Normal item: low reconstruction error
# Defective item: high reconstruction error
# Threshold at 95th percentile of validation set errors

# Used by: quality control at garment factories,
# fraud detection (unusual transaction patterns),
# medical imaging (lesion detection)
""")

# ── Application 4: Attribute manipulation ─────────────────────────────
print("4. LATENT ATTRIBUTE MANIPULATION")
print("""
# Learn direction vectors for specific attributes:
# colour_direction = mean(blue_items_latents) - mean(red_items_latents)

# Then at inference:
# z_blue_version = z_original + alpha * colour_direction
# decoded = vae.decode(z_blue_version)
# → same product, different colour, without re-photographing it

# Used by Myntra for product colour variations.
# Used by Swiggy to generate food images in different presentation styles.
""")
Controlling the latent space

β-VAE — disentangled representations where each dimension has meaning

In a standard VAE (β=1), the latent dimensions are not necessarily interpretable — dimension 7 might encode a mixture of colour, texture, and shape simultaneously. β-VAE increases the KL weight (β > 1), forcing the encoder to use each latent dimension more independently. With enough pressure, individual dimensions learn to represent single factors of variation — one dimension for colour, one for shape, one for size. This is called disentanglement.

python
import torch
import numpy as np

# ── β-VAE: just change beta in the loss ──────────────────────────────
def beta_vae_loss(x_recon, x, mu, log_var, beta: float):
    """
    β-VAE loss with controllable disentanglement.
    beta=1:   standard VAE — best reconstruction
    beta=4:   moderate disentanglement
    beta=10+: strong disentanglement, blurrier reconstruction
    """
    recon = torch.nn.functional.binary_cross_entropy(
        x_recon, x, reduction='sum') / x.size(0)
    kl    = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / x.size(0)
    return recon + beta * kl, recon.item(), kl.item()

# ── Demonstrate effect of β on KL and reconstruction ─────────────────
print("Effect of β on training dynamics:")
print(f"{'β':>6} {'Effect':>50}")
print("─" * 60)
betas_guide = [
    (0.0,  'Pure reconstruction — AE mode, no latent regularisation'),
    (0.5,  'Light regularisation — sharp images, some structure'),
    (1.0,  'Standard VAE — balanced reconstruction and structure'),
    (4.0,  'Moderate β-VAE — some disentanglement, slightly blurry'),
    (10.0, 'Strong β-VAE — clear disentanglement, blurry outputs'),
    (100., 'Over-regularised — all encodings collapse to N(0,I)'),
]
for b, effect in betas_guide:
    print(f"  {b:>4.1f}  {effect}")

# ── Measuring disentanglement — latent traversal ──────────────────────
print("""
# Latent traversal: fix all dimensions, vary one → visualise effect
# Well-disentangled: varying dim 3 changes ONLY colour
# Poorly-entangled: varying dim 3 changes colour AND shape AND texture

def latent_traversal(vae, image, dim: int, values: list):
    vae.eval()
    with torch.no_grad():
        mu, _ = vae.encode(image)
        images = []
        for v in values:
            z        = mu.clone()
            z[0, dim] = v         # vary only this dimension
            images.append(vae.decode(z))
        return images

# Traversal values — sweep from -3 to +3 (3 standard deviations)
# traversal_values = torch.linspace(-3, 3, 10)
# for dim in range(vae.latent_dim):
#     imgs = latent_traversal(vae, test_image, dim, traversal_values)
#     # If imgs show monotonic smooth change in one attribute → disentangled
""")

# ── Choosing β for production ─────────────────────────────────────────
print("Practical β selection guide:")
use_cases = [
    ('Image reconstruction / compression',    1.0,  'Prioritise fidelity'),
    ('Data augmentation',                      1.0,  'Good quality variations'),
    ('Anomaly detection',                      1.0,  'Sharp recon error signal'),
    ('Attribute manipulation',                 4.0,  'Partial disentanglement'),
    ('Drug discovery (molecular generation)',  4.0,  'Structured chemical space'),
    ('Interpretability research',             10.0,  'Maximum disentanglement'),
]
print(f"  {'Use case':<40} {'β':>5}  Reason")
print("  " + "─" * 60)
for use, b, reason in use_cases:
    print(f"  {use:<40} {b:>5.1f}  {reason}")
Errors you will hit

Every common VAE mistake — explained and fixed

KL loss goes to zero from the first epoch — posterior collapse
Why it happens

The encoder ignores the input and always outputs μ=0, σ=1 (standard normal), minimising KL to zero trivially. The decoder then learns to ignore z and generate blurry average images from scratch. This is posterior collapse — a known VAE failure mode. Happens when KL weight is too high relative to reconstruction loss, or when the decoder is too powerful (can reconstruct well without using z).

Fix

Use KL annealing: start with β=0 and linearly increase to β=1 over 10–20 epochs. The reconstruction signal trains first, forcing the encoder to encode useful information into z before the KL term regularises it. Also reduce decoder capacity or add a bottleneck. Check: if mu and log_var are near 0 and -∞ respectively for all inputs, posterior collapse has occurred.

VAE outputs blurry images — reconstructions look smeared and averaged
Why it happens

Binary cross-entropy reconstruction loss treats each pixel independently and optimises the expected pixel value. When there is uncertainty about a pixel (the model is unsure between two values), it outputs the average — which appears as blur. This is fundamental to the BCE/MSE pixel-wise reconstruction objective, not a bug. The KL term adds further pressure to average over modes.

Fix

This is an inherent limitation of VAEs. Partially mitigate with: perceptual loss (compare VGG features instead of raw pixels), GAN-style discriminator loss added on top of ELBO (VQVAE-2 approach), or using a discrete latent space (VQ-VAE). For sharp images, use a diffusion model instead — they are specifically designed to avoid this averaging problem. Or reduce β to focus more on reconstruction fidelity.

RuntimeError: Expected all tensors to be on the same device during reparameterise
Why it happens

torch.randn_like(std) creates a tensor on the same device as std, but if std was computed on CPU and the model was later moved to GPU (or vice versa), there is a device mismatch. Also caused by manually constructing the noise tensor with torch.randn(...) instead of torch.randn_like(std) — the former defaults to CPU regardless of where std lives.

Fix

Always use torch.randn_like(std) — it automatically creates noise on the same device as std. Never use torch.randn(std.shape) which always creates on CPU. If you need to manually specify: eps = torch.randn(std.shape, device=std.device). Verify model device before training: print(next(vae.parameters()).device) — should match your training data device.

Reconstruction loss NaN — training diverges after a few batches
Why it happens

Binary cross-entropy requires targets in (0, 1) strictly. If input images are not normalised to [0, 1] — for example raw uint8 values 0–255, or values slightly above 1.0 due to augmentation — BCE produces NaN because log(0) = -∞. Also caused by the Sigmoid activation being missing from the decoder output, sending pixel values outside [0, 1].

Fix

Always normalise input images to [0, 1] before passing to the VAE. Use T.ToTensor() which divides by 255 automatically. Ensure the decoder's final activation is nn.Sigmoid(). Add a clamp in the loss: x = x.clamp(1e-8, 1-1e-8) before computing BCE to prevent log(0). If images are normalised to [-1, 1] (like after ImageNet normalisation), use MSE loss instead of BCE — MSE works for any range.

What comes next

You understand latent variable models. Next: the architecture that generates the sharpest images ever produced by AI.

GANs are sharp but unstable. VAEs are stable but blurry. Diffusion models get the best of both — they are stable to train, produce sharp photorealistic outputs, and avoid mode collapse entirely. Module 63 explains the forward noising process, the reverse denoising network, and how Stable Diffusion uses a VAE latent space to make diffusion fast enough for practical use.

Next — Module 63 · Generative AI
Diffusion Models and Stable Diffusion

Forward noise, reverse denoising, DDPM, latent diffusion — how Stable Diffusion generates photorealistic images from text.

coming soon

🎯 Key Takeaways

  • A regular autoencoder maps each image to a fixed point in latent space — the space between points is empty and decodes to garbage. A VAE maps each image to a probability distribution (Gaussian with mean μ and variance σ²) and regularises all distributions to stay near N(0, I). Any point sampled from N(0, I) decodes to a meaningful image.
  • The reparameterisation trick makes VAE training possible: instead of sampling z ~ N(μ, σ²) directly (which breaks gradients), compute z = μ + σ × ε where ε ~ N(0, I). The random ε is independent of the parameters — gradients flow through μ and σ normally.
  • ELBO loss has two terms: reconstruction loss (BCE or MSE — how well does decoder reproduce the input) and KL divergence (−0.5 × Σ(1 + log σ² − μ² − σ²) — how close is the encoder distribution to N(0, I)). These are in tension — the balance creates a structured, navigable latent space.
  • KL annealing is essential for stable training: start β=0 (pure reconstruction) and linearly increase to β=1 over 10–20 epochs. Without annealing the KL term causes posterior collapse — the encoder ignores the input and outputs N(0, I) trivially, and the decoder learns to generate average blurry images without using z.
  • β-VAE (β > 1) increases KL weight to encourage disentanglement — individual latent dimensions learn to represent independent factors (colour, shape, size). β=1 gives best reconstruction quality. β=4 gives partial disentanglement. β≥10 gives strong disentanglement but noticeably blurry outputs.
  • Three production applications: interpolation (smooth transition between two encoded images by linearly blending their latent vectors), anomaly detection (high reconstruction error = unusual item — train only on normal items), and attribute manipulation (compute direction vectors in latent space for specific attributes like colour and add them to new encodings).
Share

Discussion

0

Have a better approach? Found something outdated? Share it — your knowledge helps everyone learning here.

Continue with GitHub
Loading...