Class distribution plots for machine learning in R and ggplot2


During the feature selection phase of a machine learning classification project, it is often useful to visualize the class distributions to get a sense of what features separate the classes best and how a model might use each feature to make a separation.

I have experimented with many different ways of visualizing these densities in the R programming language using the ggplot2 library. Here I document a couple of approaches which I've found best aid this type of analysis.

To start, a very straightforward one-line to the ggplot() function can be achieved by this.

This code produces the following chart which, in this case, pretty clearly distinguishes how the two classes (positive or negative for diabetes) separate with respect to blood glucose levels. While blood glucose level doesn't completely separate the classes, it is clear that higher blood glucose (above 150) is highly correlated with the diabetes class.

Sometimes the difference in the class distributions are much more subtle than in the example above. It these cases, superimposing the two class distributions on top of one another helps better illustrate exactly where the two classes differ with respect to a feature. The following code does this.

Which produces the following chart. 

Superimposed blood glucose density plot for the two classes: has diabetes (positive) and does not have diabetes (negative).

Continuous XOR problem. If you look at the distribution of classes with respect to x1 or x2 on a standalone basis, you would not be able to tell that the classes are separable. However, when visualized together, separability becomes apparent.

It is also useful to see how combinations of features interact. For example, in the continuous XOR example to the right, you would not see any class separation when looking at the class densities on either the x1 or x2 dimension alone as we were able to do with blood glucose level and diabetes above. You need to look at them together to see that they can be separated.

One approach to solving this issue is to use a scatter plot to generate something that looks like the one to the right. However, when you do this for more than a small number of data point, the resulting chart is often too noisy to observe any relationship.

A better approach is to use the data to estimate a two-dimensional density function for each class. The difference between these two densities will represent the relative differences in class densities at each point in the 2D space. That is, when the difference is zero, the density of the two classes at that point are the same and, when they are non-zero, the two densities differ at that point. This composite density function can then be plotted using a contour, heatmap, or other type of 3D representation of the function. I've found that the most attractive way to visualize it is through a combined heatmap and contour plot with a custom coloring scheme. One tricky part in getting this to come together properly is to make sure that the composite density function is always zero at a color that is in the middle of a diverging palette. For this, I give credit to Baptiste Auguie whose code I've borrowed for the script below.

Which yields the following chart:

As we saw before, higher glucose values are correlated with a higher likelihood of diabetes. However, if body mass is lower or higher than average the relationship is not as strong.