This post is part of the series in which we are going to cover the following topics.
1. What is Semantic Segmentation?
Semantic Segmentation is an image analysis task in which we classify each pixel in the image into a class.
This is similar to what us humans do all the time by default. Whenever we are looking at something, then we try to “segment” what portion of the image belongs to which class/label/category.
Essentially, Semantic Segmentation is the technique through which we can achieve this in computers.
You can read more about Segmentation in our post on Image Segmentation.
This post will focus on Semantic Segmentation
So, let’s say we have the following image.
After semantic segmentation, you get the following output:
As you can see, each pixel in the image is classified to its respective class. For example, the person is one class, the bike is another and the third is the background.
This is, in most simple terms, what Semantic Segmentation is.
2. Applications of Semantic Segmentation
The most common use cases for the Semantic Segmentation are:
2.1. Autonomous Driving
In autonomous driving, the computer driving the car needs to have a good understanding of the road scene in front of it. It is important to segment out objects like Cars, Pedestrians, Lanes and traffic signs. We cover this application in great detail in our upcoming Deep Learning course with PyTorch.
2.2. Facial Segmentation
Facial Segmentation is used for segmenting each part of the face into semantically similar regions – lips, eyes etc. This can be useful in many real-world applications. One very interesting application can be virtual make-over.
2.3. Indoor Object Segmentation
Can you guess where is this used? In AR (Augmented Reality) and VR (Virtual Reality). AR applications can segment the entire indoor area to understand the position of chairs, tables, people, wall, and other similar objects, and thus, can place and manipulate virtual objects efficiently.
2.4. Geo Land Sensing
Geo Land Sensing is a way of categorising each pixel in satellite images into a category such that we can track the land cover of each area. So, if in some area there is heavy deforestation taking place then appropriate measures can be taken. There can be many more applications using semantic segmentation on satellite images.
Let us see how to perform semantic segmentation using PyTorch and Torchvision.
3. Semantic Segmentation using torchvision
We will look at two Deep Learning based models for Semantic Segmentation. Fully Convolutional Network ( FCN ) and DeepLab v3. These models have been trained on a subset of COCO Train 2017 dataset which correspond to the PASCAL VOC dataset. There are total 20 categories supported by the models.
You can use the Colab Notebook to follow along the tutorial.
3.1. Input and Output
Before we get started, let us understand the inputs and outputs of the models.
These models expect a 3-channel image (RGB) which is normalized with the Imagenet mean and standard deviation, i.e.,
mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]
So, the input dimension is
[Ni x Ci x Hi x Wi]
Ni-> the batch size
Ci-> the number of channels (which is 3)
Hi-> the height of the image
Wi-> the width of the image
And the output dimension of the model is
[No x Co x Ho x Wo]
No-> is the batch size (same as
Co-> is the number of classes that the dataset have!
Ho-> the height of the image (which is the same as
Hiin almost all cases)
Wo-> the width of the image (which is the same as
Wiin almost all cases)
NOTE: The output of
torchvision models is an
OrderedDict and not a
And in during inference (
.eval() mode ) the output, which is an
OrderedDict just has one key –
out key holds the output and it’s corresponding value has the shape of
[No x Co x Ho x Wo].
Now, we are ready to play 🙂
3.2. FCN with Resnet-101 backbone
FCN – Fully Convolutional Networks, are among the first successful attempts of using Neural Networks for the task of Semantic Segmentation. We cover FCNs and some other models in great details in our upcoming course on Deep Learning with PyTorch. Let us see how to use the model in Torchvision.
3.2.1. Load the model
Let’s load up the FCN!
from torchvision import models fcn = models.segmentation.fcn_resnet101(pretrained=True).eval()
And that’s it! We have a pretrained model of
FCN with a
Resnet101 backbone. The
pretrained=True flag will download the model if it is not already present in the cache. The
.eval method will load it in inference mode.
3.2.2. Load the Image
Next, let’s get an image! We download an image of a bird directly from a URL and save it. We use PIL to load the image.
from PIL import Image import matplotlib.pyplot as plt import torch !wget -nv https://static.independent.co.uk/s3fs-public/thumbnails/image/2018/04/10/19/pinyon-jay-bird.jpg -O bird.png img = Image.open('./bird.png') plt.imshow(img); plt.show()
3.2.3. Pre-process the image
In order to get the image into the right format for inference using the model, we need to pre-process it and normalize it!
So, for the pre-processing steps, we carry out the following.
- Resize the image to
(256 x 256)
- CenterCrop it to
(224 x 224)
- Convert it to Tensor – all the values in the image will be scaled so that will lie between
[0, 1]instead of the original,
- Normalize it with the Imagenet specific values
mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]
And lastly, we unsqueeze the image so that it becomes
[1 x C x H x W] from
[C x H x W]. This is required since we need a batch while passing it through the network.
# Apply the transformations needed import torchvision.transforms as T trf = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])]) inp = trf(img).unsqueeze(0)
Let’s see what the above code cell does.
Torchvision has many useful functions. One of them is
Transforms which is used to pre-process images.
T.Compose is a function that takes in a
list in which each element is of
transforms type and it returns an object through which we can pass batches of images and all the required transforms will be applied to the images.
Let’s take a look at the transforms applied on the images:
T.Resize(256): Resizes the image to size
256 x 256
T.CenterCrop(224): Center Crops the image to have a resulting size of
224 x 224
T.ToTensor(): Converts the image to type
torch.Tensorand scales the values to
T.Normalize(mean, std): Normalizes the image with the given mean and standard deviation.
3.2.4. Forward pass through network
Now that we have the image all preprocessed and ready, let’s pass it through the model and get the
As we mentioned earlier, the output of the model is a
OrderedDict so we need to take the
out key from that to get the output of the model.
# Pass the input through the net out = fcn(inp)['out'] print (out.shape)
torch.Size([1, 21, 224, 224])
out is the final output of the model. And as we can see, its shape is
[1 x 21 x H x W], as discussed earlier. Since, the model was trained on
21 classes, the output has
Now what we need to do is, make this
21 channelled output into a
2D image or a
1 channel image, where each pixel of that image corresponds to a class!
2D image (of shape
[H x W]) will have each pixel corresponding to a class label, and thus for each
(x, y) pixel in this
2D image will correspond to a number between
0 - 20 representing a class.
And how do we get there from this
[1 x 21 x H x W]? We take a max index for each pixel position, which represents the class.
import numpy as np om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy() print (om.shape)
As we can see, now we have a
2D image where each pixel corresponds to a class. The last thing to do is to take this
2D image and convert it into a segmentation map where each class label is converted into a
RGB color and thus helping in an easy visualization.
3.2.5. Decode Output
We will use the following function to convert this
2D image to an
RGB image where each label is mapped to its corresponding color.
# Define the helper function def decode_segmap(image, nc=21): label_colors = np.array([(0, 0, 0), # 0=background # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), # 6=bus, 7=car, 8=cat, 9=chair, 10=cow (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128), # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)]) r = np.zeros_like(image).astype(np.uint8) g = np.zeros_like(image).astype(np.uint8) b = np.zeros_like(image).astype(np.uint8) for l in range(0, nc): idx = image == l r[idx] = label_colors[l, 0] g[idx] = label_colors[l, 1] b[idx] = label_colors[l, 2] rgb = np.stack([r, g, b], axis=2) return rgb
Let’s see what we are doing inside this function!
First, the variable
label_colors stores the colors for each of the classes according to the index. So, the color for the first class which is
background is stored at the
0th index of the
label_colors list. The second class, which is
aeroplane, is stored at index
1 and so on.
Now, we have to create an
RGB image from the
2D image we have. So, what we do is that we create empty
2D matrices for all 3 channels.
b are arrays which will form the
RGB channels for the final image and each of these arrays is of shape
[H x W] (which is the same as the shape of the 2D
Now, we loop over each class color we stored in
label_colors and we get the indexes in the image where that particular class label is present. Then for each channel, we put its corresponding color to those pixels where that class label is present.
Finally, we stack the 3 separate channels to form a
Okay! Now, let’s use this function to see the final segmented output!
rgb = decode_segmap(om) plt.imshow(rgb); plt.show()
And there we go! We have segmented the output of the image.
That’s the bird!
Note that the image after segmentation is smaller than the original image as in the preprocessing step the image is resized and cropped.
3.2.6. Final Result
Next, let’s move all this under one function and play with a few more images!
def segment(net, path): img = Image.open(path) plt.imshow(img); plt.axis('off'); plt.show() # Comment the Resize and CenterCrop for better inference results trf = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])]) inp = trf(img).unsqueeze(0) out = net(inp)['out'] om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy() rgb = decode_segmap(om) plt.imshow(rgb); plt.axis('off'); plt.show()
And let’s get a new image!
!wget -nv https://images.pexels.com/photos/1996333/pexels-photo-1996333.jpeg -O horse.png segment(fcn, './horse.png')
Wasn’t that interesting? Now let’s move on to one of the State-of-the-Art architectures in Semantic Segmentation – DeepLab.
Want to learn Deep Learning and Computer Vision in depth? OpenCV (in collaboration with LearnOpenCV) is creating 3 Computer Vision courses and offering 50% discount ( through Kickstarter ) till June 13th, 2019.
3.3. Semantic Segmentation using DeepLab
DeepLab is a Semantic Segmentation Architecture that came out of Google Brain. Let’s see how we can use it.
dlab = models.segmentation.deeplabv3_resnet101(pretrained=1).eval()
Let’s see how we can perform semantic segmentation on the same image using this model! We will use the same function we defined above.
So, there you go! You can see that, the DeepLab model has segmented the horse almost perfectly!
3.4. Multiple Objects
If we take a more complex image, then we can start to see some differences in the results obtained using both the models.
Let’s try that out!
!wget -nv "https://images.pexels.com/photos/2385051/pexels-photo-2385051.jpeg" -O dog-park.png img = Image.open('./dog-park.png') plt.imshow(img); plt.show() print ('Segmenatation Image on FCN') segment(fcn, path='./dog-park.png', show_orig=False) print ('Segmenatation Image on DeepLabv3') segment(dlab, path='./dog-park.png', show_orig=False)
As you can see both the models perform quite well! But there are cases where the model fails miserably.
Till now we have seen how the code works and how the outputs look qualitatively. In this section, we will discuss the quantitative aspects of the models and also compare the two models with each other on the basis of the following 3 metrics.
- Inference time on CPU and GPU
- Size of the model.
- GPU memory used while inference.
4.1. Inference Time
We have used Google Colab to get to these numbers and you can check out the code for the same in the shared Notebooks.
We can see that DeepLab model is slightly slower than FCN.
4.2. Model Size
Model size is the size of the weights file for the model. DeepLab is a slightly bigger model than FCN.
4.3. GPU Memory requirements
We have used a NVIDIA GTX 1080 Ti GPU for this and found that both models take around 1.2GB for a 224×224 sized image.
We will discuss about other computer vision problems using PyTorch and Torchvision in our next posts. Stay tuned!
We are offering a 50% discount on our ongoing Kickstarter campaign for Computer Vision and Deep Learning courses using OpenCV and PyTorch. Check out the video.