Figuring out the reasons why your ML model might be consistently less accurate in certain classes than others, might help you increase not only its total accuracy but also its adversarial robustness.
Machine Learning (ML) models and especially the Deep Learning (DL) ones can achieve impressive results, especially on unstructured data like images and text. However, there is a fundamental limitation with the (supervised) ML framework: the distributions we end up using ML models on are NOT always the ones we train it on. This leads to models that on the one hand seem accurate but on the other hand they are brittle.
Adversarial attacks exploit this brittleness and introduce unnoticeable perturbations to the images that force the model to flip their predictions.
For example, lighting a “Stop” sign in a specific way can make a traffic sign model predict it as a “Speed 30” sign (Fig. 1). In another example, just by rotating an image or replacing some words in a sentence of a medical diagnosis with their synonyms, you can fool models to give a wrong medical diagnosis and risk scores respectively (Fig. 2).
It is evident that adversarial robustness of a model is associated with its security and safety, thus it is important to be aware of its existence and its implications.
In this article, we will use a model trained on the CIFAR10 public dataset to:
📌 Investigate the intuition that the model’s inability to correctly predict an image also leads to higher susceptibility to adversarial attacks.
📌 Measure the class disparities, meaning that we will check the performance of the model across the 10 classes.
📌 Conduct a root cause analysis that will pinpoint the causes of these class differences that hopefully will help us fix not only the miss-classifications but also will increase the adversarial robustness of the model.
“cat”, “bird” and “dog” classes are harder to correctly classify and easier to attack
We trained a simple model, using the well-known ResNet architectural pattern on a 20 layer deep network, which achieves a 89.4% accuracy on the validation set. We then plotted the miss-rate (per class) to check if there are any disparities between the classes (Fig. 3).
It is evident that the “cat”, “bird” and “dog” classes are harder to correctly classify than the rest of the classes.
We then applied two kind of adversarial attacks:
An untargeted attack, where an attack is considered successful when the predicted class label is changed (to any other label.
A targeted attack with the least-likely target, where we have a successful attack when the predicted class label is changed specifically to the label that the model has the least confidence for the specific instance.
Afterwards, we plotted the attack-success rate per class, which measures the percentage of successful attacks per class (note: each class in the test set has 1.000 images). We can observe that the most successfully attacked classes are the same ones that are also miss-classified (Fig 4).
This is intuitive and to some extent expected, since the fact that the model miss-classifies some instances means that it pays more attention to features that are not very relevant to that class, so adding more perturbed features makes the model’s job much harder and the attacker’s goal easier.
Root cause analysis
We identified three possible root causes for the class disparities in the model’s predictions:
1. Miss-labeled/Confusing Training Data
Data collection is probably the most costly and time-consuming part of most machine learning projects. It’s perfectly reasonable to expect that this arduous process will entail some mistakes. We discovered that the CIFAR10 training set contains some images that either are miss-labeled or they are themselves confusing even for humans (Fig. 6).
Data poisoning is also linked with adversarial attacks, although in that case the poisoned data are carefully crafted. It is shown that even a single poisoned image can affect multiple test-images (Fig. 7).
2. Is this a “cat” or a “dog”?
The “cat” class has the worst miss-rate of all the classes, followed by the “dog” class which has the third worst miss-rate.
Since these animals share some similar features (four legs, ears, tail), our intuition is that our model could not extract meaningful features to distinguish these two animals or has learned better “dog” features than the “cat” ones.
Using saliency map explanations, we can verify this suspicion: we can see that the model has not learned some distinctive cat characteristics such as the pointy ears and nose and instead focuses on the whole animal’s face or body (Fig. 8).
3. Is this a “bird” or an “airplane”?
A similar situation is happening with the “bird” and “airplane” classes, where in this case, the blue background confuses the model (Fig. 9).
👉 Good data means a good model: spend some time probing your data and try to detect if there are any systematic errors in your training set.
👉 Use explanation methods as a debugger, in order to understand why your model model misses certain groups of instances more than others
👉 Apply adversarial robustness attacks to test the vulnerability of your model