iSegFormer: Interactive Segmentation via Transformers with Application to 3D Knee MR Images
{, }
{, }
1) Motivation, Objectives and Related Works:
Motivation:
Interactive image segmentation has been widely applied to obtain high-quality voxel-level labels for medical images. The recent success of Transformers on various vision tasks has paved the road for developing Transformer-based interactive image segmentation approaches. However, these approaches remain unexplored and, in particular, have not been developed for 3D medical image segmentation.
To fill this research gap, we investigate Transformer-based interactive image segmentation and its application to 3D medical images. This is a nontrivial task due to two main challenges:
Limited memory for computationally inefficient Transformers.
Limited labels for 3D medical images.
Objectives:
To tackle the first challenge, we propose iSegFormer, a memory-efficient Transformer that combines a Swin Transformer with a lightweight multilayer perceptron (MLP) decoder.
To address the second challenge, we pretrain iSegFormer on large amount of unlabeled datasets and then finetune it with only a limited number of segmented 2D slices.
We further propagate the 2D segmentations obtained by iSegFormer to unsegmented slices in 3D images using a pre-existing segmentation propagation model pretrained on videos. We evaluate iSegFormer on the public OAI-ZIB dataset for interactive knee cartilage segmentation.
Evaluation results show that iSegFormer outperforms its convolutional neural network (CNN) counterparts on interactive 2D knee cartilage segmentation, with competitive computational efficiency. When propagating the 2D interactive segmentations of 5 slices to other unprocessed slices within the same 3D volume, we achieve 82.2% Dice score for 3D knee cartilage segmentation.
Related Works:
Interactive Medical Image Segmentation:
MIDeepSeg
MONAI
Vision Transformers:
ViT
SegFormer
Swin Transformer
Swin-Unet
Transfuse
UTNet
Interactive Video Object Segmentation (iVOS)
MiVOS
STCN
Contribution:
In this work, we aim to fill this research gap by investigating Transformer-based interactive image segmentation and its application to 3D medical images. This is a challenging task due to: 1) limited memory for computationally inefficient Transformers and 2) limited labels for 3D medical images. To tackle the first challenge, we propose iSegFormer, a memory-efficient Transformer that combines a Swin Transformer with a lightweight multilayer perceptron (MLP) decoder. With the efficient Swin Transformer blocks for hierarchical self-attention and the simple MLP decoder for aggregating both local and global attention, iSegFormer learns powerful representations while achieving high computational efficiencies. To address the second challenge, we pretrain iSegFormer on large amount of unlabeled datasets and then finetune it with only a limited number of segmented 2D slices. To extend iSegFormer to 3D interactive segmentation, we further combine it with a segmentation propagation module that propagates segmented 2D slices to unlabeled ones in the same image volume. When the propagated segmentations are not as desired, the user can refine them and start a new round of propagation if necessary. Specifically, we combine iSegFormer with STCN [14], which achieves state-of-the-art results on interactive video object segmentation. We use a pretrained STCN model without finetuning on medical images.
Contributions:
We propose iSegFormer, a memory-efficient Transformer that combines a Swin Transformer with a lightweight MLP decoder, for interactive image segmentation.
iSegFormer outperforms its CNN counterparts for interactive 2D knee cartilage segmentation on the OAI-ZIB dataset with comparable computational efficiency with CNNs. To the best of our knowledge, iSegFormer is the first Transformerbased approach for interactive medical image segmentation.
We further show that iSegFormer can be easily extended to interactive 3D medical image segmentation by combining it with a pre-existing segmentation propagation model trained on videos.
2) Methodology:
Network Architecture:
The proposed iSegFormer is a Transformer-based interactive 2D image segmentation approach that combines a Swin Transformer with a lightweight MLP decoder. As shown in Fig. 1, it can be easily extended to a 3D interactive segmentation approach by combining it with a segmentation propagation module (i.e., STCN [14]). This 3D interactive segmentation approach consists of an iSegFormer for obtaining 2D segmentation from user interactions and a segmentation propagation module that propagates segmented slices to unsegmented ones, resulting in a 3D segmentation. If the propagated segmentation results are not desired, the user can refine them with further interactions and start a new round of propagation if necessary.
The network architecture of iSegFormer is shown in Fig. 1 (bottom). It uses a Swin Transformer as the segmentation backbone and two light-weight MLP layers as the decoder to produce segmentation. Specifically, there are four Swin Transformer blocks for hierarchical self-attention and a simple MLP decoder that first aggregates both local and global attention and then produces the segmentation as the output. The input of iSegFormer is the concatenation of image and clicks encoding map (introduced in Sec. 3.2). Since we want to make use of existing pretrained Swin Transformer models on ImageNet-21k [29], we do not change the number of input channels of the Swin Transformer blocks. To achieve this, we use element-wise addition instead of concatenation for merging image features and clicks encoding features after the patch embedding layers, which are linear projection layers that produce patch embeddings for selfattention. Note that there are two separate patch embedding layers in iSegFormer (one for the input image and the other for the clicks encoding map), though Fig. 1 only shows one for brevity. The clicks embedding is essential for extending a segmentation model to an interactive segmentation model as it transforms user’s interactions from clicks to feature maps that can be fed into the network. For medical images which typically only have one gray channel, we simply replicate the gray channel to RGB format.
Clicks Encoding and Simulation:
We use clicks as the interaction mode due to their simplicity. Clicks can be either positive or negative: positive clicks indicate that particular points should be included in the segmentation, and negative clicks indicate that particular points should not be included in the segmentation.
Click Encoding:
We encode positive and negative clicks from coordinates to a 2-channel feature map with the same spatial size with the input image, following the strategy used in [30]. The clicks encoding map will be fed into the network along with the input image, as shown in Fig. 1 (bottom). During training and inference, we automatically simulate clicks based on the ground truth and current predicted segmentation for fast training and evaluation. A positive click is generated in the center of the false negative region in the predicted segmentation, and a negative click is generated in the center of the false positive region in the predicted segmentation.
Random perturbations
During training, we add random perturbations for the simulated clicks to increase robustness, as adopted in [6].
During inference, we remove the randomness for deterministic evaluation. Note that clicks simulation requires the ground truth, and simulated clicks may be different from clicks generated by human evaluation. Therefore, we present in the supplementary materials some qualitative results obtained by human evaluation.
Training and Inference Details:
For fair comparison with RITM [6], we adopt most of the hyper-parameters used in RITM for training and inference. The iSegFormer models are trained in a classagnostic binary segmentation task with the normalized focal loss function (NFL) [31]. We randomly crop the image to the size of 320 × 480 for training. We adopt the same data augmentation techniques with RITM [6] including random scaling and resizing. We implement iSegFormer using Pytorch with Adam optimizer. All experiments are conducted on a NVIDIA A6000 GPU. All models are trained 55 epochs with batch size as 32 (except the SegFormer and HRFormer models in Fig. 3). More details please refer to our codebase.
Extending to Interactive 3D Image Segmentation
iSegFormer can be easily extended to a 3D interactive segmentation approach by combining it with a segmentation propagation module (i.e., STCN [14]). Since this is not our main contribution, we introduce the details in the appendix.
3) Experimental Results:
Dataset:
The OAI-ZIB [32] dataset consists of 507 3D MR images with segmentations for femur, tibia, tibial cartilage, and femoral cartilage. In this work, we only consider cartilage segmentation. Each 3D image contains 160 slices of size of 384×384. We split the dataset randomly into 407 images for training, 50 images for validation, and 50 images for testing. Since we are interested in the problem setting where the segmentations for the 3D images are limited, we only use three segmented slices of each image in the training and validation sets for developing iSegFormer, resulting in 1521 training slices, 150 validation slices, and 150 testing slices. The three slices are selected at a fixed interval (ie., slice 40, 80, and 120). We also use 9 other public datasets in our cross-domain evaluation experiments. Please refer to Sec. 4.1 for details.
Metrics:
We use Number of Clicks (NoC) to measure the number of clicks required to achieve a predefined Intersection over Union (IoU) between predicted and ground truth segmentations. For example, NoC@85% measures the number of clicks required to obtain 85% IoU. We use an automatic evaluation procedure to simulate clicks during inference and report the quantitative results, following the practices used in [6]. We also perform a human evaluation for a qualitative study. For measuring the 3D segmentation results, we use the Dice Similarity Coefficient (DSC), sensitivity (SEN), and the positive predictive value (PPV).
Experimental Results:
We compare iSegFormer with RITM [6], the state-of-the-art CNN-based approach for interactive 2D femoral and tibial cartilage segmentation. Both RITM and iSegFormer are implemented on two segmentation backbones. For RITM, the backbones are UNet [33] and HRNet32 [34]. For iSegFormer, the backbones are Swin Transformer’s base and large models. For fair comparison, all the models are trained on the OAI-ZIB training set under the same training settings.
Tab. 1 reports the comparison results for tibial and femoral cartilage segmentation on the 150 slices of the OAI-ZIB testing set. The results show that iSegFormer outperforms its CNN counterparts with very competitive speed and GPU memory consumption, demonstrating the effectiveness and efficiency of iSegFormer for interactive segmentation.
Ablations:
Comparison with Other Transformer Backbones:
To further demonstrate the efficiency of iSegFormer, we also implemented iSegFormer using two recently proposed Transformer backbones for segmentation: HRFormer [11] and SegFormer [12]. As shown in Tab 3, our proposed Swin Transformer-based segmentation backbone is much more memory-efficient than the other Transformer-based backbones.
Cross-Domain Evaluation:
We have shown that iSegFormer outperforms CNNs when trained with only 1,221 labeled 2D slices (labeling such a dataset amounts to labeling 8 3D images with 160 slices). However, in many applications no segmented slices are available, for example, when studying new medical image datasets. Therefore, it is important to generalize the trained interactive segmentation models to unseen objects or objects in different domains. In this cross-domain evaluation, we train iSegFormer and RITM models on the COCO+LVIS [35] dataset, which contains millions of high-quality labels for natural images. Then we test the model on 5 natural image datasets (GrabCut, Berkeley, DAVIS, PascalVOC, and SBD) and 3 medical image datasets (ssTEM, BraTS, and OAI-ZIB).
The results are shown in Fig. 2. Although there is still a significant performance gap between in-domain and out-of-domain evaluations, both CNN and Transformer models generalize reasonably well to medical image datasets. Note that our models do not outperform the CNN counterpart in this experiment. We argue that HRNet is the best performing model in RITM with well-tuned hyper parameters, while we adopt most of their hyper parameters for Transformers and spend little effort in tuning them.
Results on Segmentation Propagation
Given the interactively segmented 2D slices obtained by iSegFormer, we now interested in 3D segmentation via a segmentation propagation model released by STCN [14]. The results in Tab. 2 show that with more segmented slices, the propagation results get better. With only 5 segmented slices, it achieves a Dice score of 82.2% for femoral cartilage segmentation. This is a very promising result considering that the segmentation propagation model was not trained on the medical images. We hope this preliminary experiment would attract more research effort in transferring knowledge from video domain to the medical imaging domain.
Ablation Study
We demonstrated in Sec. 4.1 that iSegFormer performed better than its CNN counterparts. Other than the architecture difference, the biggest difference comes from the pretraining settings. In Sec. 4.1, our iSegFormer models are pretrained on ImageNet21k, while the CNN models have two pretraining steps: first pretrained on ImageNet21k and then finetuned on the COCO+LVIS dataset. In this study, we adopt different pretrain settings for a more fair comparison between Transformer and CNN models. Note that the pre-training task can be either classification (Cls) or interactive segmentation (iSeg). As shown in Tab. 4, pretraining on Image21k is essential for the success of iSegFormer. More details are included in the supplementary materials.
References:
n2 n0
θ