Improving Generalizability of CNN models

By Boshika Tara, Ashna Arya, Leonard So


Generalizability is one of the most challenging problems that exist when it comes to model performance. It is essentially how well a model performs on unseen data, most deep learning models that perform poorly overfit the training data. Models use inductive bias to generalize based on the data. CNN is a class of neural network models that are widely used in image classification and recognition tasks. CNN can learn spatial inductive bias really well, hence they are the model of choice for many computer vision and image processing tasks. CNN architectures are built to handle naturally occurring image transformation and changes, though there is debate and some research around how CNN’s invariant mechanism does not necessarily function the way it should(this topic needs its own post).

Due to the volatility in transformations and changes in image datasets, it is important that we do not rely solely on CNNs in-built invariant mechanism to achieve high model accuracy on unseen data. Our team at learned this first hand while building our first set of CNN models for image recognition and classification. The image data the team is working with consists of high natural transformations within images. In addition, the data corpus has imbalanced classes with poor-quality images in the corpus. Due to these factors, all our CNN prototype models exhibited overfitting. In this post, we have outlined some of the techniques currently being used by our team to improve the generalizability of our prototype CNN models.

There are many ways to increase the generalizability of CNN models, in this post we will largely focus on the techniques we are utilizing which also happen to be the two most popular types of solutions: data augmentation and data normalization. Research has shown that changes to the training data have a larger impact on overall accuracy, data augmentation and normalization covers most of the approaches from this data space perspective.


Data augmentation is a popular technique that is often utilized to artificially expand the training data space, the collection of possible data combinations that are expressed to improve generalizability. It is most effectively used when dealing with small or imbalanced datasets that are insufficient for properly training a representative model. In practice, this often occurs when the purpose of the model is quite specific, leading to a limited dataset. This technique is particularly relevant for image data types because data manipulation on images is much more intuitive, interpretable, and varied in terms of its effect on expanding the data space. The two main categories of data augmentation in the image space are basic image manipulation and deep learning-based data augmentation.

Basic image manipulation includes transformations such as random cropping, color space augmentations, translation, reflection, random distortions and augmentations, and other basic geometric transformations. Intuitively, basic image manipulation is designed to remove biases from the model that might exist from where the image data is sourced from. For example, random translations can help remove positional bias in case the data is primarily in the same part of an image, which reduces the model’s ability to predict accurately on test images that may come in a variety of positions. It should be noted that transformation effects have to be contextualized with the data and goal that the model is trying to achieve.

For example, random flips and reflections on the MNIST dataset might cause training data to contain a flipped image of a 6 that looks like a 9 but has the label of a 6.

We can use TensorFlow for implementing basic image manipulation in two ways: data preprocessing, and as a layer inside the model. One could use TensorFlow methods (tf.image), preprocessing layers, or custom image manipulation functions for applying image transformations, which then can be applied on TensorFlow datasets through map.

This separate process can only be run on CPUs, because TensorFlow reserves operations on the GPU for model training, so parallelizability is more limited. TensorFlow 2 also provides experimental layers that can be added to a model before the training layers start, to preprocess the data. One can also create a custom layer with their own image manipulation methods. These preprocessing layers will utilize the benefit of GPU parallelizability and fit more compactly with the model design that TensorFlow utilizes.

Another data augmentation approach utilizes machine learning to learn networks that can artificially create data samples that resemble the initial training dataset. The most popular approach uses generative adversarial networks (GANs) to increase the robustness of networks to adversarial attacks and generally improving the generalizability of the network. One can use these networks to search for helpful training data augmentations. The impact of GANs can be quite beneficial, particularly for protection against adversarial test samples. However, GANs as a solution for data augmentation is more appropriate for more established, larger datasets and models that need that extra bit of protection due to data imbalances, so they aren’t always a feasible solution.


In addition to data augmentation, normalization is another popular technique often used as part of the data preparation process to improve generalizability. The goal of normalization is to constrain the numeric values in a dataset to a standard range that is easier to process. Essentially, the numeric features in the dataset are rescaled while still maintaining the general distribution and ratios in the source data. Normalization is extremely common in computer vision tasks, in particular, it is often done on the input, but it can also take place inside the network itself.

Normalizing the input in image classification problems often involves changing the range of pixel intensities to bring the values into a common range. This normalization often takes the form of scaling all pixel values to range between 0 and 1, which can be done by dividing all the pixel values in each image by 255.

The ImageDataGenerator class in Keras provides a variety of techniques for scaling pixel values before feeding the images into the model. The three most popular ones include pixel normalization (scale values to range between 0 and 1), centering (scale values to have a zero mean), and standardization (scale values to have a zero mean and unit variance). The code snippet below shows how to configure the generator to perform each type of standardization:

After configuring the generator, the next steps are calculating the mean on the training set and preparing iterators to scale the images:

While it is a common practice to normalize the input, normalization can also occur inside of the network. Internal normalization is particularly useful for training deeper networks with multiple layers, as they are sensitive to the initial random weights and configurations of the learning algorithm. During the training process as the weights are updated after each mini-batch, the distribution of inputs to layers deep in the network may change. This change in the distribution of inputs to layers is referred to as an “internal covariate shift” and can cause the learning algorithm to forever chase a moving target. The most common technique used to mitigate this effect is called Batch normalization, which standardizes the inputs to a layer for each mini-batch. It is typically applied before the activation and after the convolution.

Batch Normalization has the effect of stabilizing the learning process and dramatically reducing the number of training epochs needed to train a deep network. It not only provides regularization, reducing generalization error but also accelerates training. One drawback of the batch norm is its dependency on batch size for accurate statistical estimation. A newer technique called Batch Renormalization fixes this problem by introducing two new parameters that approximate instance statistics instead of batch statistics.


Data augmentation and normalization are two prevalent techniques used to improve generalizability, two other methods commonly used to prevent overfitting include regularization and reducing the architecture complexity. Coupled with the two techniques covered earlier, these methods can be effective at improving the generalizability of machine learning models. For further reading on generalizability in CNN models, please see the links in the Reference section below.


Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s