TransFuse: Fusing Transformers and CNNs for Medical Image Segmentation
{, }
Paper: https://link.springer.com/content/pdf/10.1007/978-3-030-87193-2.pdf?pdf=button
Code:
{, }
Paper: https://link.springer.com/content/pdf/10.1007/978-3-030-87193-2.pdf?pdf=button
Code:
1) Motivation, Objectives and Related Works :
Motivation:
Medical image segmentation - the prerequisite of numerous clinical needs - has been significantly prospered by recent advances in convolutional neural networks (CNNs).
However, it exhibits general limitations on modeling explicit long-range relations, and existing cures, resorting to building deep encoders along with aggressive downsampling operations, leads to redundant deepened networks and loss of localized details.
Hence, the segmentation task awaits a better solution to improve the efficiency of modeling global contexts while maintaining a strong grasp of low-level details.
Objectives:
Propose a novel parallel-in-branch architecture, TransFuse, to address this challenge.
Combines Transformers and CNNs in a parallel style, where both global dependency and low-level spatial details can be efficiently captured in a much shallower manner.
Fusion technique - BiFusion module is created to efficiently fuse the multi-level features from both branches.
Extensive experiments demonstrate that TransFuse achieves the newest state-of-the-art results on both 2D and 3D medical image sets including polyp, skin lesion, hip, and prostate segmentation, with significant parameter decrease and inference speed improvement.
Introduction:
Convolutional neural networks (CNNs) have attained unparalleled performance in numerous medical image segmentation tasks [9,12], such as multi-organ segmentation, liver lesion segmentation, brain 3D MRI, etc., as it is proved to be powerful at building hierarchical task-specific feature representation by training the networks end-to-end. Despite the immense success of CNN-based methodologies, its lack of efficiency in capturing global context information remains a challenge. The chance of sensing global information is equaled by the risk of efficiency, because existing works obtain global information by generating very large receptive fields, which requires consecutively down-sampling and stacking convolutional layers until deep enough. This brings several drawbacks: 1) training of very deep nets is affected by the diminishing feature reuse problem [23], where low-level features are washed out by consecutive multiplications; 2) local information crucial to dense prediction tasks, e.g., pixel-wise segmentation, is discarded, as the spatial resolution is reduced gradually; 3) training parameter heavy deep nets with small medical image datasets tends to be unstable and easily overfitting. Some studies [29] use the non-local self-attention mechanism to model global context; however, the computational complexity of these modules typically grows quadratically with respect to spatial size, thus they may only be appropriately applied to low-resolution maps.
Transformer, originally used to model sequence-to-sequence predictions in NLP tasks [26], has recently attracted tremendous interests in the computer vision community. The first purely self-attention based vision transformers (ViT) for image recognition is proposed in [7], which obtained competitive results on ImageNet [6] with the prerequisite of being pretrained on a large external dataset. SETR [32] replaces the encoders with transformers in the conventional encoderdecoder based networks to successfully achieve state-of-the-art (SOTA) results on the natural image segmentation task. While Transformer is good at modeling global context, it shows limitations in capturing fine-grained details, especially for medical images. We independently find that SETR-like pure transformerbased segmentation network produces unsatisfactory performance, due to lack of spatial inductive-bias in modelling local information (also reported in [4]).
To enjoy the benefit of both, efforts have been made on combining CNNs with Transformers, e.g., TransUnet [4], which first utilizes CNNs to extract low-level features and then passed through transformers to model global interaction. With skip-connection incorporated, TransUnet sets new records in the CT multi-organ segmentation task. However, past works mainly focus on replacing convolution with transformer layers or stacking the two in a sequential manner. To further unleash the power of CNNs plus Transformers in medical image segmentation, in this paper, we propose a different architecture—TransFuse, which runs shallow CNN-based encoder and transformer-based segmentation network in parallel, followed by our proposed BiFusion module where features from the two branches are fused together to jointly make predictions. TransFuse possesses several advantages: 1) both low-level spatial features and high-level semantic context can be effectively captured; 2) it does not require very deep nets, which alleviates gradient vanishing and feature diminishing reuse problems; 3) it largely improves efficiency on model sizes and inference speed, enabling the deployment at not only cloud but also edge. To the best of our knowledge, TransFuse is the first parallel-in-branch model synthesizing CNN and Transformer. Experiments demonstrate the superior performance against other competing SOTA works.
Related Works:
Contribution:
2) Methodology:
As shown in Fig. 1, TransFuse consists of two parallel branches processing information differently: 1) CNN branch, which gradually increases the receptive field and encodes features from local to global; 2) Transformer branch, where it starts with global self-attention and recovers the local details at the end. Features with same resolution extracted from both branches are fed into our proposed BiFusion Module, where self-attention and bilinear Hadamard product are applied to selectively fuse the information. Then, the multi-level fused feature maps are combined to generate the segmentation using gated skip-connection [20]. There are two main benefits of the proposed branch-in-parallel approach: firstly, by leveraging the merits of CNNs and Transformers, we argue that TransFuse can capture global information without building very deep nets while preserving sensitivity on low-level context; secondly, our proposed BiFusion module may simultaneously exploit different characteristics of CNNs and Transformers during feature extraction, thus making the fused representation powerful and compact.
Transformer Branch
The design of Transformer branch follows the typical encoder-decoder architecture. Specifically, the input image x ∈ R^(H×W×3) is first evenly divided into N = H/S × W/S patches, where S is typically set to 16. The patches are then flattened and passed into a linear embedding layer with output dimension D0, obtaining the raw embedding sequence e ∈ R^(N×D0) . To utilize the spatial prior, a learnable positional embeddings of the same demension is added to e. The resulting embeddings z0 ∈ R^(N×D0) is the input to Transformer encoder, which contains L layers of multiheaded self-attention (MSA) and Multilayer Perceptron (MLP). We highlight that the self-attention (SA) mechanism, which is the core principal of Transformer, updates the states of each embedded patch by aggregating information globally in every layer:
where [q, k, v] = zWqkv, Wqkv ∈ RD0×3Dh is the projection matrix and vector zi ∈ R1×D0 , qi ∈ R1×Dh are the i th row of z and q, respectively. MSA is an extension of SA that concatenates multiple SAs and projects the latent dimension back to RD0 , and MLP is a stack of dense layers (refer to [7] for details of MSA and MLP). Layer normalization is applied to the output of the last transformer layer to obtain the encoded sequence zL ∈ RN×D0 . For the decoder part, we use progressive upsampling (PUP) method, as in SETR [32]. Specifically, we first reshape zL back to t0 ∈ R H 16 × W 16 ×D0 , which could be viewed as a 2D feature map with D0 channels. We then use two consecutive standard upsampling-convolution layers to recover the spatial resolution, where we obtain t1 ∈ R H 8 × W 8 ×D1 and t2 ∈ R H 4 × W 4 ×D2 , respectively. The feature maps of different scales t0, t1 and t2 are saved for late fusion with corresponding feature maps of the CNN branch.
CNNs Branch
Traditionally, features are progressively downsampled to H 32 × W 32 and hundreds of layers are employed in deep CNNs to obtain global context of features, which results in very deep models draining out resources. Considering the benefits brought by Transformers, we remove the last block from the original CNNs pipeline and take advantage of the Transformer branch to obtain global context information instead. This gives us not only a shallower model but also retaining richer local information. For example, ResNet-based models typically have five blocks, each of which downsamples the feature maps by a factor of two. We take the outputs from the 4th (g0 ∈ R H 16 × W 16 ×C0 ), 3rd (g1 ∈ R H 8 × W 8 ×C1 ) and 2nd (g2 ∈ R H 4 × W 4 ×C2 ) blocks to fuse with the results from Transformer (Fig. 1). Moreover, our CNN branch is flexible that any off-the-shelf convolutional network can be applied.
BiFusion Module
To effectively combine the encoded features from CNNs and Transformers, we propose a new BiFusion module (refer to Fig. 1) that incorporates both self-attention and multi-modal fusion mechanisms. Specifically, we obtain the fused feature representation f i , i = 0, 1, 2 by the following operations: ˆti = ChannelAttn(ti ) bˆi = Conv(ti Wi 1 gi Wi 2) gˆi = SpatialAttn(gi ) f i = Residual([bˆi ,ˆti , gˆi ]) (2) where Wi 1 ∈ RDi×Li , Wi 2 ∈ RCi×Li , || is the Hadamard product and Conv is a 3 × 3 convolution layer. The channel attention is implemented as SE-Block proposed in [10] to promote global information from the Transformer branch. The spatial attention is adopted from CBAM [30] block as spatial filters to enhance local details and suppress irrelevant regions, as low-level CNN features could be noisy. The Hadamard product then models the fine-grained interaction between features from the two branches. Finally, the interaction features bˆi and attended features ˆti , gˆi are concatenated and passed through a Residual block. The resulting feature f i effectively captures both the global and local context for the current spatial resolution. To generate final segmentation, f i s are combined using the attention-gated (AG) skip-connection [20], where we have ˆf i+1 = Conv([Up(ˆf i ), AG(f i+1, Up(ˆf i ))]) and ˆf 0 = f 0, as in Fig. 1.
Loss Function
The full network is trained end-to-end with the weighted IoU loss and binary cross entropy loss L = Lw IoU +Lw bce, where boundary pixels receive larger weights [17]. Segmentation prediction is generated by a simple head, which directly resizes the input feature maps to the original resolution and applies convolution layers to generate M maps, where M is the number of classes. Following [8], We use deep supervision to improve the gradient flow by additionally supervising the transformer branch and the first fusion branch. The final training loss is given by L = αL G, head(ˆf 2) + γL G, head(t2) + βL G, head(f 0) , where α, γ, β are tunnable hyperparameters and G is groundtruth.
3) Personal Ideas:
Method 1:
Method 2:
References:
n2 n0
θ