[SimpleClick] SimpleClick: Interactive Image Segmentation with Simple Vision Transformers
{, }
Paper: https://arxiv.org/pdf/2210.11006v2.pdf
Code:
{, }
Paper: https://arxiv.org/pdf/2210.11006v2.pdf
Code:
1) Motivation, Objectives and Related Works:
Motivation:
Click-based interactive image segmentation aims at extracting objects with limited user clicking.
A hierarchical backbone is the de-facto architecture for current methods.
Vision Transformer (ViT), non-hierarchical method, has not yet been explored for interactive segmentation.
Objectives: SimpleClick
Propose the first plain backbone method for interactive segmentation.
With the plain backbone pretrained as a masked autoencoder (MAE), SimpleClick achieves state-of-the-art performance.
Remarkably, our method achieves 4.15 NoC@90 on SBD, improving 21.8% over the previous best result.
Related Works:
Click-based approaches:
Mainly lie in two orthogonal directions:
The development of more effective backbone networks. ==> Different hierarchical backbones, including both ConvNets [29,42] and ViTs [10,32], have been developed for interactive segmentation.
The exploration of more elaborate refinement modules built upon the backbone. ==> Various refinement modules, including local refinement [10, 29] and click imitation [33], have been proposed to further boost segmentation performance.
In this work, we delve into the former direction and focus on exploring a plain backbone for interactive segmentation.
Interactive image segmentation
Various applications:
Video understanding [5, 48]
self-driving [7]
medical imaging [31, 40].
Early works:
[6, 16, 18, 39] tackles this problem using graphs defined over image pixels.
However, these methods only focus on low-level image features, and therefore tend to have difficulty with complex objects.
ConvNets:
[10, 29, 42, 46, 47] have evolved as the dominant architecture for high-quality interactive segmentation.
Different interaction types:
bounding boxes [46]
polygons [1]
clicks [42] (most common due to its simplicity and well-established training and evaluation protocols)
scribbles [44]
combinations [50].
Exploring improving the backbone.
Xu et al. [47] first proposed a click simulation strategy that has been adopted by follow-up work [10, 33, 42].
DEXTR [35] extracts a target object by specifying its four extreme points (left-most, right-most, top, bottom pixels).
FCA-Net [30] demonstrates the critical role of the first click for better segmentation.
FocalClick [10] uses SegFormer [45] as the backbone network and achieves state-of-the-art segmentation results with high computational efficiency.
iSegFormer [32] uses a Swin Transformer [34] as the backbone network for interactive segmentation on medical images.
Exploring elaborate refinement modules built upon the backbone.
FocalClick [10] and FocusCut [29] propose similar local refinement modules for high-quality segmentation.
PseudoClick [33] proposes a click-imitation mechanism by estimating the next click to further reduce human annotation cost.
Vision Transformers for Non-Interactive Segmentation
ViT-based approaches [17, 24, 43, 45, 49] have shown competitive performance on segmentation tasks compared to ConvNets.
The original ViT [13] is a nonhierarchical architecture that only maintains single-scale feature maps throughout.
SETR [52] and Segmenter [43] use the original ViT as the encoder for semantic segmentation.
To allow for more efficient segmentation, the Swin Transformer [34] reintroduces a computational hierarchy into the original ViT architecture using shifted window attention, leading to a highly efficient hierarchical ViT backbone.
SegFormer [45] designs hierarchical feature representations based on the original ViT using overlapped patch merging, combined with a lightweight MLP decoder for efficient segmentation.
HRViT [17] integrates a high-resolution multi-branch architecture with ViTs to learn multi-scale representations.
Recently, the original ViT has been reintroduced as a competitive backbone for semantic segmentation [8] and object detection [25], with the aid of MAE [21] pretraining and window attention.
Hierarchical Backbone
Predominant architecture for current interactive segmentation methods:
ResNet [22],
ViTs
Swin Transformer [34].
To increase the receptive field size:
ConvNets have to progressively downsample feature maps to capture more global contextual information ==> require a feature pyramid network such as FPN [27] to aggregate multi-scale representations for high-quality segmentation.
This reasoning no longer applies for a plain ViT, in which global information can be captured from the first self-attention block.
Because all feature maps in the ViT are of the same resolution, the motivation for an FPN-like feature pyramid also no longer remains.
A plain ViT can serve as a strong backbone for object detection [25].
This finding indicates a general-purpose ViT backbone might be suitable for other tasks, which then can decouple pretraining from finetuning and transfer the benefits from readily available pretrained ViT models (e.g. MAE [21]) to these tasks.
However, although this design is simple and has been proven effective, it has not yet been explored in interactive segmentation.
Contribution:
SimpleClick, the first plain-backbone method for interactive image segmentation.
SimpleClick achieves state-of-the-art performance on natural images and shows strong generalizability on medical images.
SimpleClick meets the computational efficiency requirement for a practical annotation tool, highlighting its readiness for real-world applications.
2) Methodology:
SimpleClick, the first plain-backbone method for interactive segmentation:
The core of SimpleClick is a plain ViT backbone that maintains single-scale representations throughout.
Only use the last feature map from the plain backbone to build a simple feature pyramid for segmentation, largely decoupling the general-purpose backbone from the segmentation-specific modules.
To make SimpleClick more efficient, we use a lightweight MLP decoder to transform the simple feature pyramid into a segmentation.
Adaptation of Plain-ViT Backbone
Use a plain ViT [13] as our backbone network, which only maintains single-scale feature maps throughout.
The patch embedding layer divides the input image into non-overlapping fixed-size patches (e.g. 16×16 for ViT-B), each patch is flattened and linearly projected to a fixed-length vector (e.g. 768 for ViT-B).
The resulting sequence of vectors is fed into a queue of Transformer blocks (e.g. 12 for ViT-B) for self-attention.
Implement SimpleClick with three backbones: ViTB, ViT-L, and ViT-H (Tab. 1 shows the number of parameters for the three backbones).
The three backbones were pretrained on ImageNet-1k as MAEs [21].
We adapt the pretrained backbones to higher-resolution inputs during finetuning using non-shifting window attention aided by a few global self-attention blocks (e.g. 2 for ViT-B), as introduced in ViTDet [25].
Since the last feature map is subject to all the attention blocks, it should have the strongest representation. Therefore, we only use the last feature map to build a simple multi-scale feature pyramid
Simple Feature Pyramid
A feature pyramid is commonly produced by an FPN [27] to combine features from different stages.
For the plain backbone, a feature pyramid can be generated in a much simpler way: by a set of parallel convolutional or deconvolutional layers using only the last feature map of the backbone .
As shown in Fig. 2, given the input ViT feature map, a multiscale feature map can be produced by four convolutions with different strides. Though the effectiveness of this simple feature pyramid design is first demonstrated in ViTDet [25] for object detection, we show in this work the effectiveness of this simple feature pyramid design for interactive segmentation.
We also propose two additional variants (Fig. 6) as part of an ablation study (Sec. 4.3).
All-MLP Segmentation Head
We implement a lightweight segmentation head using only MLP layers.
It takes in the simple feature pyramid and produces a segmentation probability map1 of scale 1∕4, followed by an upsampling operation to recover the original resolution.
Note that this segmentation head avoids computationally demanding components and only accounts for up to 1% of the model parameters (Tab. 1). The key insight is that with a powerful pretrained backbone, a lightweight segmentation head is sufficient for interactive segmentation. The proposed all-MLP segmentation head works in three steps. First, each feature map from the simple feature pyramid goes through an MLP layer to transform it to an identical channel dimension (i.e. 퐶2 in Fig. 2). Second, all feature maps are upsampled to the same resolution (i.e. 1∕4 in Fig. 2) for concatenation. Third, the concatenated features are fused by another MLP layer to produce a single-channel feature map, followed by a sigmoid function to obtain a segmentation probability map, which is then transformed to a binary segmentation given a predefined threshold (i.e. 0.5).
Other Modules
The user clicks are encoded in a twochannel disk map, one for positive clicks and the other for the negative clicks. The positive clicks should be placed on the foreground, while the negative clicks should be placed on the background. The previous segmentation and the twochannel click map are concatenated as a three-channel map for patch embedding. Two separate patch embedding layers operate on the image and the concatenated three-channel map, respectively. The two inputs are patchified, flattened, and projected to two vector sequences of the same dimension, followed by an element-wise addition before inputting them into the self-attention blocks.
3) Experimental Results:
Dataset:
We conducted experiments on 10 public datasets including 7 natural image datasets and 3 medical datasets. The details are as follows:
GrabCut [39]: 50 images (50 instances), each with clear foreground and background differences.
Berkeley [36]: 96 images (100 instances); this dataset shares a small portion of images with GrabCut.
DAVIS [38]: 50 videos; we only use the same 345 frames as used in [10, 29, 33, 42] for evaluation.
Pascal VOC [14]: 1449 images (3427 instances) in the validation set. We only test on the validation set.
SBD [20]: 8498 training images (20172 instances) and 2857 validation images (6671 instances). Following previous works [10, 29, 42], we train our model on the training set and evaluate on the validation set.
COCO [28]+LVIS [19] (C+L): COCO contains 118K training images (1.2M instances); LVIS shares the same images with COCO but has much higher segmentation quality. We combine the two datasets for training.
ssTEM [15]: two image stacks, each contains 20 medical images. We use the same stack that was used in [33].
BraTS [4]: 369 magnetic resonance image (MRI) volumes; we test on the same 369 slices used in [33].
OAIZIB [2]: 507 MRI volumes; we test on the same 150 slices (300 instances) as used in [32].
Metrics:
Following previous works [29,41,42], we automatically simulate user clicks by comparing the current segmentation with the gold standard. In this simulation, the next click will be put at the center of the region with the largest error. We use the Number of Clicks (NoC) as the evaluation metric to calculate the number of clicks required to achieve a target Intersection over Union (IoU). We set two target IoUs: 85% and 90%, represented by NoC%85 and NoC%90 respectively. The maximum number of clicks for each instance is set to 20. We also use the average IoU given 푘 clicks (mIoU@푘) as an evaluation metric to measure the segmentation quality given a fixed number of clicks.
Implementation Details:
We implement our models using Python and PyTorch [37]. We implement three models based on three vanilla ViT models (i.e. ViT-B, ViT-L, and ViT-H). These backbone models are initialized with the MAE pretrained weights, and then are finetuned end-to-end with other modules. We train our models on either SBD or COCO+LVIS with 55 epochs; the initial learning rate is set to 5 × 10−5 and decreases to 5 × 10−6 after epoch 50. We set the batch size to 140 for ViT-Base, 72 for ViT-Large, and 32 for ViT-Huge to fit the models into GPU memory. All our models are trained on four NVIDIA RTX A6000 GPUs. We use the following data augmentation techniques: random resizing (scale range from 0.75 to 1.25), random flipping and rotation, random brightness contrast, and random cropping. Though the ViT backbone was pretrained on images of size 224×224, we finetune on 448 × 448 with non-shifting window attention for better performance. We optimize using Adam with 훽1 = 0.9, 훽2 = 0.999.
Backbone Pretraining
Backbone Pretraining Our backbone models are pretrained as MAEs [21] on ImageNet-1K [11]. In MAE pretraining, the ViT models reconstruct the randomly masked pixels of images while learning a universal representation. This simple self-supervised approach turns out to be an efficient and scalable way to pretrain ViT models [21]. In this work, we do not perform pretraining ourselves. Instead, we simply use the readily available pretrained MAE weights from [21].
End-to-end Finetuning
With the pretrained backbone, we finetune our model end-to-end on the interactive segmentation task. The finetuning pipeline can be briefly described as follows. First, we automatically simulate clicks based on the current segmentation and gold standard segmentation, without a human-in-the-loop providing the clicks. Specifically, we use a combination of random and iterative click simulation strategies, inspired by RITM [42]. The random click simulation strategy generates clicks in parallel, without considering the order of the clicks. The iterative click simulation strategy generates clicks iteratively, where the next click should be placed on the erroneous region of a prediction which was obtained using the previous clicks. This strategy is more similar to human clicking behavior. Second, we incorporate the segmentation from the previous interaction as an additional input for the backbone, furtheri mproving the segmentation quality. This also allows our method to refine from an existing segmentation, which is a desired feature for a practical annotation tool. We use the normalized focal loss [42] (NFL) to train all our models. Previous works [10,42] show that NFL converges faster and achieves better performance than the widely used binary cross entropy loss for interactive segmentation tasks. Similar training pipelines have been proposed by RITM [42] and its follow-up works [9, 10, 33].
Inference
Inference There are two inference modes: automatic evaluation and human evaluation. For automatic evaluation, clicks are automatically simulated based on the current segmentation and gold standard. For human evaluation, a human-in-the-loop provides all clicks based on their subjective evaluation of current segmentation results. We use automatic evaluation for quantitative analyses and human evaluation for a qualitative assessment of the interactive segmentation behavior.
Experimental Results:
Comparison with Previous Results
We show in Tab. 2 the comparisons with previous stateof-the-art results. Our models achieves the best performance on all the five benchmarks. Remarkably, when trained on SBD training set, our ViT-H model achieves 4.15 NoC@90 on the SBD validation set, outperforming the previous best score by 21.8%. Since the SBD validation set contains the largest number of instances (6671 instances) among the five benchmarks, this improvement is convincing. When trained on COCO+LVIS, our models also achieve the state-of-theart performance on all benchmarks. Fig. 7 shows several segmentation cases on DAVIS, including the worst case. Note that the DAVIS dataset requires high-quality segmentations because all its instances have a high-quality gold standard. Our models still achieve the state-of-the-art on DAVIS without using specific modules, such as a local refinement module [10], which is beneficial for high-quality segmentation. Fig. 3 shows that our method converges better than other methods with sufficient clicks, leading to fewer failure cases as shown in Fig. 4. We only report results on SBD and Pascal VOC, which are the two largest datasets for evaluation.
Out-of-Domain Evaluation on Medical Images
We further evaluate the generalizability of our models on three medical image datasets: ssTEM [15], BraTS [3], and OAIZIB [2]. Tab. 3 reports the evaluation results on these three datasets. Fig. 5 shows the convergence analysis on BraTS and OAIZIB. Overall, our models generalize well to medical images. We also find that the models trained on larger datasets (i.e. C+L) generalize better than the models trained on smaller datasets (i.e. SBD).
Ablations:
In this section, we ablate the backbone finetuning and feature pyramid design. Tab. 4 shows the ablation results. By default, we finetune the backbone along with other modules. As an ablation, we freeze the backbone during finetuning, leading to significantly worse performance. This ablation is explainable considering the ViT backbone takes most of the model parameters (Tab. 1). For the second ablation, we compare the default simple feature pyramid design with two variants depicted in Fig. 6 (i.e. (b) and (c)). First, we observe that the multi-scale representation matters for the feature pyramid. By ablating the multi-scale property in the simple feature pyramid, the performance drops considerably. We also notice that the last feature map from the backbone is strong enough to build the feature pyramid. The parallel feature pyramid generated by multi-stage feature maps from the backbone does not surpass the simple feature pyramid that only uses the last feature map of the backbone.
Computational Analysis
Tab. 5 shows a comparison of computational requirements with respect to model parameters, FLOPs, GPU memory consumption, and speed; the speed is measured by seconds per click (SPC). Fig. 1 shows the interactive segmentation performance of methods in terms of FLOPs. In Fig. 1 and Tab. 5, each method is denoted by its backbone. For fair comparison, we evaluate all the methods on the same benchmark (i.e. GrabCut) and using the same computer (GPU: NVIDIA RTX A6000, CPU: Intel Silver×2). We only calculate the FLOPs in a single forward pass. For methods like FocusCut which require multiple forward passes for each click, the FLOPs may be much higher than reported. Our method takes by default images of size 448×448 as the fixed input. Even for our ViT-H model, the speed (132ms) and memory consumption (3.22G) is sufficient to meet the requirements of a practical annotation tool.
References:
n2 n0
θ