[Swin] Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
Meituan Inc.
{Pure Transformer, Hierarchical, Window Self-Attention, Shifted Window}
Meituan Inc.
{Pure Transformer, Hierarchical, Window Self-Attention, Shifted Window}
Challenges in adapting Transformer from language to vision arise from differences between the two domains:
Large variations in the scale of visual entities. (Difference in the spatial size of input images, or size of objects in each image)
High resolution of pixels in images compared to words in a text. (separate words - word token, is done by separating the blank space, while in visual, it keeps at 16x16 in ViT)
Swin Transformer
Capably serves as a general-purpose backbone for computer vision.
A hierarchical Transformer whose representation is computed with shifted windows.
Flexibility to model at various scales and has linear computational complexity with respect to image size.
Shifted Window
Limiting self-attention computation to non-overlapping local windows. (only calculate self-attention in a local area)
Allowing for cross-window connection. (a window cover different patches at different times)
An image is passed through a SWIN transformer by dividing it into non-overlapping patches, just like ViT.
Each patch is called a token and is of size 4x4x3=48 pixels, where 3 is for the RGB channel and 4 is the height and width of the square patch.
(Transformers are permutation invariants, which means if you reorder the input and pass it through an encoder you will still get the same output, that is why we add positional encodings, but simple positional encodings are not enough for capturing the spatial correlations in an image, as we saw in ViT, but SWIN is not permutation invariant) - NerdFact.
Patch Merging combines 2x2 windows, and merges them into one new window, downsampling feature map size by 2x, and increasing the depth of each patch by 2.
Stage 1: (H/4, W/4, C) ==> Stage 2: (H/8, W/8, 2C) ==> Stage 3: (H/16, W/16, 4C) ==> Stage 4: (H/32, W/32, 8C)
Example: For our cat example image, windowing patches look like this. Here each window has 3x3 patches, so M=3. In the paper, each window has MxM = 7x7 patches. And each patch is 4x4x3=48 pixels.
How do we take 4 windows and merge them to make one window of the same size?
We put 4 windows on top of each other so that each patch’s dimension becomes 4C from C, now we pass each patch through a linear layer, which projects the dimension of the output patch to 2C from 4C and we get MxMx2C output merged patch!
Unlike simple global MSA (multi-headed self-attention) as computed in a typical encoder block.
In a SWIN transformer block, two encoders are placed in series, and the output of the first encoder is fed to the second one.
Window-based MSA (W-MSA) module
Shifted-window MSA (SW-MSA) module.
Both compute self-attention locally within each non-overlapping window, i.e., a group of neighboring patches.
The first module uses a regular window partitioning strategy.
The next module adopts a windowing configuration that is shifted from that of the preceding layer, by displacing the windows.
(Left) Transformer Encoder Block
(Right) SWIN Transformer🚀 Block
Window-based Self Attention (W-MSA):
SWIN transformer uses Encoder blocks from the original Transformer architecture.
Each encoder block is made of a Multi-headed Self-attention Module and a Feed-forward Network.
In a ViT, The multi-headed self-attention (MSA) uses dot product-based attention for computing attention encodings of each patch, w.r.t all other patches in an input image. So in the figure given below, for ViT, on right, if we want to compute attention for the top-left green patch, we attend to all other tokens, which becomes quadratic in computation!
In SWIN, W-MSA takes a fixed-sized window, such that each window has a fixed number of patches. For example, the left figure, has, 3x6 patches, but in the paper, authors take square windows and each window has MxM patches. Now for computing attention encoding of the top left patch, we just attend to the patches inside this window.
This approach is a lot more efficient and scalable than attending to all tokens for every token.
Supposing each window contains M×M patches, the computational complexity of a global MSA module and a window-based one on an image of h×w patches are:
where the former is quadratic to patch number hw, and the latter is linear when M is fixed (set to 7 by default).
Self Attention Across “Shifted” Windows (SW-MSA):
Problem: If we just rely on window-based attention, then the correlations between windows would be missing, , which limits its modeling power. They are important for performing vision tasks.
For capturing the attention across windows (right below figure), the authors used SW-MSA, Shifted Windowed Multi-headed Self-attention:
Taking the output of the W-MSA.
Shifting all windows by half of their height and width
Compute W-MSA in shifted windows
For example, shifting all 4 windows by 4/2=2 patches down and then 2 patches left, will give you new shifted windows on right.
Efficient Batch Computation of Shifted Windows:
Problem: Border windows are not of MxM size in shifted windowed attention and the number of windows is increased.
In CNNs, we pad the image to apply a filter on borders.
In SWIN authors could have just padded the smaller windows too! but as the number of windows is increasing, it makes the computational complexity grow.
“Reducing inference time is critical because it can later be traded off with accuracy by using larger networks.”
Solve: Cyclic-shifting toward the top-left direction
Our Input with Window Partition.
Imagine we take the window, shift it down and left by 2x2. then we take patches from A, B, and C and fill them in the empty spaces.
Now we take each window, and mask regions as shown in the image, to make sure attention is computed among the desired parts in a window.
Finally, we reverse shift the window and fill the top left part of the new window with patches from the bottom right and repeat! This process allows to compute cross window attention efficiently as the number of windows remains the same as WMSA.
Relative Position Bias (Rel. Pos.):
Relative position bias B is included to each head in computing similarity.
which observes significant improvements over counterparts without this bias term or that use absolute position embedding.
Further adding absolute position embedding to the input drops performance slightly.
Difference with ViT:
The grid of windows is shifted by half of the window size.
By limiting the attention to be inside the window, the computational complexity is linear if the window size is fixed.
It destroys a key property of attention in ViT in which each patch can be globally associated with each other in one attention process. The Swin Transformer solves this problem by alternating between W-MSA and SW-MSA in two consecutive layers, allowing the information to propagate to a larger area when going deeper.
Architecture:
Swin Transformer can progressively produce feature maps with a smaller resolution while increasing the number of channels in the feature maps.
The authors have divided the flow of an input image from a SWIN transformer into 4 stages, let’s see each one of them.
Input:
An input image HxWx3 is passed through a Patch Partition, to split it into fixed-sized patches.
Each patch is treated as a “token” and its feature is set as a concatenation of the raw pixel RGB values.
A patch size of 4×4 is used and thus the feature dimension of each patch is 4×4×3=48.
Output ==> (H/4, W/4, 48)
Stage 1:
A Linear Embedding layer (48xC) is applied to this raw-valued feature to project it to an arbitrary dimension C. Output feature map of size: H/4 x W/4 x C.
This feature map is then passed through SWIN Transformer Blocks (x2). The size of inputs and outputs remains the same.
Output ==> (H/4, W/4, C)
Stage 2:
To produce a hierarchical representation, the number of tokens is reduced by Patch Merging layers as the network gets deeper.
The feature map from 1st stage, is now passed through a Patch Merging layer, which combines (2x2) neighboring windows, downsampling the resolution by 2x and increasing the feature map depth by 2. Output feature map of size: H/8 x W/8 x 2C.
This feature map is passed through other SWIN transformer blocks (x2), which keeps its dimensions intact.
Output ==> (H/8, W/8, 2C)
Stage 3 and 4:
The procedure is repeated twice, as “Stage 3” and “Stage 4”, with output resolutions of H/16×W/16 and H/32×W/32, respectively.
These stages jointly produce a hierarchical representation, with the same feature map resolutions as those of typical convolutional networks, such as VGGNet and ResNet, which can conveniently replace the backbone networks in existing methods for various vision tasks.
Scale:
Base model, called Swin-B, have the model size and computation complexity similar to ViT-B/DeiT-B.
Swin-T, Swin-S and Swin-L, are versions of about 0.25×, 0.5× and 2× the model size and computational complexity, respectively. The complexity of Swin-T and Swin-S are similar to those of ResNet-50 (DeiT-S) and ResNet-101, respectively.
The window size is set to M=7 by default. The query dimension of each head is d=32, and the expansion layer of each MLP is α=4, for all experiments. The architecture hyper-parameters of these model variants are:
Swin-T: C = 96, layer numbers = {2, 2, 6, 2}
Swin-S: C = 96, layer numbers ={2, 2, 18, 2}
Swin-B: C = 128, layer numbers ={2, 2, 18, 2}
Swin-L: C = 192, layer numbers ={2, 2, 18, 2}
where C is the channel number of the hidden layers in the first stage.