Tips for training a 3D-Unet model for segmentation tasks.
Semantic segmentation in essence is a classification problem. The network has to classify every pixel of an image with a class label. There are also many possible ways to do this. Training a U-net seemed simple, but in the end it involved a lot of research and experimentations. Here are some tips that I learned while training a 3D-Unet for semantic segmentation of lung tissues to identify infection from Covid-19.
What is U-net
U-net is a fully convolutional network. The lack of dense layers or fully connected layers makes it a easy to train and use less computational resources. It comprises of 2 sets of convolutional layers working in an encoder -decoder way. The 1st set of layers down samples the image creating a set of images with different resolutions. The second set of layers tries to recreate the image in its original resolution by creating a set of up sampled images.
During a series of down-sampling networks can detect the most important features without destroying the shapes or textures of objects. The aim of reconstruction is to make up for the losses produced in the encoding stage.
3D-Unet can handle 3D images with the same U-net structure. 3D images can be viewed as 2D images with the additional depth information.
Before training
Before training make sure that the data you want to train is normalized and resized for the model. Main purpose of normalization is to make computation efficient by reducing values between 0 to 1 as well as removing noise in data.
Images used for medical image segmentation are high-resolution three-dimensional (3D) images. To reduce the memory usage on GPU’s, we can use a patch-based method, which divides a large image into small patches and trains the models with these small patches.
Choosing the hyperparameters
Hyperparameters are the parameters that need to be initialized before training a model. They govern the entire training process. These are but not limited to
Activation functions
For all layers except the last layer, RelU or eLU can be a good choice of activation function. How ever since segmentation is a pixelwise classification problem, softmax would be an ideal choice for the output layer.
Weight initialization as the name suggests, initialize weights for the neurons. It decides how quickly we can train our networks. In Keras, by default, weights are initialized by Glorot initialization. In my implementation I chose He initialization. He initialization goes well with non linear activation functions such as Relu and eLU.
Loss Functions
In my training I tried a lot of loss functions and found that a combination of loss function worked better. In semantic segmentation, especially that involves medical images, the usual choice is Dice similarity coefficient. The dice coefficient is a measure of overlap between prediction and groundtruth. The dice coefficient scores are in the range of 0 to 1 where 1 means that predicted image and groundtruth overlaps perfectly.
Here TP represents true positives, FN is false negatives and FP is false positives. Dice coefficient increase when there are more true positives than its false components.
We can improve dice coefficient further using tversky index. It has the parameters α and 𝜷 to add penalty to false negatives and false positives respectively.
Tversky index helps in controlling false classifications. When α = 𝜷 = 0.5, it acts similar to dice coefficient. In segmentation tasks, from my research and experience an α of 0.3 and 𝜷 of 0.7 worked the best controlling false positives increasing the accuracy of segmentation.
I had a multiclass segmentation problem, so what about categorical crossentropy? Crossentropy roughly quantifies how much the distribution of the ground truth differs from the prediction. For multiclass, we can measure categorical crossentropy as a loss function.
I used a combination of tversky loss and categorical crossentropy as my loss function to train my 3dUnet.
Learning rate and epochs
We can use a higher learning rate to start the training when we use a learning rate scheduler or ReduceLROnPlateau function in Keras. This can help in faster convergence in a controlled manner.
Training
Downsampling layers : By pooling the convultional layers, the image is downsampled.
Upsampling
I used Transpose convolutional layers for the decoder layers. Transpose convolutional layers provide an abstract representation of the input which is exactly what we want in semantic segmentation. Along with upsampling, these layers will learn how to fill in details in the upsampled image during training. Concatenating is done with the feature maps at encoder level so as to get the precise locations of these fillings.
Finally, training a neural network from scratch to get desirable results might take more than a few epochs. Data preparation, setting the correct hyperparameters and use of correct functions to aid training is the key to faster convergence.