In digital image processing and computer vision, image segmentation is the process of partitioning a digital image into multiple segments (sets of pixels, also known as image objects). The goal of segmentation is to simplify and/or change the representation of an image into something that is more meaningful and easier to analyze. Image segmentation is typically used to locate objects and boundaries (lines, curves, etc.) in images. More precisely, image segmentation is the process of assigning a label to every pixel in an image such that pixels with the same label share certain characteristics.
The result of image segmentation is a set of segments that collectively cover the entire image, or a set of contours extracted from the image (see edge detection). Each of the pixels in a region are similar with respect to some characteristic or computed property, such as color, intensity, or texture. Adjacent regions are significantly different with respect to the same characteristic(s). — Wikipedia
‘mmdetection’ is an open source semantic segmentation toolbox based on PyTorch. It is a part of the open-mmlab project developed by Multimedia Laboratory, CUHK.
In this article, we will see how to train our own model on custom data using the Github repository called mmsegmentation by open-mmlab.
What makes Mmsegmentation so great?
•They only run on GPUs.
• They have decomposed the segmentation framework into different components and one can easily construct a customized semantic segmentation framework by combining different modules.
• It supports training on multiple GPUs.
• It supports multiple Datasets like Cityscapes, VOC12aug, PASCAL_context etc.
• It has almost every State of the art model pre-configured.
To train on a customized dataset, the following steps are neccessary:
- Add a new dataset class.
- Create a config file accordingly.
- Perform training and evaluation.
1. Add a new dataset
Datasets in MMSegmentation require image and semantic segmentation maps to be placed in folders with the same perfix. To support a new dataset, we may need to modify the original file structure.
In this tutorial, we give an example of converting the dataset. You may refer to docs for details about dataset reorganization.
We use Standord Background Dataset as an example. The dataset contains 715 images chosen from existing public datasets LabelMe, MSRC, PASCAL VOC and Geometric Context. Images from these datasets are mainly outdoor scenes, each containing approximately 320-by-240 pixels. In this tutorial, we use the region annotations as labels. There are 8 classes in total, i.e. sky, tree, road, grass, water, building, mountain, and foreground object.
A training pair will consist of the files with same suffix in img_dir/ann_dir.
You can easily find Datasets on many sites. To make your own annotations with origin color images, I highly recommend using this tool.
You can download my Standord Background Dataset from here.
First, we install MMSegmentation, we will use Pytorch 1.5 with cuda
# Install PyTorch!pip install -U torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f Install MMCV!pip install mmcv-full==latest+torch1.5.0+cu101 -f
Clone the respository
!rm -rf mmsegmentation!git clone mmsegmentation!pip install -e
Import some libraries
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplotfrom mmseg.core.evaluation import get_palette
Copy your data folder to mmsegmentation directory, remember data folder contains images folder (origin images) and labels folder (mask annotation). In the code below, we have 8 classes that will be indicated with 8 color palettes (r,g,b). You can change values with your own data.
import os.path as osp
import numpy as np
from PIL import Image
# Set up paths
data_root = 'data'
img_dir = 'images'
ann_dir = 'labels
# define class and plaette for better visualization
classes = ('sky', 'tree', 'road', 'grass', 'water', 'bldg', 'mntn', 'fg obj')
palette = [[128, 128, 128], [129, 127, 38], [120, 69, 125], [53, 125, 34],
[0, 11, 123], [118, 20, 12], [122, 81, 25], [241, 134, 51]]
Split the data to train set (80%) and validation set (20%)
# split train/val set randomly
split_dir = 'splits'
mmcv.mkdir_or_exist(osp.join(data_root, split_dir))
filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(
osp.join(data_root, ann_dir), suffix='.png')]
with open(osp.join(data_root, split_dir, 'train.txt'), 'w') as f:
# select first 4/5 as train set
train_length = int(len(filename_list)*4/5)
f.writelines(line + '\n' for line in filename_list[:train_length])
with open(osp.join(data_root, split_dir, 'val.txt'), 'w') as f:
# select last 1/5 as val set
f.writelines(line + '\n' for line in filename_list[train_length:])
We need to implement load_annotations
function in the new dataset class StandfordBackgroundDataset
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset
class StandfordBackgroundDataset(CustomDataset):
CLASSES = classes
PALETTE = palette
def __init__(self, split, **kwargs):
super().__init__(img_suffix='.jpg', seg_map_suffix='.png',
split=split, **kwargs)
assert osp.exists(self.img_dir) and self.split is not None
2. Create a config file
In the next step, we need to modify the config for the training. To accelerate the process, we finetune the model from trained weights. Here I choose PSPnet with max_iteration = 40000. All config files are in config, it can be explored very easy.
from mmcv import Config
cfg = Config.fromfile('configs/pspnet/')
Since the given config is used to train PSPNet on cityscapes dataset, we need to modify it accordingly for our new dataset
from mmseg.apis import set_random_seed
# Since we use ony one GPU, BN is used instead of SyncBN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
# modify num classes of the model in decode/auxiliary head
cfg.model.decode_head.num_classes = 8
cfg.model.auxiliary_head.num_classes = 8
# Modify dataset type and path
cfg.dataset_type = 'StandfordBackgroundDataset'
cfg.data_root = data_root = 8
cfg.img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
cfg.crop_size = (256, 256)
cfg.train_pipeline = [
dict(type='Resize', img_scale=(320, 240), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **cfg.img_norm_cfg),
dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
cfg.test_pipeline = [
img_scale=(320, 240),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
dict(type='Resize', keep_ratio=True),
dict(type='Normalize', **cfg.img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
] = cfg.dataset_type = cfg.data_root = img_dir = ann_dir = cfg.train_pipeline = 'splits/train.txt' = cfg.dataset_type = cfg.data_root = img_dir = ann_dir = cfg.test_pipeline = 'splits/val.txt' = cfg.dataset_type = cfg.data_root = img_dir = ann_dir = cfg.test_pipeline = 'splits/val.txt'
# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
cfg.load_from = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
# Set up working dir to save files and logs.
cfg.work_dir = './work_dirs/tutorial'
cfg.total_iters = 200
cfg.log_config.interval = 10
cfg.evaluation.interval = 200
cfg.checkpoint_config.interval = 200
# Set seed to facitate reproducing the result
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
# Let's have a look at the final config used for training
3. Train and Evaluation
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor
# Build the dataset
datasets = [build_dataset(]
# Build the detector
model = build_segmentor(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
# Create work_dir
train_segmentor(model, datasets, cfg, distributed=False, validate=True,
Inference with trained model
img = mmcv.imread('data/images/6000124.jpg')
model.cfg = cfg
result = inference_segmentor(model, img)
plt.figure(figsize=(8, 6))
show_result_pyplot(model, img, result, palette)
• Link to tutorial colab notebook.