Pet Segmentation with UNet

Author

Archer Liu

Published

March 21, 2025

Introduction

In this report, I’ll present a UNet-based segmentation model trained on the Oxford-IIIT Pet Dataset from Kaggle . The model aims to accurately segment pets from images using a deep learning approach. The trained model has been saved to disk and will be loaded for evaluation and visualization.

Environment Setup

# Load necessary libraries
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display, Markdown

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
Using device: cuda

Dataset and Data Loading

Dataset Description

The Oxford-IIIT Pet Dataset consists of images of pets and their corresponding binary masks:

  • Images: JPEG format (RGB)
  • Masks: PNG format (binary mask with white for pet and black for background)

Data Loading

To efficiently load the data, I define a custom PetDataset class.

class PetDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(image_dir))
        self.mask_files = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])

        # Load image and mask
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # Grayscale mask

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

Visualizing Data Samples

To understand the data, let’s visualize a few samples from the training set, including images and their corresponding masks.

Training samples

Data Transformation and Augmentation

Data augmentation techniques are applied to enhance model generalization:

  • Resizing to a consistent size.
  • Normalization for image tensors.
  • No normalization for masks to retain binary values.
DATA_DIR = "../data"
IMAGE_SIZE = 256
BATCH_SIZE = 16

data_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
])

# Mask transformations (no normalization)
mask_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
])

# Create datasets
full_dataset = PetDataset(
    image_dir=os.path.join(DATA_DIR, "images"),
    mask_dir=os.path.join(DATA_DIR, "annotations"),
    transform=data_transforms
)

# Train-validation-test split
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size, test_size]
)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Model Architecture: UNet

The UNet model is a fully convolutional network designed to perform precise segmentation by combining:

  1. Encoder: Captures spatial context.
  2. Bottleneck: Bridge between encoder and decoder.
  3. Decoder: Restores spatial resolution.
  4. Skip Connections: Preserve fine details by merging low-level features from the encoder.

UNet Architecture

# DoubleConv Block
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

# UNet Model
class PetUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(PetUNet, self).__init__()
        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bottleneck = DoubleConv(512, 1024)
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(128, 64)
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        bottleneck = self.bottleneck(self.pool(enc4))
        dec4 = self.dec4(torch.cat([self.up4(bottleneck), enc4], dim=1))
        dec3 = self.dec3(torch.cat([self.up3(dec4), enc3], dim=1))
        dec2 = self.dec2(torch.cat([self.up2(dec3), enc2], dim=1))
        dec1 = self.dec1(torch.cat([self.up1(dec2), enc1], dim=1))
        return torch.sigmoid(self.out_conv(dec1))

Model Training

Note: The model has already been trained and saved. The training code is provided for reproducibility but is not executed in this report.

epochs = 20
learning_rate = 1e-4

model = PetUNet(in_channels=3, out_channels=1).to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# results = trainer(model, criterion, optimizer, train_loader, valid_loader, epochs=epochs)
# torch.save(model.state_dict(), "results/models/pet_unet.pth")

Loading the Trained Model

# Load the saved model
model = PetUNet().to(device)
model.load_state_dict(torch.load("../results/models/pet_unet.pth"))
model.eval()
print("Model loaded successfully!")
Model loaded successfully!

Model Evaluation

Evaluate the model on the test set and visualize predictions to assess performance.

test_results = test(model, criterion, test_loader)
test_loss = round(test_results['test_loss'], 4)
test_dice = round(test_results['test_dice'], 4)
Test Loss: 0.1025 | Test Dice: 0.9536

Let’s compare our predicted masks with ground truth to assess model performance.

Test predictions

Conclusion

The UNet-based pet segmentation model achieved a Test Loss of 0.1025 and a Test Dice Score of 0.9536, indicating high accuracy in distinguishing pets from the background. The model demonstrated consistent performance across the dataset, showcasing the effectiveness of the UNet architecture for binary segmentation tasks.

Future work could involve experimenting with advanced architectures like UNet++ or Attention UNet, and leveraging transfer learning to improve generalization. Additionally, extending this approach to multi-class segmentation tasks could further enhance performance.

Reference

Parkhi, O. M., Vedaldi, A., Zisserman, A., & Jawahar, C. V. (2012). Cats and Dogs. In IEEE Conference on Computer Vision and Pattern Recognition.