Chapter 11 The Math and Interpretation of Neural Network Models
11.1 Neural Networks Regression
Model for Regression: \(y = f_\mathbf{w}(\mathbf{X}) + \epsilon\), where \(\epsilon\sim \mathcal{N}(0, \sigma^2)\) and \(f_\mathbf{w}\) is a neural network with parameters \(\mathbf{w}\).
Training Objective: find \(\mathbf{w}\) to maximize the joint log-likelihood of our data. This is equivalent to minimizing the Mean Square Error, \[ \min_{\mathbf{w}}\, \mathcal{L}(\mathbf{w}) = \frac{1}{N}\sum^N_{n=1} \left(y_n - f_\mathbf{w}(x_n)\right)^2 \]
Optimizing the Training Objective: For linear regression (when \(f_\mathbf{w}\) is a linear function), we computed the gradient of the \(\mathcal{L}\) with respective to the model parameters \(\mathbf{w}\), set it equal to zero and solved for the optimal \(\mathbf{w}\) analytically. For logistic regression, we computed the gradient and used (stochastic) gradient descent to “solve for where the gradient is zero”.
Can we do the same when \(f_\mathbf{w}\) is a neural network?
11.1.1 Why It’s Hard to Differentiate a Neural Network
Computing the gradient for any parameter in a neural network requires us to use the chain rule:
\[\begin{align} \frac{\partial}{\partial t} g(h(t)) = g'(h(t))h'(t),\quad& \text{or}\quad\frac{\partial g}{\partial t} = \frac{\partial g}{\partial h} \frac{\partial h}{\partial t} \end{align}\]
This is because a neural network is just a big composition of functions.
Let’s try to differentiate the following neural network, with activation function \(h(t) = e^{-0.5 t^2}\) seen as a composition of functions by hand.
We will first give a name to each parameter or weight in the neural network, \(w_{ij}^l\), where the superscript \(l\) indicates the layer from which the weight originates and the subscript \(ij\) indicate that this is a weight on an edge going from the \(i\)-th node in the \(l\)-th layer to the \(j\)-th node in the \(l+1\)-layer.
We will denote the output of the \(i\)-th node in the \(l\)-th layer as \(a_i^l\). We will denote the input to the \(is\)-th node in the \(l\)-th layer (a linear combination of the outputs of nodes in the previous layer) as \(s_i^l\).
In the “Forwards” column, we show you how each quantity \(a_i^l\) and \(s_i^l\) is computed as a function of the quantities in the previous layer.
To compute each partial derivative \(\frac{\partial \mathcal{L}}{w_{ij}^l}\), we start with the definition of the loss function \(\mathcal{L}\) in terms of \(a_0^3\), and then apply the chain-rule \[ \frac{\partial \mathcal{L}}{s_{0}^3} = \frac{\partial \mathcal{L}}{a_{0}^3}\frac{\partial a_{0}^3}{s_{0}^3} \] to get a partial derivative of the loss function with respect to a quantity further down the neural network graph.
By repeating the application of the chain-rule to each layer of the neural network, we eventually get all partial derivatives \(\frac{\partial \mathcal{L}}{w_{ij}^l}\).
You see that differentiating even a tiny network is a very complex time-consuming process!
11.1.2 Differentiating Neural Networks: Backpropagation
Luckily, the process of differentiating a neural network by iteratively applying the chain-rule can be algorithmatized!
The backpropagation algorithm automatically computes the gradient of a neural network and consists of three phases: 0. (Initialize) intialize the network parameters \(\mathbf{W}\) 1. Repeat: 1. (Forward Pass) compute all intermediate values \(s_{ij}^l\) and \(a_{ij}^l\) for the given covariates \(\mathbf{X}\) 2. (Backward Pass) compute all the gradients \(\frac{\partial \mathcal{L}}{\partial w^l_{ij}}\) 3. (Update Parameters) update each parameter by \(-\eta \frac{\partial \mathcal{L}}{\partial w^l_{ij}}\)
We will see on Thursday that backpropagation is a special instance of reverse mode automatic differentiation – a method of algorithmically computing exact gradients for functions defined by combinations of simple functions, by drawing graphical models of the composition of functions and then taking gradients by going forwards-backwards.
11.2 Interpreting Neural Networks
Linear models are easy to interpret. Once we’ve found the MLE of the model parameters, we can formulate scientific hypotheses about the relationship between the outcome \(Y\) and the covariates \(\mathbf{X}\):
\[\begin{align} \widehat{\text{income}} = 2 * \text{education (yr)} + 3.1 * \text{married} - 1.5 * \text{gaps in work history} \end{align}\]
What do the weights of a neural network tell you about the relationship between the covariates and the outcome?
We might be tempted to conclude that neural networks are uninterpretable due to their complexity. But just because we can’t understand neural networks by inspecting the value of the individual weights, it does not mean that we can’t understand them.
In The Mythos of Model Interpretability, the authors survey a large number of methods for interpreting deep models.
11.2.1 Example 1: Can Neural Network Models Make Use of Human Concepts?
(with Anita Mahinpei, Justin Clark, Ike Lage, Finale Doshi-Velez)
What if instead building complex non-linear models based on raw inputs, we instead build simple linear models based on human interpretable concepts? We use a neural network to predict concepts from inputs and then use a linear model to predict the outcome from the concepts. We interpret the relationship between the outcome and the concepts via the linear model. These models are called concept bottleneck models.
In The Promises and Pitfalls of Black-box Concept Learning Models, we examine the advantages and drawbacks of these models.
11.2.2 Example 2: Can Neural Network Models Learn to Explore Hypothetical Scenarios?
(with Michael Downs, Jonathan Chu, Wisoo Song, Yaniv Yacoby, Finale Doshi-Velez)
Rather than explaining why the model made a decision, it’s often more helpful to explain how to change the data in order to change the model’s decision. This modified input is a counter-factual. In CRUDS: Counterfactual Recourse Using Disentangled Subspaces, we study how to automatically generate counter-factual explanations that can help users achieve a favorable outcome from a decision system.
11.2.3 Example 3: A Powerful Generalization of Feature Importance for Neural Network Models
In An explainable deep-learning algorithm for the detection of acute intracranial haemorrhage from small datasets, the authors build a neural network model to detect acute intracranial haemorrhage (ICH) and classifies five ICH subtypes.
Model classifications are explained by highlighting the pixels that contributed the most to the decision. The highligthed regions tends to overlapped with ‘bleeding points’ annotated by neuroradiologists on the images.
11.2.4 Example 4: The Perils of Explanations
In How machine-learning recommendations influence clinician treatment selections: the example of antidepressant selection, the authors found that clinicians interacting with incorrect recommendations paired with simple explanations experienced significant reduction in treatment selection accuracy.
Take-away: Incorrect ML recommendations may adversely impact clinician treatment selections and that explanations are insufficient for addressing overreliance on imperfect ML algorithms.
11.3 Neural Network Models and Generalization
Complex models have low bias – they can model a wide range of functions, given enough samples.
But complex models like neural networks can use their ‘extra’ capacity to explain non-meaningful features of the training data that are unlikely to appear in the test data (i.e. noise). These models have high variance – they are very sensitive to small changes in the data distribution, leading to drastic performance decrease from train to test settings.
Polynomial Model with Modest Degree | Neural Network Model |
---|---|
Just as in the case of linear and polynomial models, we can prevent nerual networks from overfitting (i.e. poor generalization due to high variance) by regularization or by ensembling a large number of models.
However, a new body of work like Deep Double Descent: Where Bigger Models and More Data Hurt show that very wide neural networks (with far more parameters than there are data observations) actually ceases to overfit as the width surpasses a certain threshold. In fact, as the width of a neural network approaches infinity, training the neural network becomes kernel regression (this kernel is called the neural tangent kernel)!