[SAM] Segment Anything Model
{Promt, Interactive Segmentation, SA-1B Dataset, Zeroshot}
{Promt, Interactive Segmentation, SA-1B Dataset, Zeroshot}
Reducing the need for task-specific modeling expertise, training compute, and custom data annotation for image segmentation.
Build a foundation model for object segmentation.
SAM = Interactive Segmentation + Automatic Segmentation
Segment Anything: Simultaneously develop a general, promptable segmentation model and use it to create a segmentation dataset of unprecedented scale.
SAM allows users to interactively segment objects (click, bounding box, text)
SAM can output multiple valid masks.
SAM can automatically find and mask all objects in an image.
SAM can generate a segmentation mask for any prompt in real time after precomputing the image embedding, allowing for real-time interaction with the model.
Models aligning paired text and images: CLIP [82] (contrastive language-image pretraining) and ALIGN [55] (use contrastive learning to train text and image encoders that align the two modalities).
Model generate images: DALL-E [83]
[66, 44, 117, 60]
Task: Promptable Segmentation Task - return a valid segmentation mask given any segmentation prompt.
Model:
A powerful image encoder computes an image embedding.
A prompt encoder embeds prompts.
The two information sources are combined in a lightweight mask decoder that predicts segmentation masks.
Data Engine: Co-develop the model with model-in-the-loop dataset annotation.
Assisted-manual.
Semi-automatic.
Fully automatic.
Dataset: SA-1B, 1B masks from 11 licensed and privacy-preserving images.
Inspiration: “prompting” techniques - where foundation models can perform zero-shot and few-shot learning for new datasets and tasks.
Constraints: Run in real-time on a CPU in a web browser to allow our annotators to use SAM interactively in real time to annotate efficiently.
Result: Return a valid segmentation mask for any prompt.
Framework:
An image encoder produces a one-time embedding for the image.
A prompt encoder converts any prompt into an embedding vector in real-time.
These two information sources are then combined in a lightweight decoder that predicts segmentation masks.
After the image embedding is computed, SAM can produce a segment in just 50 milliseconds given any prompt in a web browser.
SAM has three components:
An image encoder.
A flexible prompt encoder.
A light-weight mask decoder.
We build on Transformer vision models [14, 33, 20, 62] with specific tradeoffs for (amortized) real-time performance.
Parameters:
Encoder: 632M parameters.
Decoder (prompt and mask): 4M parameters.
Input: An input resolution of 1024×1024 pixels, obtained by rescaling the image and padding the shorter side.
Encoder:
Use an MAE [47] pre-trained Vision Transformer (ViT) [33] with minimal adaptations to process high-resolution inputs, specifically a ViT-H/16 with 14×14 windowed attention and four equally-spaced global attention blocks, following [62].
To reduce the channel dimension, following [62], we use a 1×1 convolution to get to 256 channels, followed by a 3×3 convolution also with 256 channels. Each convolution is followed by a layer normalization [4].
Output: a 16×downscaled embedding of the input image (ex: 64×64).
Sparse (foreground/ background points, bounding boxes, text)
Represent points and boxes by positional encodings summed with learned embeddings for each prompt type.
Free-form text with an off-the-shelf text encoder from CLIP.
Dense (masks).
Dense prompts are embedded using convolutions and summed element-wise with the image embedding.
Sparse prompts are mapped to 256-dimensional vectorial embeddings as follows.
A point is represented as the sum of a positional encoding [95] of the point’s location and one of two learned embeddings that indicate if the point is either in the foreground or background.
A box is represented by an embedding pair: (1) the positional encoding of its top-left corner summed with a learned embedding representing “top-left corner” and (2) the same structure but using a learned embedding indicating “bottom-right corner”.
To represent free-form text we use the text encoder from CLIP [82] (any text encoder is possible in general). We focus on geometric prompts for the remainder of this section and discuss text prompts in depth in §D.5.
Dense prompts (i.e., masks) have a spatial correspondence with the image.
Input masks at a 4× lower resolution than the input image, then downscale an additional 4× using two 2×2, stride-2 convolutions with output channels 4 and 16, respectively.
A final 1×1 convolution maps the channel dimension to 256.
Each layer is separated by GELU activations [50] and layer normalization.
The mask and image embedding are then added element-wise. If there is no mask prompt, a learned embedding representing “no mask” is added to each image embedding location.
Sinusoidal-concatenation encoding.
2D Fourier Feature PE
Ref:
Learnable Fourier Features for Multi-Dimensional Spatial Positional Encoding
Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains
This module efficiently maps the image embedding and a set of prompt embeddings to an output mask.
Before applying our decoder, we first insert into the set of prompt embeddings (simplicity, called prompt tokens) a learned output token embedding (output token) that will be used at the decoder’s output, analogous to the [class] token in BERT [33].
Decoder design: Each decoder layer performs 4 steps:
Self-attention on the tokens.
Cross-attention from tokens (as queries) to the image embedding.
A point-wise MLP updates each token.
Cross-attention from the image embedding (as queries) to tokens.
To ensure the decoder has access to critical geometric information the positional encodings are added to the image embedding whenever they participate in an attention layer. Additionally, the entire original prompt tokens (including their positional encodings) are re-added to the updated tokens whenever they participate in an attention layer. This allows for a strong dependence on both the prompt token’s geometric location and type.
After running the decoder, we upsample the updated image embedding by 4× with two transposed convolutional layers (now it’s downscaled 4× relative to the input image). Then, the tokens attend once more to the image embedding and we pass the updated output token embedding to a small 3-layer MLP that outputs a vector matching the channel dimension of the upscaled image embedding. Finally, we predict a mask with a spatially point-wise product between the upscaled image embedding and the MLP’s output. The transformer uses an embedding dimension of 256.
The transformer MLP blocks have a large internal dimension of 2048, but the MLP is applied only to the prompt tokens for which there are relatively few (rarely greater than 20). However, in cross-attention layers where we have a 64×64 image embedding, we reduce the channel dimension of the queries, keys, and values by 2× to 128 for computational efficiency. All attention layers use 8 heads.
The transposed convolutions used to upscale the output image embedding are 2×2, stride 2 with output channel dimensions of 64 and 32 and have GELU activations. They are separated by layer normalization.
Similar to BERT's [CLS] token and Vision Transformer [class] token.
One token representing the IoU and three tokens representing the output masks (part, subpart, whole) are added before the sparse token embeddings (points, boxes, and text).
As described, a single input prompt may be ambiguous in the sense that it corresponds to multiple valid masks, and the model will learn to average over these masks.
We eliminate this problem with a simple modification: instead of predicting a single mask, we use a small number of output tokens and predict multiple masks simultaneously.
By default we predict three masks, since we observe that three layers (whole, part, and subpart) are often enough to describe nested masks.
During training, we compute the loss (described shortly) between the ground truth and each of the predicted masks, but only backpropagate from the lowest loss. This is a common technique used for models with multiple outputs [15, 45, 64].
For use in applications, we’d like to rank predicted masks, so we add a small head (operating on an additional output token) that estimates the IoU between each predicted mask and the object it covers.
Ambiguity is much rarer with multiple prompts and the three output masks will usually become similar.
To minimize the computation of degenerate losses at training and ensure the single unambiguous mask receives a regular gradient signal, we only predict a single mask when more than one prompt is given.
This is accomplished by adding a fourth output token for an additional mask prediction. This fourth mask is never returned for a single prompt and is the only mask returned for multiple prompts.
We supervise mask prediction with a linear combination of focal loss [65] and dice loss [73] in a 20:1 ratio of focal loss to dice loss, following [20, 14]. Unlike [20, 14], we observe that auxiliary deep supervision after each decoder layer is unhelpful. The IoU prediction head is trained with mean-square-error loss between the IoU prediction and the predicted mask’s IoU with the ground truth mask. It is added to the mask loss with a constant scaling factor of 1.0.
Input Resolution: 1024x1024 pixels by rescaling the image as the longest side and padding the shorter side.
Image Encoder: Transfer Images to Embedding.
Masked Auto-Encoder (MAE) [47] pre-trained Vision Transformer (ViT) [33]
Detail:
A ViT-H/16 with 14×14 windowed attention and four equally-spaced global attention blocks [62].
To reduce the channel dimension, following [62], use a 1×1 convolution to get to 256 channels, followed by a 3×3 convolution also with 256 channels. Each convolution is followed by a layer normalization [4].
Output: a 16× downscaled embedding of the input image. (Ex: 64x64 pixels)
Flexible Prompt Encoder: Sparse + Dense
Sparse: Encoder points, bounding boxes under positional encodings [95] summed with learned embeddings for each prompt type and free-form text with an off-the-shelf text encoder from CLIP [82]
Sparse prompts are mapped to 256-dimensional vectorial embeddings.
A point is represented as the sum of a positional encoding [95] of the point’s location and one of two learned embeddings that indicate if the point is either in the foreground or background.
A box is represented by an embedding pair: (1) the positional encoding of its top-left corner summed with a learned embedding representing “top-left corner” and (2) the same structure but using a learned embedding indicating “bottom right corner”.
To represent free-form text we use the text encoder from CLIP [82] (any text encoder is possible in general).
Dense: Masks are embedded using convolutions and summed element-wise with the image embedding.
Dense prompts (i.e., masks) have a spatial correspondence with the image.
Input masks at a 4× lower resolution than the input image, then downscale an additional 4× using two 2×2, stride-2 convolutions with output channels 4 and 16, respectively.
A final 1×1 convolution maps the channel dimension to 256.
Each layer is separated by GELU activations [50] and layer normalization.
The mask and image embedding are then added element-wise.
If there is no mask prompt, a learned embedding representing “no mask” is added to each image embedding location.
Light-weight Mask Decoder: Efficiently maps the image embedding, prompt embeddings, and an output token to a mask.
Model:
A variant of Transformer Decoder Block [14, 20, 103] with a dynamic mask prediction head.
Using prompt self-attention and cross-attention in two directions (prompt-to-image embedding and vice-versa) to update all embeddings.
Framework:
First insert into the set of prompt embeddings a learned output token embedding that will be used at the decoder’s output, analogous to the [class] token in [33]. For simplicity, refer to “tokens”.
Each decoder layer performs 4 steps:
(1) self-attention on the tokens
(2) cross-attention from tokens (as queries) to the image embedding
(3) a point-wise MLP updates each token
(4) cross-attention from the image embedding (as queries) to tokens. This last step updates the image embedding with prompt information.
During cross-attention, the image embedding is treated as a set of 642 256-dimensional vectors. Each self/cross-attention and MLP has a residual connection [49], layer normalization, and a dropout [93] of 0.1 at training.
The next decoder layer takes the updated tokens and the updated image embedding from the previous layer. We use a two-layer decoder.
To ensure the decoder has access to critical geometric information the positional encodings are added to the image embedding whenever they participate in an attention layer.
Additionally, the entire original prompt tokens (including their positional encodings) are re-added to the updated tokens whenever they participate in an attention layer. This allows for a strong dependence on both the prompt token’s geometric location and type.
After running the decoder, we upsample the updated image embedding by 4× with two transposed convolutional layers (now it’s downscaled 4× relative to the input image). Then, the tokens attend once more to the image embedding and we pass the updated output token embedding to a small 3-layer MLP that outputs a vector matching the channel dimension of the upscaled image embedding.
Finally, we predict a mask with a spatially point-wise product between the upscaled image embedding and the MLP’s output. The transformer uses an embedding dimension of 256.
The transformer MLP blocks have a large internal dimension of 2048, but the MLP is applied only to the prompt tokens for which there are relatively few (rarely greater than 20).
However, in cross-attention layers where we have a 64×64 image embedding, we reduce the channel dimension of the queries, keys, and values by 2× to 128 for computational efficiency.
All attention layers use 8 heads.
The transposed convolutions used to upscale the output image embedding are 2×2, stride 2 with output channel dimensions of 64 and 32 and have GELU activations. They are separated by layer normalization.
Making the model ambiguity-aware.
As described, a single input prompt may be ambiguous in the sense that it corresponds to multiple valid masks, and the model will learn to average over these masks.
We eliminate this problem with a simple modification: instead of predicting a single mask, we use a small number of output tokens and predict multiple masks simultaneously.
By default, we predict three masks, since we observe that three layers (whole, part, and subpart) are often enough to describe nested masks.
During training, we compute the loss (described shortly) between the ground truth and each of the predicted masks, but only backpropagate from the lowest loss.
This is a common technique used for models with multiple outputs [15, 45, 64].
For use in applications, we’d like to rank predicted masks, so we add a small head (operating on an additional output token) that estimates the IoU between each predicted mask and the object it covers.
Ambiguity is much rarer with multiple prompts and the three output masks will usually become similar. To minimize the computation of degenerate losses at training and ensure the single unambiguous mask receives a regular gradient signal, we only predict a single mask when more than one prompt is given. This is accomplished by adding a fourth output token for an additional mask prediction. This fourth mask is never returned for a single prompt and is the only mask returned for multiple prompts.
Loss Function:
FOCAL [65] loss and DICE [73] loss.
Ratio: 20:1 [20, 14]
Following recent approaches [92, 37], we simulate an interactive segmentation setup during training. First, with equal probability either a foreground point or bounding box is selected randomly for the target mask. Points are sampled uniformly from the ground truth mask. Boxes are taken as the ground truth mask’s bounding box, with random noise added in each coordinate with standard deviation equal to 10% of the box sidelength, to a maximum of 20 pixels. This noise profile is a reasonable compromise between applications like instance segmentation, which produce a tight box around the target object, and interactive segmentation, where a user may draw a loose box.
After making a prediction from this first prompt, subsequent points are selected uniformly from the error region between the previous mask prediction and the ground truth mask. Each new point is foreground or background if the error region is a false negative or false positive, respectively. We also supply the mask prediction from the previous iteration as an additional prompt to our model. To provide the next iteration with maximal information, we supply the unthresholded mask logits instead of the binarized mask. When multiple masks are returned, the mask passed to the next iteration and used to sample the next point is the one with the highest predicted IoU.
We find diminishing returns after 8 iteratively sampled points (we have tested up to 16). Additionally, to encourage the model to benefit from the supplied mask, we also use two more iterations where no additional points are sampled. One of these iterations is randomly inserted among the 8 iteratively sampled points, and the other is always at the end. This gives 11 total iterations: one sampled initial input prompt, 8 iteratively sampled points, and two iterations where no new external information is supplied to the model so it can learn to refine its own mask predictions. We note that using a relatively large number of iterations is possible because our lightweight mask decoder requires less than 1% of the image encoder’s compute and, therefore, each iteration adds only a small overhead. This is unlike previous interactive methods that perform only one or a few interactive steps per optimizer update [70, 9, 37, 92].
We use the AdamW [68] optimizer (β1 = 0.9, β2 = 0.999) and a linear learning rate warmup [42] for 250 iterations and a step-wise learning rate decay schedule. The initial learning rate (lr), after warmup, is 8e−4. We train for 90k iterations (∼2 SA-1B epochs) and decrease the lr by a factor of 10 at 60k iterations and again at 86666 iterations. The batch size is 256 images. To regularize SAM, we set weight decay (wd) to 0.1 and apply drop path [53] (dp) with a rate of 0.4. We use a layer-wise learning rate decay [5] (ld) of 0.8. No data augmentation is applied. We initialize SAM from an MAE [47] pre-trained ViT-H. We distribute training across 256 GPUs, due to the large image encoder and 1024×1024 input size. To limit GPU memory usage, we train with up to 64 randomly sampled masks per GPU. Additionally, we find that lightly filtering SA-1B masks to discard any that cover more than 90% of the image qualitatively improves results. For ablations and other variations on training (e.g., text-to-mask §D.5), we deviate from the default recipe above as follows. When training with data from the first and second data engine stages only, we augment the input with large-scale jitter [40] with a scale range of [0.1, 2.0]. Intuitively, data augmentation may be helpful when training data is more limited. To train ViT-B and ViT-L, we use 180k iterations with batch size 128 distributed across 128 GPUs. We set lr = 8e−4 /4e−4 , ld = 0.6/0.8, wd = 0.1, and dp = 0.6/0.4 for ViT-B/L, respectively.
The data was collected using SAM. It only takes about 14 seconds to interactively annotate a mask.
A data engine is built, including 3 gears:
Assisted manual: Annotators used SAM to interactively annotate images, and then the newly annotated data was used to update SAM in turn. Repeat this cycle many times to iteratively improve both the model and the dataset. Image Encoder is changed from ViT-B to ViT-H.
Semi-automatic: A mix of fully automatic annotation combined with assisted annotation, helping increase the diversity of collected masks. Assisted annotators aim to annotate classes that have not been labeled by SAM to increase the diversity of output masks.
Fully automatic: fully automatic mask creation, allowing our dataset to scale. A grid of points (32x32) is used as the prompt, each corresponds to a set of masks. IoU prediction module is used to choose confident masks. Overlapping masks are filtered using NMS.
Final dataset: 1.1 billion segmentation masks collected on about 11 million licensed and privacy-preserving images.
SA-1B: 11M images, 1B+ masks
RAI Additional Details:
Inferring geographic information: Each image has a caption describing where its contents and where it was taken.
Inferring geographic information for COCO and Open Images: retrieve geographic metadata by using the Flickr API.
Inferring income information: look up its income level using the levels defined by the World Bank.
Fairness in segmenting people: use More Inclusive Annotation for People (MIAP) test set annotations for Open Images.
Fairness in segmenting clothing.
Zero-Shot Single Point Valid Mask (compared with RITM, SimpleClick, FocalClick).
Zero-Shot Edge Detection (compared with EDETR).
Zero-Shot Object Proposals (compared with ViTDet-H).
Zero-Shot Instance Segmentation (compared with VITDet-H).
Zero-Shot Text-to-Mask (compared with RIT).
Probing the Latent Space of SAM.
Finally, we perform an initial investigation to qualitatively probe the latent space learned by SAM. In particular, we are interested in whether SAM is able to capture any semantics in its representation even though is not trained with explicit semantic supervision. To do so, we compute mask embeddings by extracting an image embedding from SAM from an image crop around a mask and its horizontally flipped version, multiplying the image embedding by the binary mask, and averaging over spatial locations. In Fig. 17, we show 3 examples of a query mask and similar masks (in the latent space) in the same image. We observe that the nearest neighbors for each query show some, albeit imperfect, shape and semantic similarity. Although these results are preliminary, they indicate that the representations from SAM may be useful for a variety of purposes, such as further data labeling, understanding the contents of datasets, or as features for downstream tasks