If you can’t explain it simply, you don’t understand it well enough. — Albert Einstein
Disclaimer: This article draws and expands upon material from (1) Christoph Molnar’s excellent book on Interpretable Machine Learning which I definitely recommend to the curious reader, (2) a deep learning visualization workshop from Harvard ComputeFest 2020, as well as (3) material from CS282R at Harvard University taught by Ike Lage and Hima Lakkaraju, who are both prominent researchers in the field of interpretability and explainability. This article is meant to condense and summarize the field of interpretable machine learning to the average data scientist and to stimulate interest in the subject.
Machine learning systems are becoming increasingly employed in complex high-stakes settings such as medicine (e.g. radiology, drug development), financial technology (e.g. stock price prediction, digital financial advisor), and even in law (e.g. case summarization, litigation prediction). Despite this increased utilization, there is still a lack of sufficient techniques available to be able to explain and interpret the decisions of these deep learning algorithms. This can be very problematic in some areas where the decisions of algorithms must be explainable or attributable to certain features due to laws or regulations (such as the right to explanation), or where accountability is required.
The need for algorithmic accountability has been highlighted many times, the most notable cases of which are Google’s facial recognition algorithm that labeled some black people as gorillas, and Uber’s self-driving car which ran a stop sign. Due to the inability of Google to fix the algorithm and remove the algorithmic bias that resulted in this issue, they solved the problem by removing words relating to monkeys from Google Photo’s search engine. This illustrates the alleged black box nature of many machine learning algorithms.
The black box problem is predominantly associated with the supervised machine learning paradigm due to its predictive nature.
Accuracy alone is no longer enough.
Academics in deep learning are acutely aware of this interpretability and explainability problem, and whilst some argue that these models are essentially black boxes, there have been several developments in recent years which have been developed for visualizing aspects of deep neural networks such the features and representations they have learned. The term info-besity has been thrown around to refer to the difficulty of providing transparency when decisions are made on the basis of many individual features, due to an overload of information. The field of interpretability and explainability in machine learning has exploded since 2015 and there are now dozens of papers on the subject, some of which can be found in the references.
As we will see in this article, these visualization techniques are not sufficient for completely explaining the complex representations learned by deep learning algorithms, but hopefully, you will be convinced that the black box interpretation of deep learning is not true — we just need better techniques to be able to understand and interpret these models.
If these in-depth educational content is useful for you, you can subscribe to our AI Research mailing list at the bottom of this article to be alerted when we release new research updates.
The Black Box
All algorithms in machine learning are to some extent black boxes. One of the key ideas of machine learning is that the models are data-driven — the model is configured from the data. This fundamentally leads us to problems such as (1) how we should interpret the models, (2) how to ensure they are transparent in their decision making, and (3) making sure the results of the said algorithm are fair and statistically valid.
For something like linear regression, the models are very well understood and highly interpretable. When we move to something like a support vector machine (SVM) or a random forest model, things get a bit more difficult. In this sense, there is no white or black box algorithm in machine learning, the interpretability exists as a spectrum or a ‘gray box’ of varying grayness.
It just so happens, that at the far end of our ‘gray’ area is the neural network. Even further in this gray area is the deep neural network. When you have a deep neural network with 1.5 billion parameters — as the GPT-2 algorithm for language modeling has — it becomes extremely difficult to interpret the representations that the model has learned.
In February 2020, Microsoft released the largest deep neural network in existence (probably not for long), Turing-NLG. This network contains 17 billion parameters, which is around 1/5th of the 85 billion neurons present in the human brain (although in a neural network, parameters represent connections, of which there are ~100 trillion in the human brain). Clearly, interpreting a 17 billion parameter neural network will be incredibly difficult, but its performance may be far superior to other models because it can be trained on huge amounts of data without becoming saturated — this is the idea that more complex representations can be stored by a model with a greater number of parameters.
Obviously, the representations are there, we just do not understand them fully, and thus we must come up with better techniques to be able to interpret the models. Sadly, it is more difficult than reading coefficients as one is able to do in linear regression!
Often, we do not care how an algorithm came to a specific decision, particularly when they are operationalized in low-risk environments. In these scenarios, we are not limited in our selection of algorithms by any limitation on the interpretability. However, if interpretability is important within our algorithm — as it often is for high-risk environments — then we must accept a tradeoff between accuracy and interpretability.
So what techniques are available to help us better interpret and understand our models? It turns out there are many of these, and it is helpful to make a distinction between what these different types of techniques help us to examine.
Local vs. Global
Techniques can be local, to help us study a small portion of the network, as is the case when looking at individual filters in a neural network.
Techniques can be global, allowing us to build up a better picture of the model as a whole, this could include visualizations of the weight distributions in a deep neural network, or visualizations of neural network layers propagating through the network.
Model-Specific vs. Model-Agnostic
A technique that is highly model-specific is only suitable for use by a single type of models. For example, layer visualization is only applicable to neural networks, whereas partial dependency plots can be utilized for many different types of models and would be described as model-agnostic.
Model-specific techniques generally involve examining the structure of algorithms or intermediate representations, whereas model-agnostic techniques generally involve examining the input or output data distribution.
I will discuss all of the above techniques throughout this article, but will also discuss where and how they can be put to use to help provide us with insight into our models.
Being Right for the Right Reasons
One of the issues that arise from our lack of model explainability is that we do not know what the model has been trained on. This is best illustrated with an apocryphal example (there is some debate as to the truth of the story, but the lessons we can draw from it are nonetheless valuable).
Hide and Seek
According to AI folklore, in the 1960s, the U.S. Army was interested in developing a neural network algorithm that was able to detect tanks in images. Researchers developed an algorithm that was able to do this with remarkable accuracy, and everyone was pretty happy with the result.
However, when the algorithm was tested on additional images, it performed very poorly. This confused the researchers as the results had been so positive during development. After a while of everyone scratching their heads, one of the researchers noticed that when looking at the two sets of images, the sky was darker in one set of images than the other.
It became clear that the algorithm had not actually learned to detect tanks that were camouflaged, but instead was looking at the brightness of the sky!
Whilst this story exacerbates one of the common criticisms of deep learning, there is truth to the fact that in a neural network, and especially a deep neural network, you do not really know what the model is learning.
This powerful criticism and the increasing importance of deep learning in academia and industry is what has led to an increased focus on interpretability and explainability. If an industry professional cannot convince their client that they understand what the model they built is doing, should it be really be used when there are large risks, such as financial losses or people’s lives?
Interpretability
At this point, you might be asking yourself how visualization can help us to interpret a model, given that there may be an infinite number of viable interpretations. Defining and measuring what interpretability means is not a trivial task, and there is little consensus on how to evaluate it.
There is no mathematical definition of interpretability. Two proposed definitions in the literature are:
Interpretability is the degree to which a human can understand the cause of a decision. — Tim Miller
Interpretability is the degree to which a human can consistently predict the model’s result. — Been Kim
The higher the interpretability of a machine learning model, the easier it is for someone to comprehend why certain decisions or predictions have been made. A model is better interpretable than another model if its decisions are easier for a human to comprehend than decisions from the other model. One way we can start to evaluate model interpretability is via a quantifiable proxy.
A proxy is something that is highly correlated with what we are interested in studying but is fundamentally different from the object of interest. Proxies tend to be simpler to measure than the object of interest, or like in this case, just measurable — whereas our object of interest (like interpretability) may not be.
The idea of proxies is prevalent in many fields, one of which is psychology where they are used to measure abstract concepts. The most famous proxy is probably the intelligence quotient (IQ) which is a proxy for intelligence. Whilst the correlation between IQ and intelligence is not 100%, it is high enough that we can gain some useful information about intelligence from measuring IQ. There is no known way for directly measuring intelligence.
An algorithm that uses dimensional reduction to allow us to visualize high-dimensional data in a lower-dimensional space provides us with a proxy to visualize the data distribution. Similarly, a set of training images provides us with a proxy of the full data distribution of interest, but will inevitably be somewhat different to the true distribution (if you did a good job constructing the training set, it should not differ too much from a given test set).
What about post-hoc explanations?
Post-hoc explanations (or explaining after the fact) can be useful but sometimes misleading. These merely provide a plausible rationalization for the algorithmic behavior of a black box, not necessarily concrete evidence and so should be used cautiously. Post-hoc rationalization can be done with quantifiable proxies, and some of the techniques we will discuss do this.
Choosing a Visualization
Designing a visualization requires us to think about the following factors:
- The audience to whom we are presenting (the who) — is this being done for debugging purposes? To convince a client? To convince a peer-reviewer for a research article?
- The objective of the visualization (the what)— are we trying to understand the inputs (such as if EXIF metadata from an image is being read correctly so that an image does not enter a CNN sideways), outputs, or parameter distributions of our model? Are we interested in how inputs evolve through the network or a static feature of the network like a feature map or filter?
- The model being developed (the how)— clearly, if you are not using a neural network, you cannot visualize feature maps of a network layer. Similarly, feature importance can be used for some models, such as XGBoost or Random Forest algorithms, but not others. Thus the model selection inherently biases what techniques can be used, and some techniques are more general and versatile than others. Developing multiple models can provide more versatility in what can be examined.
Deep models present unique challenges for visualization: we can answer the same questions about the model, but our method of interrogation must change! Because of the importance of this, we will mainly focus on deep learning visualization for the rest of the article.
Subfields of Deep Learning Visualization
There are largely three subfields of deep learning visualization literature:
- Interpretability & Explainability: helping to understand how deep learning models make decisions and their learned representations.
- Debugging & Improving: helping model curators and developers construct and troubleshoot their models, with the hope of expediting the iterative experimentation process to ultimately improve performance.
- Teaching Deep Learning: helping to educate amateur users about artificial intelligence — more specifically, machine learning.
Why is interpreting a neural network so difficult?
To understand why interpreting a neural network is difficult and non-intuitive, we have to understand what the network is doing to our data.
Essentially, the data we pass to the input layer — this could be an image or a set of relevant features for predicting a variable — can be plotted to form some complex distribution like that shown in the image below (this is only a 2D representation, imagine it in 1000 dimensions).
If we ran this data through a linear classifier, the model would try its best to separate the data, but since we are limited to a hypothesis class that only contains linear functions, our model will perform poorly since a large portion of the data is not linearly separable.
This is where neural networks come in. The neural network is a very special function. It has been proven that a neural network with a single hidden layer is capable of representing the hypothesis class of all non-linear functions, as long as we have enough nodes in the network. This is known as the universal approximation theorem.
It turns out that the more nodes we have, the larger our class of functions we can represent. If we have a network with only ten layers and are trying to use it to classify a million images, the network will quickly saturate and reach maximum capacity. If we have 10 million parameters, it will be able to learn a much better representation of the network, as the number of non-linear transformations increases. We say this model has a larger model capacity.
People use deep neural networks instead of a single layer because the amount of neurons needed in a single layer network increases exponentially with model capacity. The abstraction of hidden layers significantly reduces the need for more neurons but this comes at a cost for interpretability. The deeper we go, the less interpretable the network becomes.
The non-linear transformations of the neural network allow us to remap our data into a linearly separable space. At the output layer of a neural network, it then becomes arbitrary for us to separate our initially non-linear data into two classes using a linear classifier, as illustrated below.
The question is, how do we know what is going on within this multi-layer non-linear transformation, which may contain millions of parameters?
Imagine a GAN model (two networks fighting each other in order to mimic the distribution of the input data) working on a 512×512 image dataset. When images are introduced into a neural network, each pixel becomes a feature of the neural network. For an image of this size, the number of features is 262,144. This means we are performing potentially 8 or 9 convolutional and non-linear transformations on over 200,000 features. How can one interpret this?
Go even more extreme to the case of 1024×1024 images, which have been developed by NVIDIA’s implementation of StyleGAN. Since the number of pixels increases by a factor of four with a doubling of image size, we would have over a million features as our input to the GAN. So we now have a one million feature neural network, performing convolutional operations and non-linear activations, and doing this over a dataset of hundreds of thousands of images.
Hopefully, I have convinced you that interpreting deep neural networks is profoundly difficult. Although the operations of a neural network may seem simple, they can produce wildly complex outcomes via some form of emergence.
Visualizations
For the remainder of this article, I will discuss visualization techniques that can be used for deep neural networks, since they present the greatest challenge in the interpretability and explainability of machine learning.
Weight Histograms
Weight histograms are generally applicable to any data type, so I have chosen to cover these first. Weight histograms can be very useful in determining the overall distribution of weights across a deep neural network. In general, histograms display the number of occurrences of a given value relative to each other values. If the distribution of weights is uniform, a normal distribution, or takes on some ordered structure can tell us useful information.
For example, if we want to check that all our network layers are learning from a given batch, we can see how the weight distributions change after training on the batch. Whilst this may not seem the most useful visualization at first, we can still gain valuable insight from weight histograms.
Below shows weight and bias histograms for a four-layer network in Tensorboard — Tensorflow’s main visualization tool.
For those of you who are not familiar, there is another tool for plotting weight distributions is Weights and Biases (W&B), which is a relatively new company specializing in experiment tracking for deep learning. When training a large network such as a GAN with millions of parameters, the experiment tracking provided by W&B is very helpful for logging purposes and offers more functionality than Tensorboard (and is free for those of you in academia).
Saliency Maps
Going back to the tank problem we discussed previously, how could we troubleshoot this network to ensure the classifier is examining the correct portions of an image to make its predictions? One way to do this is with saliency maps.
Saliency maps were proposed in the paper “Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps” in 2013, along with class maximization (discussed later). The idea behind them is fairly simple. First, we compute the gradient of the output category with respect to the input image. This gives us an indicator as to how our classification changes with respect to small changes in each of the input image pixels. If the small change creates a positive gradient, then we know that changes to that pixel increase the output value. By visualizing the gradients, we can examine which pixels are the most important for activation and ensure that portions of the image being examined correspond to the object of interest.
Saliency maps provide us with a method for computing the spatial support of a given class in a given image (image-specific class saliency map). This means that we can look at a classification output from a convolution network, perform backpropagation, and look at which parts of the image were involved in classifying the image as a given class.
Another simple adjustment to the saliency method known as rectified saliency can be used. This involves clipping negative gradients during the backpropagation step so as to only propagate positive gradient information. Thus, only information related to an increase in output is communicated. You can find more in the paper “Visualizing and Understanding Convolutional Networks.”
Given an image with pixel locations i and j, and with c color channels (red, blue, and green in RGB images), we backpropagate the output to find the derivative that corresponds to each pixel. We then take the maximum of the absolute value across all color channels of the weights and use this as the ij-th value of the saliency map M.
Visualizing saliency maps can easily be done in Keras using the Keras functions ‘visualize_saliency’ and ‘visualize_saliency_with_losses’.
Occlusion Maps
A similar technique to saliency mapping for discerning the importance of pixels in an image’s prediction is occlusion mapping. In occlusion mapping, we are still developing a map related to an image’s output. However, this time we are interested in how blocking out part of the image affects the prediction output of the image.
Occlusion based methods systematically occlude (blocking out) portions of the input image using a grey square and monitoring the classifier output. The image below — which shows an image classifier aiming to predict melanoma — clearly shows the model is localizing the objects within the scene, as the probability of the correct class drops significantly when the object is occluded (the heat map gets darker in the regions where the melanoma is because occluding this reduces the classifier output performance).
Occlusion mapping is fairly simple to implement as it just involves distorting the image at a given pixel location and saving the prediction output to plot in a heat map. A good implementation of this on GitHub by Akshay Chawla can be found here.
Class Maximization
One very powerful technique in studying neural networks is class maximization. This allows us to view the exemplar of a class, i.e. the input that would cause the class value of the classifier to be maximized in the output. For image data, we would call this the image exemplar of a class. Mathematically, this corresponds to:
Where x* corresponds to the image exemplar of class c. This notation says we want the image that gives us the maximum possible output for what class c is, which can be interpreted as what is the perfect c?
The outputs of this from a large scale classification network are fascinating. Below are some images generated by Nguyen, Yosinski, and Clune in their 2016 paper on deep convolutional network visualization. They performed class maximization on a deep convolutional neural network which was trained on the ILSVRC-2013 dataset.
Activation Maximization
Similar to class maximization, activation maximization helps us to visualize the exemplar of convolutional filters. Class maximization is a subset of activation maximization whereby the output softmax layer of a classification algorithm is maximized. Mathematically, activation maximization can be described as:
Where x* corresponds to the exemplar of hidden layer l or filter f in a deep neural network. This notation says we want the input (an image in the case of a convolutional network) that maximizes the filter or layer. This is illustrated below for the 8 layers of a deep convolutional neural network.
LIME (Local Interpretable Model-Agnostic Explanations)
LIME stands for local interpretable model-agnostic explanations and even has its own Python package. Because the method was designed to be model-agnostic, it can be applied to many different machine learning models. It was first shown in papers by Marco Tulio Ribeiro and colleagues, including “Model-Agnostic Interpretability of Machine Learning” and ‘“Why Should I Trust You?”: Explaining the Predictions of Any Classifier’, both published in 2016.
Local surrogate models are interpretable models that are used to explain individual predictions of black box machine learning models. LIME is an implementation of local surrogate models.
Surrogate models are trained to approximate the predictions of the underlying black box model.
Instead of training a global surrogate model, LIME focuses on training local surrogate models to explain individual predictions.
In LIME, we perturb the input and analyze how our predictions change. Despite how it may sound, this is very different from occlusion mapping and saliency mapping. Our aim is to approximate the underlying model, f, using an interpretable model, g (such as a linear model with a few coefficients) from a set of possible models, G, at a given location governed by a proximity measure, πₓ. We also add a regularizer, Ω, to make sure the interpretable model is as simple as possible. This is illustrated in the equation below.
LIME for images works differently than LIME for tabular data and text. Intuitively, it would not make much sense to perturb individual pixels, since many more than one pixel contribute to one class. Randomly changing individual pixels would probably not change the predictions by much. Therefore, variations of the images are created by segmenting the image into “superpixels” and turning superpixels off or on.
Superpixels are interconnected pixels with similar colors and can be turned off by replacing each pixel with a user-defined color such as gray. The user can also specify a probability for turning off a superpixel in each permutation.
The fidelity measure (how well the interpretable model approximates the black box predictions, given by our loss value L) gives us a good idea of how reliable the interpretable model is in explaining the black box predictions in the neighborhood of the data instance of interest.
LIME is also one of the few methods that works for tabular data, text and images.
Note that we can also generate global surrogate models, which follow the same idea but are used as an approximate model for the entire black box algorithm, not just a localized subset of the algorithm.
Partial Dependency Plots
The partial dependence plot shows the marginal effect one or two features have on the predicted outcome of a machine learning model. If we are analyzing the market price of a metal like gold using a dataset with a hundred features, including the value of gold in previous days, we will find that the price of gold has a much higher dependence on some features than others. For example, the gold price might be closely linked to the oil price, whilst not strongly linked to the price of avocados. This information becomes visible in a partial dependency plot.
Note that this is not the same as a linear regression model. If this was performed on a linear regression model, each of the partial dependency plots would be linear. The partial dependency plot allows us to see the relationship in its full complexity, which may be linear, exponential, or some other complex relationship.
One of the main pitfalls of the partial dependency plot is that it can only realistically show a 2D interpretation involving one or two features. Thus, modeling higher-order interaction terms between multiple variables is difficult.
There is also an inherent assumption of independence of the variables, which is often not the case (such as a correlation between height and weight, which are two common parameters in medical datasets). These correlations between variables may render one of them redundant or present issues to the algorithm due to multicollinearity. Where this becomes a problem, using Accumulated Local Effects (ALE) is much preferred, as it does not suffer from the same pitfalls as partial dependency plots when it comes to collinearity.
To avoid overinterpreting the results in data-sparse feature regions it is helpful to add a rug plot to the bottom of the partial dependency plot to see where data-rich and data-sparse regions are present.
Individual Conditional Expectation (ICE)
ICE is similar to partial dependency plots, except a different line is plotted for each instance in the dataset. Thus, the partial dependency plot gives us an averaged view of the dependency of a feature variable on the output variable, whereas ICE allows us to see the instance-specific dependency of a feature variable. This is useful when interaction variables are present that could be masked when looking at the average result, but become very apparent when using ICE.
Different types of ICE plots exist, such as centered and derivative ICE plots also exist but essentially provide the same information in different forms.
Shapley Values
The Shapley value is a concept drawn from an aspect of cooperative game theory developed in 1953 by Lloyd Shapley. In cooperative game theory, the Shapley value optimizes the payout for each player based on their average contribution over all permutations. When applied to machine learning, we assume that each feature is a player in the game, all working together to maximize the prediction, which can be considered the payout. The Shapley value assigns a portion of the payout to each feature based on its contribution to the output value.
For example, if you are looking at house prices and you remove a single feature from the analysis, how does this affect the model prediction? If the predicted value goes down by an amount, we can infer that this feature contributed this much to the prediction. Of course, it is not exactly that simple, we must perform this computation for every possible combination of features, which means we need to run 2ˣ models where x is the number of features.
Thus, the Shapley value is the average marginal contribution of a feature value across all possible coalitions.
This equation may look daunting, so let’s examine it piece by piece from right to left. To know that marginal contribution of our point xᵢ, we calculate the prediction value of our model using all features in our feature subset, S, that do not contain feature xᵢ, and we subtract this from the prediction value of the subset with that feature still present. We then scale this for the total number of permutations of features and then sum all of these contributions. Thus, we now have a value which is essentially the average contribution of a feature for a trained model using every possible subset of features.
This discussion may seem quite abstract, so an example would be helpful. The example used in Christoph’s book is an excellent one to consider involving house prices. If we have features for predicting a house price which involve (1) the size of the apartment (numeric), (2) the proximity to a nearby park (binary), and (3) the floor of the building the apartment is on. To calculate the Shapley values for each feature, we take every possible subset of features and predict the output in each case (including the case with no features). We then sum the marginal contributions of each feature.
A player can be an individual feature value, e.g. for tabular data, but a player can also be a group of feature values. For example, to explain an image, pixels can be grouped to superpixels and the prediction distributed among them.
As far as I know, there is no official package for Shapley values on Python, but there are some repositories available that have implemented it for machine learning. One such package can be found here.
The main disadvantage of the Shapley value is that it is very computationally expensive and time-consuming for large numbers of features due to the exponential increase in the number of possible permutations for a linear increase in the number of features. Thus, for applications where the number of features is very large, the Shapley value is typically approximated using a subset of feature permutations.
Anchors
First introduced in a 2018 paper by Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin, the same researchers that created LIME. It also has its own Python package that was developed by Marco. It is also available in the ALIBI package for Python.
Anchors address a key shortcoming of local explanation methods like LIME which proxy the local behavior of the model in a linear way. It is, however, unclear to what extent the explanation holds up in the region around the instance to be explained since both the model and data can exhibit non-linear behavior in the neighborhood of the instance. This approach can easily lead to overconfidence in the explanation and misleading conclusions on unseen but similar instances. The anchor algorithm tackles this issue by incorporating coverage, the region where the explanation applies, into the optimization problem.
Similar to LIME, anchors can be used on text, tabular, and image data. For images, we first segment them into superpixels whilst still maintaining local image structure. The interpretable representation then consists of the presence or absence of each superpixel in the anchor. Several image segmentation techniques can be used to split an image into superpixels, such as slic or quickshift.
The algorithm supports a number of standard image segmentation algorithms (felzenszwalb, slic and quickshift) and allows the user to provide a custom segmentation function.
Counterfactuals
Counterfactuals are the opposite of anchors. Anchors are features that when present are sufficient to anchor a prediction (i.e. prevent it from being changed by altering other features). In the anchor section, we looked at an example where these anchors were superpixels of an image. Every superpixel in the image that was not part of an anchor was, in fact, a counterfactual — we can alter the prediction by altering the counterfactuals, and not by altering the anchors.
Counterfactuals were first proposed in the Wachter et al 2017 paper titled “Counterfactual explanations without opening the black box: Automated decisions and the GDPR”. The basic idea of counterfactuals is that we want to find the smallest change we can make to the smallest number of features in order to get the desired output we want.
A counterfactual explanation of a prediction describes the smallest change to the feature values that changes the prediction to a predefined output.
This may sound like an underdefined task, as there are many ways in which we could alter our instance in order for it to meet our desired output. This phenomenon is known as the ‘Rashomon effect’ and as a result, we must cast our problem in the form of an optimization problem. Firstly, we want to ensure that we change as few features as possible, and change these features by the smallest amount possible, whilst also maintaining instances that are likely given the joint distribution of the data. The loss function for our optimization problem can be cast as
The first term of the loss function represents the quadratic distance between the model prediction f’(x’) and the expected output y’. The second term represents a distance metric between the original instance and the counterfactual instance. The quadratic term has a scaling parameter that scales the importance of the prediction output to the distance between the normal instance x and the counterfactual instance x’.
The distance metric we use is the Manhattan distance because the counterfactual should not only be close to the original instance but should also change as few features as possible. The distance function is described as
If we have a small scaling parameter, the distance metric because more important and we prefer to see counterfactuals that are close to our normal instance. If we have a large scaling parameter, the prediction becomes more important and we are laxer on how close the counterfactual is to representing our normal instance.
When we run our algorithm, we do not need to select a value for our scaling parameter. Instead, the authors suggest that a tolerance, ϵ, is given by the user which represents how far we will tolerate the prediction being from our output. This is represented as
Our optimization problem can then succinctly be described as
The optimization mechanism for counterfactuals can be described as a ‘growing spheres’ approach, whereby the input instance, x, output value, y’, and tolerance parameter, ϵ, are given by the user. Initially, a small value for the scaling parameter, λ, is set. A random instance within the current ‘sphere’ of allowed counterfactuals is sampled and then used as a starting point for optimization until the instance satisfies the above constraint (i.e. if the difference between the prediction and the output value is below our tolerance). We then add this instance to our list of counterfactuals and increase the value of λ, which is effectively growing the size of our ‘sphere’. We do this recursively, generating a list of counterfactuals. At the end of the procedure, we select the counterfactual which minimizes the loss function.
Counterfactuals are implemented in the Python package ALIBI, which you can read about here (they also have an alternate description that may be helpful and clearer than my own).
Other Techniques
There are other techniques that I have not touched upon here which I refer the interested reader to. These include, but are not limited to:
Dimensional Reduction Techniques (PCA, t-SNE)
SHapley Additive exPlanations (SHAP)
A good repository of topics on machine learning interpretability can also be found on this GitHub page which covers papers, lectures, and other blogs with material on the subject.
Final Comments
Deep learning visualization is a complex topic that has only just begun to be researched in the last few years. However, it will become more important as deep learning techniques become more integrated into our data-driven society. Most of us may value performance over understanding, but I think that being able to interpret and explain models would provide a competitive edge for individuals and companies in the future, there certainly will be a market for it.
Visualization is not the only method or the best method of interpreting or explaining the results of deep neural networks, but they are certainly a method and they can provide us with useful insight into the decision making process of complex networks.
The problem is that a single metric, such as classification accuracy, is an incomplete description of most real-world tasks. — Doshi-Velez and Kim 2017
References
Here are papers that I referenced in this article as well as papers I think the reader may find informative on the topic of algorithmic interpretability and explainability.
[1] Towards A Rigorous Science of Interpretable Machine Learning — Doshi-Velez and Kim, 2017
[2] The Mythos of Model Interpretability — Lipton, 2017
[3] Transparency: Motivations and Challenges — Weller, 2019
[4]An Evaluation of the Human-Interpretability of Explanation — Lage et. al., 2019
[5] Manipulating and Measuring Model Interpretability — Poursabzi-Sangdeh, 2018
[6] Interpretable Classifiers Using Rules and Bayesian Analysis: Building a Better Stroke Predictions Model — Letham and Rudin, 2015
[7] Interpretable Decision Sets: A Joint Framework for Description and Prediction — Lakkaraju et. al., 2016
[8] Deep Learning for Case-Based Reasoning through Prototypes: A Neural Network that Explains Its Predictions — Li et. al., 2017
[9] The Bayesian Case Model: A Generative Approach for Case-Based Reasoning and Prototype Classification — Kim et. al., 2014
[10] Learning Optimized Risk Scores — Ustun and Rudin, 2017
[11] Intelligible Models for HealthCare: Predicting Pneumonia Risk and Hospital 30-day Readmission — Caruana et. al., 2015
[12] “Why Should I Trust You?” Explaining the Predictions of Any Classifier — Ribeiro et. al., 2016
[13] Stop Explaining Black Box Machine Learning Models for High Stakes Decisions and Use Interpretable Models Instead — Rudin, 2019
[14] Interpretation of Neural Networks is Fragile — Ghorbani et. al., 2019
[15] Visualizing Deep Neural Network Decisions: Prediction Difference Analysis — Zintgraf et. al., 2017
[16] Sanity Checks for Saliency Maps — Adebayo et. al., 2018
[17] A Unified Approach to Interpreting Model Predictions — Lundberg and Lee, 2017
[18] Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV) — Kim et. al., 2018
[19] Counterfactual Explanations Without Opening the Black Box: Automated Decisions and the GDPR — Wachter et. al., 2018
[20] Actionable Recourse in Linear Classification — Ustun et. al., 2018
[21] Causal Interpretations of Black-Box Models — Zhao and Hastie, 2018
[22] Learning Cost-Effective and Interpretable Treatment Regimes — Lakkaraju and Rudin, 2017
[23] Human-in-the-Loop Interpretability Prior — Lage et. al., 2018
[24] Faithful and Customizable Explanations of Black Box Models — Lakkaraju et. al., 2019
[25] Understanding Black-box Predictions via Influence Functions — Koh and Liang, 2017
[26] Simplicity Creates Inequity: Implications for Fairness, Stereotypes, and Interpretability — Kleinberg and Mullainathan, 2019
[27] Understanding Neural Networks Through Deep Visualization — Yosinski et al., 2015
[28] Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps — Simonyan, Vedaldi, and Zisserman, 2014
[29] Multifaceted Feature Visualization: Uncovering the Different Types of Features Learned By Each Neuron in Deep Neural Networks — Nguyen, Yosinski, and Clune, 2016
[30] Explanation in artificial intelligence: Insights from the social sciences — Tim Miller, 2017
[31] Examples are not enough, learn to criticize! Criticism for interpretability —Kim, Been, Rajiv Khanna, and Oluwasanmi O. Koyejo, 2016
[32] What’s Inside the Black Box? AI Challenges for Lawyers and Researchers — Ronald Yu and Gabriele Spina Ali, 2019
This article was originally published on Towards Data Science and re-published to TOPBOTS with permission from the author.
Enjoy this article? Sign up for more AI research updates.
We’ll let you know when we release more summary articles like this one.
Knut Hellan says
It might not be an official package, but there is shap for python: https://github.com/slundberg/shap