Remove modelling

- Move files to extras/ for now
- Adjust references to modelling
- Or add TO DO items to adjust later
This commit is contained in:
Mine Çetinkaya-Rundel
2021-02-21 21:29:24 +00:00
parent f9109aadfe
commit 55803fc8a3
13 changed files with 8 additions and 13 deletions

View File

@@ -0,0 +1,785 @@
# Model basics
## Introduction
The goal of a model is to provide a simple low-dimensional summary of a dataset.
In the context of this book we're going to use models to partition data into patterns and residuals.
Strong patterns will hide subtler trends, so we'll use models to help peel back layers of structure as we explore a dataset.
However, before we can start using models on interesting, real, datasets, you need to understand the basics of how models work.
For that reason, this chapter of the book is unique because it uses only simulated datasets.
These datasets are very simple, and not at all interesting, but they will help you understand the essence of modelling before you apply the same techniques to real data in the next chapter.
There are two parts to a model:
1. First, you define a **family of models** that express a precise, but generic, pattern that you want to capture.
For example, the pattern might be a straight line, or a quadratic curve.
You will express the model family as an equation like `y = a_1 * x + a_2` or `y = a_1 * x ^ a_2`.
Here, `x` and `y` are known variables from your data, and `a_1` and `a_2` are parameters that can vary to capture different patterns.
2. Next, you generate a **fitted model** by finding the model from the family that is the closest to your data.
This takes the generic model family and makes it specific, like `y = 3 * x + 7` or `y = 9 * x ^ 2`.
It's important to understand that a fitted model is just the closest model from a family of models.
That implies that you have the "best" model (according to some criteria); it doesn't imply that you have a good model and it certainly doesn't imply that the model is "true".
George Box puts this well in his famous aphorism:
> All models are wrong, but some are useful.
It's worth reading the fuller context of the quote:
> Now it would be very remarkable if any system existing in the real world could be exactly represented by any simple model.
> However, cunningly chosen parsimonious models often do provide remarkably useful approximations.
> For example, the law PV = RT relating pressure P, volume V and temperature T of an "ideal" gas via a constant R is not exactly true for any real gas, but it frequently provides a useful approximation and furthermore its structure is informative since it springs from a physical view of the behavior of gas molecules.
>
> For such a model there is no need to ask the question "Is the model true?".
> If "truth" is to be the "whole truth" the answer must be "No".
> The only question of interest is "Is the model illuminating and useful?".
The goal of a model is not to uncover truth, but to discover a simple approximation that is still useful.
### Prerequisites
In this chapter we'll use the modelr package which wraps around base R's modelling functions to make them work naturally in a pipe.
```{r setup, message = FALSE}
library(tidyverse)
library(modelr)
options(na.action = na.warn)
```
## A simple model
Lets take a look at the simulated dataset `sim1`, included with the modelr package.
It contains two continuous variables, `x` and `y`.
Let's plot them to see how they're related:
```{r}
ggplot(sim1, aes(x, y)) +
geom_point()
```
You can see a strong pattern in the data.
Let's use a model to capture that pattern and make it explicit.
It's our job to supply the basic form of the model.
In this case, the relationship looks linear, i.e. `y = a_0 + a_1 * x`.
Let's start by getting a feel for what models from that family look like by randomly generating a few and overlaying them on the data.
For this simple case, we can use `geom_abline()` which takes a slope and intercept as parameters.
Later on we'll learn more general techniques that work with any model.
```{r}
models <- tibble(
a1 = runif(250, -20, 40),
a2 = runif(250, -5, 5)
)
ggplot(sim1, aes(x, y)) +
geom_abline(aes(intercept = a1, slope = a2), data = models, alpha = 1/4) +
geom_point()
```
There are 250 models on this plot, but a lot are really bad!
We need to find the good models by making precise our intuition that a good model is "close" to the data.
We need a way to quantify the distance between the data and a model.
Then we can fit the model by finding the value of `a_0` and `a_1` that generate the model with the smallest distance from this data.
One easy place to start is to find the vertical distance between each point and the model, as in the following diagram.
(Note that I've shifted the x values slightly so you can see the individual distances.)
```{r, echo = FALSE}
dist1 <- sim1 %>%
mutate(
dodge = rep(c(-1, 0, 1) / 20, 10),
x1 = x + dodge,
pred = 7 + x1 * 1.5
)
ggplot(dist1, aes(x1, y)) +
geom_abline(intercept = 7, slope = 1.5, colour = "grey40") +
geom_point(colour = "grey40") +
geom_linerange(aes(ymin = y, ymax = pred), colour = "#3366FF")
```
This distance is just the difference between the y value given by the model (the **prediction**), and the actual y value in the data (the **response**).
To compute this distance, we first turn our model family into an R function.
This takes the model parameters and the data as inputs, and gives values predicted by the model as output:
```{r}
model1 <- function(a, data) {
a[1] + data$x * a[2]
}
model1(c(7, 1.5), sim1)
```
Next, we need some way to compute an overall distance between the predicted and actual values.
In other words, the plot above shows 30 distances: how do we collapse that into a single number?
One common way to do this in statistics to use the "root-mean-squared deviation".
We compute the difference between actual and predicted, square them, average them, and then take the square root.
This distance has lots of appealing mathematical properties, which we're not going to talk about here.
You'll just have to take my word for it!
```{r}
measure_distance <- function(mod, data) {
diff <- data$y - model1(mod, data)
sqrt(mean(diff ^ 2))
}
measure_distance(c(7, 1.5), sim1)
```
Now we can use purrr to compute the distance for all the models defined above.
We need a helper function because our distance function expects the model as a numeric vector of length 2.
```{r}
sim1_dist <- function(a1, a2) {
measure_distance(c(a1, a2), sim1)
}
models <- models %>%
mutate(dist = purrr::map2_dbl(a1, a2, sim1_dist))
models
```
Next, let's overlay the 10 best models on to the data.
I've coloured the models by `-dist`: this is an easy way to make sure that the best models (i.e. the ones with the smallest distance) get the brighest colours.
```{r}
ggplot(sim1, aes(x, y)) +
geom_point(size = 2, colour = "grey30") +
geom_abline(
aes(intercept = a1, slope = a2, colour = -dist),
data = filter(models, rank(dist) <= 10)
)
```
We can also think about these models as observations, and visualising with a scatterplot of `a1` vs `a2`, again coloured by `-dist`.
We can no longer directly see how the model compares to the data, but we can see many models at once.
Again, I've highlighted the 10 best models, this time by drawing red circles underneath them.
```{r}
ggplot(models, aes(a1, a2)) +
geom_point(data = filter(models, rank(dist) <= 10), size = 4, colour = "red") +
geom_point(aes(colour = -dist))
```
Instead of trying lots of random models, we could be more systematic and generate an evenly spaced grid of points (this is called a grid search).
I picked the parameters of the grid roughly by looking at where the best models were in the plot above.
```{r}
grid <- expand.grid(
a1 = seq(-5, 20, length = 25),
a2 = seq(1, 3, length = 25)
) %>%
mutate(dist = purrr::map2_dbl(a1, a2, sim1_dist))
grid %>%
ggplot(aes(a1, a2)) +
geom_point(data = filter(grid, rank(dist) <= 10), size = 4, colour = "red") +
geom_point(aes(colour = -dist))
```
When you overlay the best 10 models back on the original data, they all look pretty good:
```{r}
ggplot(sim1, aes(x, y)) +
geom_point(size = 2, colour = "grey30") +
geom_abline(
aes(intercept = a1, slope = a2, colour = -dist),
data = filter(grid, rank(dist) <= 10)
)
```
You could imagine iteratively making the grid finer and finer until you narrowed in on the best model.
But there's a better way to tackle that problem: a numerical minimisation tool called Newton-Raphson search.
The intuition of Newton-Raphson is pretty simple: you pick a starting point and look around for the steepest slope.
You then ski down that slope a little way, and then repeat again and again, until you can't go any lower.
In R, we can do that with `optim()`:
```{r}
best <- optim(c(0, 0), measure_distance, data = sim1)
best$par
ggplot(sim1, aes(x, y)) +
geom_point(size = 2, colour = "grey30") +
geom_abline(intercept = best$par[1], slope = best$par[2])
```
Don't worry too much about the details of how `optim()` works.
It's the intuition that's important here.
If you have a function that defines the distance between a model and a dataset, an algorithm that can minimise that distance by modifying the parameters of the model can find the best model.
The neat thing about this approach is that it will work for any family of models that you can write an equation for.
There's one more approach that we can use for this model, because it's a special case of a broader family: linear models.
A linear model has the general form `y = a_1 + a_2 * x_1 + a_3 * x_2 + ... + a_n * x_(n - 1)`.
So this simple model is equivalent to a general linear model where n is 2 and `x_1` is `x`.
R has a tool specifically designed for fitting linear models called `lm()`.
`lm()` has a special way to specify the model family: formulas.
Formulas look like `y ~ x`, which `lm()` will translate to a function like `y = a_1 + a_2 * x`.
We can fit the model and look at the output:
```{r}
sim1_mod <- lm(y ~ x, data = sim1)
coef(sim1_mod)
```
These are exactly the same values we got with `optim()`!
Behind the scenes `lm()` doesn't use `optim()` but instead takes advantage of the mathematical structure of linear models.
Using some connections between geometry, calculus, and linear algebra, `lm()` actually finds the closest model in a single step, using a sophisticated algorithm.
This approach is both faster, and guarantees that there is a global minimum.
### Exercises
1. One downside of the linear model is that it is sensitive to unusual values because the distance incorporates a squared term.
Fit a linear model to the simulated data below, and visualise the results.
Rerun a few times to generate different simulated datasets.
What do you notice about the model?
```{r}
sim1a <- tibble(
x = rep(1:10, each = 3),
y = x * 1.5 + 6 + rt(length(x), df = 2)
)
```
2. One way to make linear models more robust is to use a different distance measure.
For example, instead of root-mean-squared distance, you could use mean-absolute distance:
```{r}
measure_distance <- function(mod, data) {
diff <- data$y - model1(mod, data)
mean(abs(diff))
}
```
Use `optim()` to fit this model to the simulated data above and compare it to the linear model.
3. One challenge with performing numerical optimisation is that it's only guaranteed to find one local optimum.
What's the problem with optimising a three parameter model like this?
```{r}
model1 <- function(a, data) {
a[1] + data$x * a[2] + a[3]
}
```
## Visualising models
For simple models, like the one above, you can figure out what pattern the model captures by carefully studying the model family and the fitted coefficients.
And if you ever take a statistics course on modelling, you're likely to spend a lot of time doing just that.
Here, however, we're going to take a different tack.
We're going to focus on understanding a model by looking at its predictions.
This has a big advantage: every type of predictive model makes predictions (otherwise what use would it be?) so we can use the same set of techniques to understand any type of predictive model.
It's also useful to see what the model doesn't capture, the so-called residuals which are left after subtracting the predictions from the data.
Residuals are powerful because they allow us to use models to remove striking patterns so we can study the subtler trends that remain.
### Predictions
To visualise the predictions from a model, we start by generating an evenly spaced grid of values that covers the region where our data lies.
The easiest way to do that is to use `modelr::data_grid()`.
Its first argument is a data frame, and for each subsequent argument it finds the unique variables and then generates all combinations:
```{r}
grid <- sim1 %>%
data_grid(x)
grid
```
(This will get more interesting when we start to add more variables to our model.)
Next we add predictions.
We'll use `modelr::add_predictions()` which takes a data frame and a model.
It adds the predictions from the model to a new column in the data frame:
```{r}
grid <- grid %>%
add_predictions(sim1_mod)
grid
```
(You can also use this function to add predictions to your original dataset.)
Next, we plot the predictions.
You might wonder about all this extra work compared to just using `geom_abline()`.
But the advantage of this approach is that it will work with *any* model in R, from the simplest to the most complex.
You're only limited by your visualisation skills.
For more ideas about how to visualise more complex model types, you might try <http://vita.had.co.nz/papers/model-vis.html>.
```{r}
ggplot(sim1, aes(x)) +
geom_point(aes(y = y)) +
geom_line(aes(y = pred), data = grid, colour = "red", size = 1)
```
### Residuals
The flip-side of predictions are **residuals**.
The predictions tells you the pattern that the model has captured, and the residuals tell you what the model has missed.
The residuals are just the distances between the observed and predicted values that we computed above.
We add residuals to the data with `add_residuals()`, which works much like `add_predictions()`.
Note, however, that we use the original dataset, not a manufactured grid.
This is because to compute residuals we need actual y values.
```{r}
sim1 <- sim1 %>%
add_residuals(sim1_mod)
sim1
```
There are a few different ways to understand what the residuals tell us about the model.
One way is to simply draw a frequency polygon to help us understand the spread of the residuals:
```{r}
ggplot(sim1, aes(resid)) +
geom_freqpoly(binwidth = 0.5)
```
This helps you calibrate the quality of the model: how far away are the predictions from the observed values?
Note that the average of the residual will always be 0.
You'll often want to recreate plots using the residuals instead of the original predictor.
You'll see a lot of that in the next chapter.
```{r}
ggplot(sim1, aes(x, resid)) +
geom_ref_line(h = 0) +
geom_point()
```
This looks like random noise, suggesting that our model has done a good job of capturing the patterns in the dataset.
### Exercises
1. Instead of using `lm()` to fit a straight line, you can use `loess()` to fit a smooth curve.
Repeat the process of model fitting, grid generation, predictions, and visualisation on `sim1` using `loess()` instead of `lm()`.
How does the result compare to `geom_smooth()`?
2. `add_predictions()` is paired with `gather_predictions()` and `spread_predictions()`.
How do these three functions differ?
3. What does `geom_ref_line()` do?
What package does it come from?
Why is displaying a reference line in plots showing residuals useful and important?
4. Why might you want to look at a frequency polygon of absolute residuals?
What are the pros and cons compared to looking at the raw residuals?
## Formulas and model families
You've seen formulas before when using `facet_wrap()` and `facet_grid()`.
In R, formulas provide a general way of getting "special behaviour".
Rather than evaluating the values of the variables right away, they capture them so they can be interpreted by the function.
The majority of modelling functions in R use a standard conversion from formulas to functions.
You've seen one simple conversion already: `y ~ x` is translated to `y = a_1 + a_2 * x`.
If you want to see what R actually does, you can use the `model_matrix()` function.
It takes a data frame and a formula and returns a tibble that defines the model equation: each column in the output is associated with one coefficient in the model, the function is always `y = a_1 * out_1 + a_2 * out_2`.
For the simplest case of `y ~ x1` this shows us something interesting:
```{r}
df <- tribble(
~y, ~x1, ~x2,
4, 2, 5,
5, 1, 6
)
model_matrix(df, y ~ x1)
```
The way that R adds the intercept to the model is just by having a column that is full of ones.
By default, R will always add this column.
If you don't want, you need to explicitly drop it with `-1`:
```{r}
model_matrix(df, y ~ x1 - 1)
```
The model matrix grows in an unsurprising way when you add more variables to the model:
```{r}
model_matrix(df, y ~ x1 + x2)
```
This formula notation is sometimes called "Wilkinson-Rogers notation", and was initially described in *Symbolic Description of Factorial Models for Analysis of Variance*, by G.
N. Wilkinson and C.
E. Rogers <https://www.jstor.org/stable/2346786>.
It's worth digging up and reading the original paper if you'd like to understand the full details of the modelling algebra.
The following sections expand on how this formula notation works for categorical variables, interactions, and transformation.
### Categorical variables
Generating a function from a formula is straight forward when the predictor is continuous, but things get a bit more complicated when the predictor is categorical.
Imagine you have a formula like `y ~ sex`, where sex could either be male or female.
It doesn't make sense to convert that to a formula like `y = a_0 + a_1 * sex` because `sex` isn't a number - you can't multiply it!
Instead what R does is convert it to `y = a_0 + a_1 * sex_male` where `sex_male` is one if `sex` is male and zero otherwise:
```{r}
df <- tribble(
~ sex, ~ response,
"male", 1,
"female", 2,
"male", 1
)
model_matrix(df, response ~ sex)
```
You might wonder why R also doesn't create a `sexfemale` column.
The problem is that would create a column that is perfectly predictable based on the other columns (i.e. `sexfemale = 1 - sexmale`).
Unfortunately the exact details of why this is a problem is beyond the scope of this book, but basically it creates a model family that is too flexible, and will have infinitely many models that are equally close to the data.
Fortunately, however, if you focus on visualising predictions you don't need to worry about the exact parameterisation.
Let's look at some data and models to make that concrete.
Here's the `sim2` dataset from modelr:
```{r}
ggplot(sim2) +
geom_point(aes(x, y))
```
We can fit a model to it, and generate predictions:
```{r}
mod2 <- lm(y ~ x, data = sim2)
grid <- sim2 %>%
data_grid(x) %>%
add_predictions(mod2)
grid
```
Effectively, a model with a categorical `x` will predict the mean value for each category.
(Why? Because the mean minimises the root-mean-squared distance.) That's easy to see if we overlay the predictions on top of the original data:
```{r}
ggplot(sim2, aes(x)) +
geom_point(aes(y = y)) +
geom_point(data = grid, aes(y = pred), colour = "red", size = 4)
```
You can't make predictions about levels that you didn't observe.
Sometimes you'll do this by accident so it's good to recognise this error message:
```{r, error = TRUE}
tibble(x = "e") %>%
add_predictions(mod2)
```
### Interactions (continuous and categorical)
What happens when you combine a continuous and a categorical variable?
`sim3` contains a categorical predictor and a continuous predictor.
We can visualise it with a simple plot:
```{r}
ggplot(sim3, aes(x1, y)) +
geom_point(aes(colour = x2))
```
There are two possible models you could fit to this data:
```{r}
mod1 <- lm(y ~ x1 + x2, data = sim3)
mod2 <- lm(y ~ x1 * x2, data = sim3)
```
When you add variables with `+`, the model will estimate each effect independent of all the others.
It's possible to fit the so-called interaction by using `*`.
For example, `y ~ x1 * x2` is translated to `y = a_0 + a_1 * x1 + a_2 * x2 + a_12 * x1 * x2`.
Note that whenever you use `*`, both the interaction and the individual components are included in the model.
To visualise these models we need two new tricks:
1. We have two predictors, so we need to give `data_grid()` both variables.
It finds all the unique values of `x1` and `x2` and then generates all combinations.
2. To generate predictions from both models simultaneously, we can use `gather_predictions()` which adds each prediction as a row.
The complement of `gather_predictions()` is `spread_predictions()` which adds each prediction to a new column.
Together this gives us:
```{r}
grid <- sim3 %>%
data_grid(x1, x2) %>%
gather_predictions(mod1, mod2)
grid
```
We can visualise the results for both models on one plot using facetting:
```{r}
ggplot(sim3, aes(x1, y, colour = x2)) +
geom_point() +
geom_line(data = grid, aes(y = pred)) +
facet_wrap(~ model)
```
Note that the model that uses `+` has the same slope for each line, but different intercepts.
The model that uses `*` has a different slope and intercept for each line.
Which model is better for this data?
We can take look at the residuals.
Here I've facetted by both model and `x2` because it makes it easier to see the pattern within each group.
```{r}
sim3 <- sim3 %>%
gather_residuals(mod1, mod2)
ggplot(sim3, aes(x1, resid, colour = x2)) +
geom_point() +
facet_grid(model ~ x2)
```
There is little obvious pattern in the residuals for `mod2`.
The residuals for `mod1` show that the model has clearly missed some pattern in `b`, and less so, but still present is pattern in `c`, and `d`.
You might wonder if there's a precise way to tell which of `mod1` or `mod2` is better.
There is, but it requires a lot of mathematical background, and we don't really care.
Here, we're interested in a qualitative assessment of whether or not the model has captured the pattern that we're interested in.
### Interactions (two continuous)
Let's take a look at the equivalent model for two continuous variables.
Initially things proceed almost identically to the previous example:
```{r}
mod1 <- lm(y ~ x1 + x2, data = sim4)
mod2 <- lm(y ~ x1 * x2, data = sim4)
grid <- sim4 %>%
data_grid(
x1 = seq_range(x1, 5),
x2 = seq_range(x2, 5)
) %>%
gather_predictions(mod1, mod2)
grid
```
Note my use of `seq_range()` inside `data_grid()`.
Instead of using every unique value of `x`, I'm going to use a regularly spaced grid of five values between the minimum and maximum numbers.
It's probably not super important here, but it's a useful technique in general.
There are two other useful arguments to `seq_range()`:
- `pretty = TRUE` will generate a "pretty" sequence, i.e. something that looks nice to the human eye.
This is useful if you want to produce tables of output:
```{r}
seq_range(c(0.0123, 0.923423), n = 5)
seq_range(c(0.0123, 0.923423), n = 5, pretty = TRUE)
```
- `trim = 0.1` will trim off 10% of the tail values.
This is useful if the variables have a long tailed distribution and you want to focus on generating values near the center:
```{r}
x1 <- rcauchy(100)
seq_range(x1, n = 5)
seq_range(x1, n = 5, trim = 0.10)
seq_range(x1, n = 5, trim = 0.25)
seq_range(x1, n = 5, trim = 0.50)
```
- `expand = 0.1` is in some sense the opposite of `trim()` it expands the range by 10%.
```{r}
x2 <- c(0, 1)
seq_range(x2, n = 5)
seq_range(x2, n = 5, expand = 0.10)
seq_range(x2, n = 5, expand = 0.25)
seq_range(x2, n = 5, expand = 0.50)
```
Next let's try and visualise that model.
We have two continuous predictors, so you can imagine the model like a 3d surface.
We could display that using `geom_tile()`:
```{r}
ggplot(grid, aes(x1, x2)) +
geom_tile(aes(fill = pred)) +
facet_wrap(~ model)
```
That doesn't suggest that the models are very different!
But that's partly an illusion: our eyes and brains are not very good at accurately comparing shades of colour.
Instead of looking at the surface from the top, we could look at it from either side, showing multiple slices:
```{r, asp = 1/2}
ggplot(grid, aes(x1, pred, colour = x2, group = x2)) +
geom_line() +
facet_wrap(~ model)
ggplot(grid, aes(x2, pred, colour = x1, group = x1)) +
geom_line() +
facet_wrap(~ model)
```
This shows you that interaction between two continuous variables works basically the same way as for a categorical and continuous variable.
An interaction says that there's not a fixed offset: you need to consider both values of `x1` and `x2` simultaneously in order to predict `y`.
You can see that even with just two continuous variables, coming up with good visualisations are hard.
But that's reasonable: you shouldn't expect it will be easy to understand how three or more variables simultaneously interact!
But again, we're saved a little because we're using models for exploration, and you can gradually build up your model over time.
The model doesn't have to be perfect, it just has to help you reveal a little more about your data.
I spent some time looking at the residuals to see if I could figure if `mod2` did better than `mod1`.
I think it does, but it's pretty subtle.
You'll have a chance to work on it in the exercises.
### Transformations
You can also perform transformations inside the model formula.
For example, `log(y) ~ sqrt(x1) + x2` is transformed to `log(y) = a_1 + a_2 * sqrt(x1) + a_3 * x2`.
If your transformation involves `+`, `*`, `^`, or `-`, you'll need to wrap it in `I()` so R doesn't treat it like part of the model specification.
For example, `y ~ x + I(x ^ 2)` is translated to `y = a_1 + a_2 * x + a_3 * x^2`.
If you forget the `I()` and specify `y ~ x ^ 2 + x`, R will compute `y ~ x * x + x`.
`x * x` means the interaction of `x` with itself, which is the same as `x`.
R automatically drops redundant variables so `x + x` become `x`, meaning that `y ~ x ^ 2 + x` specifies the function `y = a_1 + a_2 * x`.
That's probably not what you intended!
Again, if you get confused about what your model is doing, you can always use `model_matrix()` to see exactly what equation `lm()` is fitting:
```{r}
df <- tribble(
~y, ~x,
1, 1,
2, 2,
3, 3
)
model_matrix(df, y ~ x^2 + x)
model_matrix(df, y ~ I(x^2) + x)
```
Transformations are useful because you can use them to approximate non-linear functions.
If you've taken a calculus class, you may have heard of Taylor's theorem which says you can approximate any smooth function with an infinite sum of polynomials.
That means you can use a polynomial function to get arbitrarily close to a smooth function by fitting an equation like `y = a_1 + a_2 * x + a_3 * x^2 + a_4 * x ^ 3`.
Typing that sequence by hand is tedious, so R provides a helper function: `poly()`:
```{r}
model_matrix(df, y ~ poly(x, 2))
```
However there's one major problem with using `poly()`: outside the range of the data, polynomials rapidly shoot off to positive or negative infinity.
One safer alternative is to use the natural spline, `splines::ns()`.
```{r}
library(splines)
model_matrix(df, y ~ ns(x, 2))
```
Let's see what that looks like when we try and approximate a non-linear function:
```{r}
sim5 <- tibble(
x = seq(0, 3.5 * pi, length = 50),
y = 4 * sin(x) + rnorm(length(x))
)
ggplot(sim5, aes(x, y)) +
geom_point()
```
I'm going to fit five models to this data.
```{r}
mod1 <- lm(y ~ ns(x, 1), data = sim5)
mod2 <- lm(y ~ ns(x, 2), data = sim5)
mod3 <- lm(y ~ ns(x, 3), data = sim5)
mod4 <- lm(y ~ ns(x, 4), data = sim5)
mod5 <- lm(y ~ ns(x, 5), data = sim5)
grid <- sim5 %>%
data_grid(x = seq_range(x, n = 50, expand = 0.1)) %>%
gather_predictions(mod1, mod2, mod3, mod4, mod5, .pred = "y")
ggplot(sim5, aes(x, y)) +
geom_point() +
geom_line(data = grid, colour = "red") +
facet_wrap(~ model)
```
Notice that the extrapolation outside the range of the data is clearly bad.
This is the downside to approximating a function with a polynomial.
But this is a very real problem with every model: the model can never tell you if the behaviour is true when you start extrapolating outside the range of the data that you have seen.
You must rely on theory and science.
### Exercises
1. What happens if you repeat the analysis of `sim2` using a model without an intercept.
What happens to the model equation?
What happens to the predictions?
2. Use `model_matrix()` to explore the equations generated for the models I fit to `sim3` and `sim4`.
Why is `*` a good shorthand for interaction?
3. Using the basic principles, convert the formulas in the following two models into functions.
(Hint: start by converting the categorical variable into 0-1 variables.)
```{r, eval = FALSE}
mod1 <- lm(y ~ x1 + x2, data = sim3)
mod2 <- lm(y ~ x1 * x2, data = sim3)
```
4. For `sim4`, which of `mod1` and `mod2` is better?
I think `mod2` does a slightly better job at removing patterns, but it's pretty subtle.
Can you come up with a plot to support my claim?
## Missing values
Missing values obviously can not convey any information about the relationship between the variables, so modelling functions will drop any rows that contain missing values.
R's default behaviour is to silently drop them, but `options(na.action = na.warn)` (run in the prerequisites), makes sure you get a warning.
```{r}
df <- tribble(
~x, ~y,
1, 2.2,
2, NA,
3, 3.5,
4, 8.3,
NA, 10
)
mod <- lm(y ~ x, data = df)
```
To suppress the warning, set `na.action = na.exclude`:
```{r}
mod <- lm(y ~ x, data = df, na.action = na.exclude)
```
You can always see exactly how many observations were used with `nobs()`:
```{r}
nobs(mod)
```
## Other model families
This chapter has focussed exclusively on the class of linear models, which assume a relationship of the form `y = a_1 * x1 + a_2 * x2 + ... + a_n * xn`.
Linear models additionally assume that the residuals have a normal distribution, which we haven't talked about.
There are a large set of model classes that extend the linear model in various interesting ways.
Some of them are:
- **Generalised linear models**, e.g. `stats::glm()`.
Linear models assume that the response is continuous and the error has a normal distribution.
Generalised linear models extend linear models to include non-continuous responses (e.g. binary data or counts).
They work by defining a distance metric based on the statistical idea of likelihood.
- **Generalised additive models**, e.g. `mgcv::gam()`, extend generalised linear models to incorporate arbitrary smooth functions.
That means you can write a formula like `y ~ s(x)` which becomes an equation like `y = f(x)` and let `gam()` estimate what that function is (subject to some smoothness constraints to make the problem tractable).
- **Penalised linear models**, e.g. `glmnet::glmnet()`, add a penalty term to the distance that penalises complex models (as defined by the distance between the parameter vector and the origin).
This tends to make models that generalise better to new datasets from the same population.
- **Robust linear models**, e.g.
`MASS::rlm()`, tweak the distance to downweight points that are very far away.
This makes them less sensitive to the presence of outliers, at the cost of being not quite as good when there are no outliers.
- **Trees**, e.g. `rpart::rpart()`, attack the problem in a completely different way than linear models.
They fit a piece-wise constant model, splitting the data into progressively smaller and smaller pieces.
Trees aren't terribly effective by themselves, but they are very powerful when used in aggregate by models like **random forests** (e.g. `randomForest::randomForest()`) or **gradient boosting machines** (e.g. `xgboost::xgboost`.)
These models all work similarly from a programming perspective.
Once you've mastered linear models, you should find it easy to master the mechanics of these other model classes.
Being a skilled modeller is a mixture of some good general principles and having a big toolbox of techniques.
Now that you've learned some general tools and one useful class of models, you can go on and learn more classes from other sources.

View File

@@ -0,0 +1,495 @@
# Model building
## Introduction
In the previous chapter you learned how linear models work, and learned some basic tools for understanding what a model is telling you about your data.
The previous chapter focussed on simulated datasets.
This chapter will focus on real data, showing you how you can progressively build up a model to aid your understanding of the data.
We will take advantage of the fact that you can think about a model partitioning your data into pattern and residuals.
We'll find patterns with visualisation, then make them concrete and precise with a model.
We'll then repeat the process, but replace the old response variable with the residuals from the model.
The goal is to transition from implicit knowledge in the data and your head to explicit knowledge in a quantitative model.
This makes it easier to apply to new domains, and easier for others to use.
For very large and complex datasets this will be a lot of work.
There are certainly alternative approaches - a more machine learning approach is simply to focus on the predictive ability of the model.
These approaches tend to produce black boxes: the model does a really good job at generating predictions, but you don't know why.
This is a totally reasonable approach, but it does make it hard to apply your real world knowledge to the model.
That, in turn, makes it difficult to assess whether or not the model will continue to work in the long-term, as fundamentals change.
For most real models, I'd expect you to use some combination of this approach and a more classic automated approach.
It's a challenge to know when to stop.
You need to figure out when your model is good enough, and when additional investment is unlikely to pay off.
I particularly like this quote from reddit user Broseidon241:
> A long time ago in art class, my teacher told me "An artist needs to know when a piece is done. You can't tweak something into perfection - wrap it up. If you don't like it, do it over again. Otherwise begin something new".
> Later in life, I heard "A poor seamstress makes many mistakes. A good seamstress works hard to correct those mistakes. A great seamstress isn't afraid to throw out the garment and start over."
>
> -- Broseidon241, <https://www.reddit.com/r/datascience/comments/4irajq>
### Prerequisites
We'll use the same tools as in the previous chapter, but add in some real datasets: `diamonds` from ggplot2, and `flights` from nycflights13.
We'll also need lubridate in order to work with the date/times in `flights`.
```{r setup, message = FALSE}
library(tidyverse)
library(modelr)
options(na.action = na.warn)
library(nycflights13)
library(lubridate)
```
## Why are low quality diamonds more expensive? {#diamond-prices}
In previous chapters we've seen a surprising relationship between the quality of diamonds and their price: low quality diamonds (poor cuts, bad colours, and inferior clarity) have higher prices.
```{r dev = "png"}
ggplot(diamonds, aes(cut, price)) + geom_boxplot()
ggplot(diamonds, aes(color, price)) + geom_boxplot()
ggplot(diamonds, aes(clarity, price)) + geom_boxplot()
```
Note that the worst diamond color is J (slightly yellow), and the worst clarity is I1 (inclusions visible to the naked eye).
### Price and carat
It looks like lower quality diamonds have higher prices because there is an important confounding variable: the weight (`carat`) of the diamond.
The weight of the diamond is the single most important factor for determining the price of the diamond, and lower quality diamonds tend to be larger.
```{r}
ggplot(diamonds, aes(carat, price)) +
geom_hex(bins = 50)
```
We can make it easier to see how the other attributes of a diamond affect its relative `price` by fitting a model to separate out the effect of `carat`.
But first, lets make a couple of tweaks to the diamonds dataset to make it easier to work with:
1. Focus on diamonds smaller than 2.5 carats (99.7% of the data)
2. Log-transform the carat and price variables.
```{r}
diamonds2 <- diamonds %>%
filter(carat <= 2.5) %>%
mutate(lprice = log2(price), lcarat = log2(carat))
```
Together, these changes make it easier to see the relationship between `carat` and `price`:
```{r}
ggplot(diamonds2, aes(lcarat, lprice)) +
geom_hex(bins = 50)
```
The log-transformation is particularly useful here because it makes the pattern linear, and linear patterns are the easiest to work with.
Let's take the next step and remove that strong linear pattern.
We first make the pattern explicit by fitting a model:
```{r}
mod_diamond <- lm(lprice ~ lcarat, data = diamonds2)
```
Then we look at what the model tells us about the data.
Note that I back transform the predictions, undoing the log transformation, so I can overlay the predictions on the raw data:
```{r}
grid <- diamonds2 %>%
data_grid(carat = seq_range(carat, 20)) %>%
mutate(lcarat = log2(carat)) %>%
add_predictions(mod_diamond, "lprice") %>%
mutate(price = 2 ^ lprice)
ggplot(diamonds2, aes(carat, price)) +
geom_hex(bins = 50) +
geom_line(data = grid, colour = "red", size = 1)
```
That tells us something interesting about our data.
If we believe our model, then the large diamonds are much cheaper than expected.
This is probably because no diamond in this dataset costs more than \$19,000.
Now we can look at the residuals, which verifies that we've successfully removed the strong linear pattern:
```{r}
diamonds2 <- diamonds2 %>%
add_residuals(mod_diamond, "lresid")
ggplot(diamonds2, aes(lcarat, lresid)) +
geom_hex(bins = 50)
```
Importantly, we can now re-do our motivating plots using those residuals instead of `price`.
```{r dev = "png"}
ggplot(diamonds2, aes(cut, lresid)) + geom_boxplot()
ggplot(diamonds2, aes(color, lresid)) + geom_boxplot()
ggplot(diamonds2, aes(clarity, lresid)) + geom_boxplot()
```
Now we see the relationship we expect: as the quality of the diamond increases, so too does its relative price.
To interpret the `y` axis, we need to think about what the residuals are telling us, and what scale they are on.
A residual of -1 indicates that `lprice` was 1 unit lower than a prediction based solely on its weight.
$2^{-1}$ is 1/2, points with a value of -1 are half the expected price, and residuals with value 1 are twice the predicted price.
### A more complicated model
If we wanted to, we could continue to build up our model, moving the effects we've observed into the model to make them explicit.
For example, we could include `color`, `cut`, and `clarity` into the model so that we also make explicit the effect of these three categorical variables:
```{r}
mod_diamond2 <- lm(lprice ~ lcarat + color + cut + clarity, data = diamonds2)
```
This model now includes four predictors, so it's getting harder to visualise.
Fortunately, they're currently all independent which means that we can plot them individually in four plots.
To make the process a little easier, we're going to use the `.model` argument to `data_grid`:
```{r}
grid <- diamonds2 %>%
data_grid(cut, .model = mod_diamond2) %>%
add_predictions(mod_diamond2)
grid
ggplot(grid, aes(cut, pred)) +
geom_point()
```
If the model needs variables that you haven't explicitly supplied, `data_grid()` will automatically fill them in with "typical" value.
For continuous variables, it uses the median, and categorical variables it uses the most common value (or values, if there's a tie).
```{r}
diamonds2 <- diamonds2 %>%
add_residuals(mod_diamond2, "lresid2")
ggplot(diamonds2, aes(lcarat, lresid2)) +
geom_hex(bins = 50)
```
This plot indicates that there are some diamonds with quite large residuals - remember a residual of 2 indicates that the diamond is 4x the price that we expected.
It's often useful to look at unusual values individually:
```{r}
diamonds2 %>%
filter(abs(lresid2) > 1) %>%
add_predictions(mod_diamond2) %>%
mutate(pred = round(2 ^ pred)) %>%
select(price, pred, carat:table, x:z) %>%
arrange(price)
```
Nothing really jumps out at me here, but it's probably worth spending time considering if this indicates a problem with our model, or if there are errors in the data.
If there are mistakes in the data, this could be an opportunity to buy diamonds that have been priced low incorrectly.
### Exercises
1. In the plot of `lcarat` vs. `lprice`, there are some bright vertical strips.
What do they represent?
2. If `log(price) = a_0 + a_1 * log(carat)`, what does that say about the relationship between `price` and `carat`?
3. Extract the diamonds that have very high and very low residuals.
Is there anything unusual about these diamonds?
Are they particularly bad or good, or do you think these are pricing errors?
4. Does the final model, `mod_diamond2`, do a good job of predicting diamond prices?
Would you trust it to tell you how much to spend if you were buying a diamond?
## What affects the number of daily flights?
Let's work through a similar process for a dataset that seems even simpler at first glance: the number of flights that leave NYC per day.
This is a really small dataset --- only 365 rows and 2 columns --- and we're not going to end up with a fully realised model, but as you'll see, the steps along the way will help us better understand the data.
Let's get started by counting the number of flights per day and visualising it with ggplot2.
```{r}
daily <- flights %>%
mutate(date = make_date(year, month, day)) %>%
group_by(date) %>%
summarise(n = n())
daily
ggplot(daily, aes(date, n)) +
geom_line()
```
### Day of week
Understanding the long-term trend is challenging because there's a very strong day-of-week effect that dominates the subtler patterns.
Let's start by looking at the distribution of flight numbers by day-of-week:
```{r}
daily <- daily %>%
mutate(wday = wday(date, label = TRUE))
ggplot(daily, aes(wday, n)) +
geom_boxplot()
```
There are fewer flights on weekends because most travel is for business.
The effect is particularly pronounced on Saturday: you might sometimes leave on Sunday for a Monday morning meeting, but it's very rare that you'd leave on Saturday as you'd much rather be at home with your family.
One way to remove this strong pattern is to use a model.
First, we fit the model, and display its predictions overlaid on the original data:
```{r}
mod <- lm(n ~ wday, data = daily)
grid <- daily %>%
data_grid(wday) %>%
add_predictions(mod, "n")
ggplot(daily, aes(wday, n)) +
geom_boxplot() +
geom_point(data = grid, colour = "red", size = 4)
```
Next we compute and visualise the residuals:
```{r}
daily <- daily %>%
add_residuals(mod)
daily %>%
ggplot(aes(date, resid)) +
geom_ref_line(h = 0) +
geom_line()
```
Note the change in the y-axis: now we are seeing the deviation from the expected number of flights, given the day of week.
This plot is useful because now that we've removed much of the large day-of-week effect, we can see some of the subtler patterns that remain:
1. Our model seems to fail starting in June: you can still see a strong regular pattern that our model hasn't captured.
Drawing a plot with one line for each day of the week makes the cause easier to see:
```{r}
ggplot(daily, aes(date, resid, colour = wday)) +
geom_ref_line(h = 0) +
geom_line()
```
Our model fails to accurately predict the number of flights on Saturday: during summer there are more flights than we expect, and during fall there are fewer.
We'll see how we can do better to capture this pattern in the next section.
2. There are some days with far fewer flights than expected:
```{r}
daily %>%
filter(resid < -100)
```
If you're familiar with American public holidays, you might spot New Year's day, July 4th, Thanksgiving and Christmas.
There are some others that don't seem to correspond to public holidays.
You'll work on those in one of the exercises.
3. There seems to be some smoother long term trend over the course of a year.
We can highlight that trend with `geom_smooth()`:
```{r}
daily %>%
ggplot(aes(date, resid)) +
geom_ref_line(h = 0) +
geom_line(colour = "grey50") +
geom_smooth(se = FALSE, span = 0.20)
```
There are fewer flights in January (and December), and more in summer (May-Sep).
We can't do much with this pattern quantitatively, because we only have a single year of data.
But we can use our domain knowledge to brainstorm potential explanations.
### Seasonal Saturday effect
Let's first tackle our failure to accurately predict the number of flights on Saturday.
A good place to start is to go back to the raw numbers, focussing on Saturdays:
```{r}
daily %>%
filter(wday == "Sat") %>%
ggplot(aes(date, n)) +
geom_point() +
geom_line() +
scale_x_date(NULL, date_breaks = "1 month", date_labels = "%b")
```
(I've used both points and lines to make it more clear what is data and what is interpolation.)
I suspect this pattern is caused by summer holidays: many people go on holiday in the summer, and people don't mind travelling on Saturdays for vacation.
Looking at this plot, we might guess that summer holidays are from early June to late August.
That seems to line up fairly well with the [state's school terms](http://schools.nyc.gov/Calendar/2013-2014+School+Year+Calendars.htm): summer break in 2013 was Jun 26--Sep 9.
Why are there more Saturday flights in spring than fall?
I asked some American friends and they suggested that it's less common to plan family vacations during fall because of the big Thanksgiving and Christmas holidays.
We don't have the data to know for sure, but it seems like a plausible working hypothesis.
Lets create a "term" variable that roughly captures the three school terms, and check our work with a plot:
```{r}
term <- function(date) {
cut(date,
breaks = ymd(20130101, 20130605, 20130825, 20140101),
labels = c("spring", "summer", "fall")
)
}
daily <- daily %>%
mutate(term = term(date))
daily %>%
filter(wday == "Sat") %>%
ggplot(aes(date, n, colour = term)) +
geom_point(alpha = 1/3) +
geom_line() +
scale_x_date(NULL, date_breaks = "1 month", date_labels = "%b")
```
(I manually tweaked the dates to get nice breaks in the plot. Using a visualisation to help you understand what your function is doing is a really powerful and general technique.)
It's useful to see how this new variable affects the other days of the week:
```{r}
daily %>%
ggplot(aes(wday, n, colour = term)) +
geom_boxplot()
```
It looks like there is significant variation across the terms, so fitting a separate day of week effect for each term is reasonable.
This improves our model, but not as much as we might hope:
```{r}
mod1 <- lm(n ~ wday, data = daily)
mod2 <- lm(n ~ wday * term, data = daily)
daily %>%
gather_residuals(without_term = mod1, with_term = mod2) %>%
ggplot(aes(date, resid, colour = model)) +
geom_line(alpha = 0.75)
```
We can see the problem by overlaying the predictions from the model on to the raw data:
```{r}
grid <- daily %>%
data_grid(wday, term) %>%
add_predictions(mod2, "n")
ggplot(daily, aes(wday, n)) +
geom_boxplot() +
geom_point(data = grid, colour = "red") +
facet_wrap(~ term)
```
Our model is finding the *mean* effect, but we have a lot of big outliers, so mean tends to be far away from the typical value.
We can alleviate this problem by using a model that is robust to the effect of outliers: `MASS::rlm()`.
This greatly reduces the impact of the outliers on our estimates, and gives a model that does a good job of removing the day of week pattern:
```{r, warn = FALSE}
mod3 <- MASS::rlm(n ~ wday * term, data = daily)
daily %>%
add_residuals(mod3, "resid") %>%
ggplot(aes(date, resid)) +
geom_hline(yintercept = 0, size = 2, colour = "white") +
geom_line()
```
It's now much easier to see the long-term trend, and the positive and negative outliers.
### Computed variables
If you're experimenting with many models and many visualisations, it's a good idea to bundle the creation of variables up into a function so there's no chance of accidentally applying a different transformation in different places.
For example, we could write:
```{r}
compute_vars <- function(data) {
data %>%
mutate(
term = term(date),
wday = wday(date, label = TRUE)
)
}
```
Another option is to put the transformations directly in the model formula:
```{r}
wday2 <- function(x) wday(x, label = TRUE)
mod3 <- lm(n ~ wday2(date) * term(date), data = daily)
```
Either approach is reasonable.
Making the transformed variable explicit is useful if you want to check your work, or use them in a visualisation.
But you can't easily use transformations (like splines) that return multiple columns.
Including the transformations in the model function makes life a little easier when you're working with many different datasets because the model is self contained.
### Time of year: an alternative approach
In the previous section we used our domain knowledge (how the US school term affects travel) to improve the model.
An alternative to using our knowledge explicitly in the model is to give the data more room to speak.
We could use a more flexible model and allow that to capture the pattern we're interested in.
A simple linear trend isn't adequate, so we could try using a natural spline to fit a smooth curve across the year:
```{r}
library(splines)
mod <- MASS::rlm(n ~ wday * ns(date, 5), data = daily)
daily %>%
data_grid(wday, date = seq_range(date, n = 13)) %>%
add_predictions(mod) %>%
ggplot(aes(date, pred, colour = wday)) +
geom_line() +
geom_point()
```
We see a strong pattern in the numbers of Saturday flights.
This is reassuring, because we also saw that pattern in the raw data.
It's a good sign when you get the same signal from different approaches.
### Exercises
1. Use your Google sleuthing skills to brainstorm why there were fewer than expected flights on Jan 20, May 26, and Sep 1.
(Hint: they all have the same explanation.) How would these days generalise to another year?
2. What do the three days with high positive residuals represent?
How would these days generalise to another year?
```{r}
daily %>%
slice_max(n = 3, resid)
```
3. Create a new variable that splits the `wday` variable into terms, but only for Saturdays, i.e. it should have `Thurs`, `Fri`, but `Sat-summer`, `Sat-spring`, `Sat-fall`.
How does this model compare with the model with every combination of `wday` and `term`?
4. Create a new `wday` variable that combines the day of week, term (for Saturdays), and public holidays.
What do the residuals of that model look like?
5. What happens if you fit a day of week effect that varies by month (i.e. `n ~ wday * month`)?
Why is this not very helpful?
6. What would you expect the model `n ~ wday + ns(date, 5)` to look like?
Knowing what you know about the data, why would you expect it to be not particularly effective?
7. We hypothesised that people leaving on Sundays are more likely to be business travellers who need to be somewhere on Monday.
Explore that hypothesis by seeing how it breaks down based on distance and time: if it's true, you'd expect to see more Sunday evening flights to places that are far away.
8. It's a little frustrating that Sunday and Saturday are on separate ends of the plot.
Write a small function to set the levels of the factor so that the week starts on Monday.
## Learning more about models
We have only scratched the absolute surface of modelling, but you have hopefully gained some simple, but general-purpose tools that you can use to improve your own data analyses.
It's OK to start simple!
As you've seen, even very simple models can make a dramatic difference in your ability to tease out interactions between variables.
These modelling chapters are even more opinionated than the rest of the book.
I approach modelling from a somewhat different perspective to most others, and there is relatively little space devoted to it.
Modelling really deserves a book on its own, so I'd highly recommend that you read at least one of these three books:
- *Statistical Modeling: A Fresh Approach* by Danny Kaplan, <http://project-mosaic-books.com/?page_id=13>.
This book provides a gentle introduction to modelling, where you build your intuition, mathematical tools, and R skills in parallel.
The book replaces a traditional "introduction to statistics" course, providing a curriculum that is up-to-date and relevant to data science.
- *An Introduction to Statistical Learning* by Gareth James, Daniela Witten, Trevor Hastie, and Robert Tibshirani, <http://www-bcf.usc.edu/~gareth/ISL/> (available online for free).
This book presents a family of modern modelling techniques collectively known as statistical learning.
For an even deeper understanding of the math behind the models, read the classic *Elements of Statistical Learning* by Trevor Hastie, Robert Tibshirani, and Jerome Friedman, <https://web.stanford.edu/~hastie/Papers/ESLII.pdf> (also available online for free).
- *Applied Predictive Modeling* by Max Kuhn and Kjell Johnson, <http://appliedpredictivemodeling.com>.
This book is a companion to the **caret** package and provides practical tools for dealing with real-life predictive modelling challenges.

622
extra/model/model-many.Rmd Normal file
View File

@@ -0,0 +1,622 @@
# Many models
## Introduction
In this chapter you're going to learn three powerful ideas that help you to work with large numbers of models with ease:
1. Using many simple models to better understand complex datasets.
2. Using list-columns to store arbitrary data structures in a data frame.
For example, this will allow you to have a column that contains linear models.
3. Using the **broom** package, by David Robinson, to turn models into tidy data.
This is a powerful technique for working with large numbers of models because once you have tidy data, you can apply all of the techniques that you've learned about earlier in the book.
We'll start by diving into a motivating example using data about life expectancy around the world.
It's a small dataset but it illustrates how important modelling can be for improving your visualisations.
We'll use a large number of simple models to partition out some of the strongest signals so we can see the subtler signals that remain.
We'll also see how model summaries can help us pick out outliers and unusual trends.
The following sections will dive into more detail about the individual techniques:
1. In [list-columns](#list-columns-1), you'll learn more about the list-column data structure, and why it's valid to put lists in data frames.
2. In [creating list-columns], you'll learn the three main ways in which you'll create list-columns.
3. In [simplifying list-columns] you'll learn how to convert list-columns back to regular atomic vectors (or sets of atomic vectors) so you can work with them more easily.
4. In [making tidy data with broom], you'll learn about the full set of tools provided by broom, and see how they can be applied to other types of data structure.
This chapter is somewhat aspirational: if this book is your first introduction to R, this chapter is likely to be a struggle.
It requires you to have deeply internalised ideas about modelling, data structures, and iteration.
So don't worry if you don't get it --- just put this chapter aside for a few months, and come back when you want to stretch your brain.
### Prerequisites
Working with many models requires many of the packages of the tidyverse (for data exploration, wrangling, and programming) and modelr to facilitate modelling.
```{r setup, message = FALSE}
library(modelr)
library(tidyverse)
```
## gapminder
To motivate the power of many simple models, we're going to look into the "gapminder" data.
This data was popularised by Hans Rosling, a Swedish doctor and statistician.
If you've never heard of him, stop reading this chapter right now and go watch one of his videos!
He is a fantastic data presenter and illustrates how you can use data to present a compelling story.
A good place to start is this short video filmed in conjunction with the BBC: <https://www.youtube.com/watch?v=jbkSRLYSojo>.
The gapminder data summarises the progression of countries over time, looking at statistics like life expectancy and GDP.
The data is easy to access in R, thanks to Jenny Bryan who created the gapminder package:
```{r}
library(gapminder)
gapminder
```
In this case study, we're going to focus on just three variables to answer the question "How does life expectancy (`lifeExp`) change over time (`year`) for each country (`country`)?".
A good place to start is with a plot:
```{r}
gapminder %>%
ggplot(aes(year, lifeExp, group = country)) +
geom_line(alpha = 1/3)
```
This is a small dataset: it only has \~1,700 observations and 3 variables.
But it's still hard to see what's going on!
Overall, it looks like life expectancy has been steadily improving.
However, if you look closely, you might notice some countries that don't follow this pattern.
How can we make those countries easier to see?
One way is to use the same approach as in the last chapter: there's a strong signal (overall linear growth) that makes it hard to see subtler trends.
We'll tease these factors apart by fitting a model with a linear trend.
The model captures steady growth over time, and the residuals will show what's left.
You already know how to do that if we had a single country:
```{r, out.width = "33%", fig.asp = 1, fig.width = 3, fig.align='default'}
nz <- filter(gapminder, country == "New Zealand")
nz %>%
ggplot(aes(year, lifeExp)) +
geom_line() +
ggtitle("Full data = ")
nz_mod <- lm(lifeExp ~ year, data = nz)
nz %>%
add_predictions(nz_mod) %>%
ggplot(aes(year, pred)) +
geom_line() +
ggtitle("Linear trend + ")
nz %>%
add_residuals(nz_mod) %>%
ggplot(aes(year, resid)) +
geom_hline(yintercept = 0, colour = "white", size = 3) +
geom_line() +
ggtitle("Remaining pattern")
```
How can we easily fit that model to every country?
### Nested data
You could imagine copy and pasting that code multiple times; but you've already learned a better way!
Extract out the common code with a function and repeat using a map function from purrr.
This problem is structured a little differently to what you've seen before.
Instead of repeating an action for each variable, we want to repeat an action for each country, a subset of rows.
To do that, we need a new data structure: the **nested data frame**.
To create a nested data frame we start with a grouped data frame, and "nest" it:
```{r}
by_country <- gapminder %>%
group_by(country, continent) %>%
nest()
by_country
```
(I'm cheating a little by grouping on both `continent` and `country`. Given `country`, `continent` is fixed, so this doesn't add any more groups, but it's an easy way to carry an extra variable along for the ride.)
This creates a data frame that has one row per group (per country), and a rather unusual column: `data`.
`data` is a list of data frames (or tibbles, to be precise).
This seems like a crazy idea: we have a data frame with a column that is a list of other data frames!
I'll explain shortly why I think this is a good idea.
The `data` column is a little tricky to look at because it's a moderately complicated list, and we're still working on good tools to explore these objects.
Unfortunately using `str()` is not recommended as it will often produce very long output.
But if you pluck out a single element from the `data` column you'll see that it contains all the data for that country (in this case, Afghanistan).
```{r}
by_country$data[[1]]
```
Note the difference between a standard grouped data frame and a nested data frame: in a grouped data frame, each row is an observation; in a nested data frame, each row is a group.
Another way to think about a nested dataset is we now have a meta-observation: a row that represents the complete time course for a country, rather than a single point in time.
### List-columns
Now that we have our nested data frame, we're in a good position to fit some models.
We have a model-fitting function:
```{r}
country_model <- function(df) {
lm(lifeExp ~ year, data = df)
}
```
And we want to apply it to every data frame.
The data frames are in a list, so we can use `purrr::map()` to apply `country_model` to each element:
```{r}
models <- map(by_country$data, country_model)
```
However, rather than leaving the list of models as a free-floating object, I think it's better to store it as a column in the `by_country` data frame.
Storing related objects in columns is a key part of the value of data frames, and why I think list-columns are such a good idea.
In the course of working with these countries, we are going to have lots of lists where we have one element per country.
So why not store them all together in one data frame?
In other words, instead of creating a new object in the global environment, we're going to create a new variable in the `by_country` data frame.
That's a job for `dplyr::mutate()`:
```{r}
by_country <- by_country %>%
mutate(model = map(data, country_model))
by_country
```
This has a big advantage: because all the related objects are stored together, you don't need to manually keep them in sync when you filter or arrange.
The semantics of the data frame takes care of that for you:
```{r}
by_country %>%
filter(continent == "Europe")
by_country %>%
arrange(continent, country)
```
If your list of data frames and list of models were separate objects, you have to remember that whenever you re-order or subset one vector, you need to re-order or subset all the others in order to keep them in sync.
If you forget, your code will continue to work, but it will give the wrong answer!
### Unnesting
Previously we computed the residuals of a single model with a single dataset.
Now we have 142 data frames and 142 models.
To compute the residuals, we need to call `add_residuals()` with each model-data pair:
```{r}
by_country <- by_country %>%
mutate(
resids = map2(data, model, add_residuals)
)
by_country
```
But how can you plot a list of data frames?
Instead of struggling to answer that question, let's turn the list of data frames back into a regular data frame.
Previously we used `nest()` to turn a regular data frame into an nested data frame, and now we do the opposite with `unnest()`:
```{r}
resids <- unnest(by_country, resids)
resids
```
Note that each regular column is repeated once for each row of the nested tibble.
Now we have regular data frame, we can plot the residuals:
```{r}
resids %>%
ggplot(aes(year, resid)) +
geom_line(aes(group = country), alpha = 1 / 3) +
geom_smooth(se = FALSE)
```
Facetting by continent is particularly revealing:
```{r}
resids %>%
ggplot(aes(year, resid, group = country)) +
geom_line(alpha = 1 / 3) +
facet_wrap(~continent)
```
It looks like we've missed some mild patterns.
There's also something interesting going on in Africa: we see some very large residuals which suggests our model isn't fitting so well there.
We'll explore that more in the next section, attacking it from a slightly different angle.
### Model quality
Instead of looking at the residuals from the model, we could look at some general measurements of model quality.
You learned how to compute some specific measures in the previous chapter.
Here we'll show a different approach using the broom package.
The broom package provides a general set of functions to turn models into tidy data.
Here we'll use `broom::glance()` to extract some model quality metrics.
If we apply it to a model, we get a data frame with a single row:
```{r}
broom::glance(nz_mod)
```
We can use `mutate()` and `unnest()` to create a data frame with a row for each country:
```{r}
glance <- by_country %>%
mutate(glance = map(model, broom::glance)) %>%
select(country, continent, glance) %>%
unnest(glance)
glance
```
(Pay attention to the variables that aren't printed: there's a lot of useful stuff there.)
With this data frame in hand, we can start to look for models that don't fit well:
```{r}
glance %>%
arrange(r.squared)
```
The worst models all appear to be in Africa.
Let's double check that with a plot.
Here we have a relatively small number of observations and a discrete variable, so `geom_jitter()` is effective:
```{r}
glance %>%
ggplot(aes(continent, r.squared)) +
geom_jitter(width = 0.5)
```
We could pull out the countries with particularly bad $R^2$ and plot the data:
```{r}
bad_fit <- filter(glance, r.squared < 0.25)
gapminder %>%
semi_join(bad_fit, by = "country") %>%
ggplot(aes(year, lifeExp, colour = country)) +
geom_line()
```
We see two main effects here: the tragedies of the HIV/AIDS epidemic and the Rwandan genocide.
### Exercises
1. A linear trend seems to be slightly too simple for the overall trend.
Can you do better with a quadratic polynomial?
How can you interpret the coefficients of the quadratic?
(Hint you might want to transform `year` so that it has mean zero.)
2. Explore other methods for visualising the distribution of $R^2$ per continent.
You might want to try the ggbeeswarm package, which provides similar methods for avoiding overlaps as jitter, but uses deterministic methods.
3. To create the last plot (showing the data for the countries with the worst model fits), we needed two steps: we created a data frame with one row per country and then semi-joined it to the original dataset.
It's possible to avoid this join if we use `unnest()` instead of `unnest(.drop = TRUE)`.
How?
## List-columns {#list-columns-1}
Now that you've seen a basic workflow for managing many models, let's dive back into some of the details.
In this section, we'll explore the list-column data structure in a little more detail.
It's only recently that I've really appreciated the idea of the list-column.
List-columns are implicit in the definition of the data frame: a data frame is a named list of equal length vectors.
A list is a vector, so it's always been legitimate to use a list as a column of a data frame.
However, base R doesn't make it easy to create list-columns, and `data.frame()` treats a list as a list of columns:.
```{r}
data.frame(x = list(1:3, 3:5))
```
You can prevent `data.frame()` from doing this with `I()`, but the result doesn't print particularly well:
```{r}
data.frame(
x = I(list(1:3, 3:5)),
y = c("1, 2", "3, 4, 5")
)
```
Tibble alleviates this problem by being lazier (`tibble()` doesn't modify its inputs) and by providing a better print method:
```{r}
tibble(
x = list(1:3, 3:5),
y = c("1, 2", "3, 4, 5")
)
```
It's even easier with `tribble()` as it can automatically work out that you need a list:
```{r}
tribble(
~x, ~y,
1:3, "1, 2",
3:5, "3, 4, 5"
)
```
List-columns are often most useful as intermediate data structure.
They're hard to work with directly, because most R functions work with atomic vectors or data frames, but the advantage of keeping related items together in a data frame is worth a little hassle.
Generally there are three parts of an effective list-column pipeline:
1. You create the list-column using one of `nest()`, `summarise()` + `list()`, or `mutate()` + a map function, as described in [Creating list-columns].
2. You create other intermediate list-columns by transforming existing list columns with `map()`, `map2()` or `pmap()`.
For example, in the case study above, we created a list-column of models by transforming a list-column of data frames.
3. You simplify the list-column back down to a data frame or atomic vector, as described in [Simplifying list-columns].
## Creating list-columns
Typically, you won't create list-columns with `tibble()`.
Instead, you'll create them from regular columns, using one of three methods:
1. With `tidyr::nest()` to convert a grouped data frame into a nested data frame where you have list-column of data frames.
2. With `mutate()` and vectorised functions that return a list.
3. With `summarise()` and summary functions that return multiple results.
Alternatively, you might create them from a named list, using `tibble::enframe()`.
Generally, when creating list-columns, you should make sure they're homogeneous: each element should contain the same type of thing.
There are no checks to make sure this is true, but if you use purrr and remember what you've learned about type-stable functions, you should find it happens naturally.
### With nesting
`nest()` creates a nested data frame, which is a data frame with a list-column of data frames.
In a nested data frame each row is a meta-observation: the other columns give variables that define the observation (like country and continent above), and the list-column of data frames gives the individual observations that make up the meta-observation.
There are two ways to use `nest()`.
So far you've seen how to use it with a grouped data frame.
When applied to a grouped data frame, `nest()` keeps the grouping columns as is, and bundles everything else into the list-column:
```{r}
gapminder %>%
group_by(country, continent) %>%
nest()
```
You can also use it on an ungrouped data frame, specifying which columns you want to nest:
```{r}
gapminder %>%
nest(data = c(year:gdpPercap))
```
### From vectorised functions
Some useful functions take an atomic vector and return a list.
For example, in [strings] you learned about `stringr::str_split()` which takes a character vector and returns a list of character vectors.
If you use that inside mutate, you'll get a list-column:
```{r}
df <- tribble(
~x1,
"a,b,c",
"d,e,f,g"
)
df %>%
mutate(x2 = stringr::str_split(x1, ","))
```
`unnest()` knows how to handle these lists of vectors:
```{r}
df %>%
mutate(x2 = stringr::str_split(x1, ",")) %>%
unnest(x2)
```
(If you find yourself using this pattern a lot, make sure to check out `tidyr::separate_rows()` which is a wrapper around this common pattern).
Another example of this pattern is using the `map()`, `map2()`, `pmap()` from purrr.
For example, we could take the final example from [Invoking different functions] and rewrite it to use `mutate()`:
```{r}
sim <- tribble(
~f, ~params,
"runif", list(min = -1, max = 1),
"rnorm", list(sd = 5),
"rpois", list(lambda = 10)
)
sim %>%
mutate(sims = invoke_map(f, params, n = 10))
```
Note that technically `sim` isn't homogeneous because it contains both double and integer vectors.
However, this is unlikely to cause many problems since integers and doubles are both numeric vectors.
### From multivalued summaries
One restriction of `summarise()` is that it only works with summary functions that return a single value.
That means that you can't use it with functions like `quantile()` that return a vector of arbitrary length:
```{r, error = TRUE}
mtcars %>%
group_by(cyl) %>%
summarise(q = quantile(mpg))
```
You can however, wrap the result in a list!
This obeys the contract of `summarise()`, because each summary is now a list (a vector) of length 1.
```{r}
mtcars %>%
group_by(cyl) %>%
summarise(q = list(quantile(mpg)))
```
To make useful results with unnest, you'll also need to capture the probabilities:
```{r}
probs <- c(0.01, 0.25, 0.5, 0.75, 0.99)
mtcars %>%
group_by(cyl) %>%
summarise(p = list(probs), q = list(quantile(mpg, probs))) %>%
unnest(c(p, q))
```
### From a named list
Data frames with list-columns provide a solution to a common problem: what do you do if you want to iterate over both the contents of a list and its elements?
Instead of trying to jam everything into one object, it's often easier to make a data frame: one column can contain the elements, and one column can contain the list.
An easy way to create such a data frame from a list is `tibble::enframe()`.
```{r}
x <- list(
a = 1:5,
b = 3:4,
c = 5:6
)
df <- enframe(x)
df
```
The advantage of this structure is that it generalises in a straightforward way - names are useful if you have character vector of metadata, but don't help if you have other types of data, or multiple vectors.
Now if you want to iterate over names and values in parallel, you can use `map2()`:
```{r}
df %>%
mutate(
smry = map2_chr(name, value, ~ stringr::str_c(.x, ": ", .y[1]))
)
```
### Exercises
1. List all the functions that you can think of that take a atomic vector and return a list.
2. Brainstorm useful summary functions that, like `quantile()`, return multiple values.
3. What's missing in the following data frame?
How does `quantile()` return that missing piece?
Why isn't that helpful here?
```{r}
mtcars %>%
group_by(cyl) %>%
summarise(q = list(quantile(mpg))) %>%
unnest(q)
```
4. What does this code do?
Why might might it be useful?
```{r, eval = FALSE}
mtcars %>%
group_by(cyl) %>%
summarise_all(list(list))
```
## Simplifying list-columns
To apply the techniques of data manipulation and visualisation you've learned in this book, you'll need to simplify the list-column back to a regular column (an atomic vector), or set of columns.
The technique you'll use to collapse back down to a simpler structure depends on whether you want a single value per element, or multiple values:
1. If you want a single value, use `mutate()` with `map_lgl()`, `map_int()`, `map_dbl()`, and `map_chr()` to create an atomic vector.
2. If you want many values, use `unnest()` to convert list-columns back to regular columns, repeating the rows as many times as necessary.
These are described in more detail below.
### List to vector
If you can reduce your list column to an atomic vector then it will be a regular column.
For example, you can always summarise an object with its type and length, so this code will work regardless of what sort of list-column you have:
```{r}
df <- tribble(
~x,
letters[1:5],
1:3,
runif(5)
)
df %>% mutate(
type = map_chr(x, typeof),
length = map_int(x, length)
)
```
This is the same basic information that you get from the default tbl print method, but now you can use it for filtering.
This is a useful technique if you have a heterogeneous list, and want to filter out the parts aren't working for you.
Don't forget about the `map_*()` shortcuts - you can use `map_chr(x, "apple")` to extract the string stored in `apple` for each element of `x`.
This is useful for pulling apart nested lists into regular columns.
Use the `.null` argument to provide a value to use if the element is missing (instead of returning `NULL`):
```{r}
df <- tribble(
~x,
list(a = 1, b = 2),
list(a = 2, c = 4)
)
df %>% mutate(
a = map_dbl(x, "a"),
b = map_dbl(x, "b", .null = NA_real_)
)
```
### Unnesting
`unnest()` works by repeating the regular columns once for each element of the list-column.
For example, in the following very simple example we repeat the first row 4 times (because there the first element of `y` has length four), and the second row once:
```{r}
tibble(x = 1:2, y = list(1:4, 1)) %>% unnest(y)
```
This means that you can't simultaneously unnest two columns that contain different number of elements:
```{r, error = TRUE}
# Ok, because y and z have the same number of elements in
# every row
df1 <- tribble(
~x, ~y, ~z,
1, c("a", "b"), 1:2,
2, "c", 3
)
df1
df1 %>% unnest(c(y, z))
# Doesn't work because y and z have different number of elements
df2 <- tribble(
~x, ~y, ~z,
1, "a", 1:2,
2, c("b", "c"), 3
)
df2
df2 %>% unnest(c(y, z))
```
The same principle applies when unnesting list-columns of data frames.
You can unnest multiple list-cols as long as all the data frames in each row have the same number of rows.
### Exercises
1. Why might the `lengths()` function be useful for creating atomic vector columns from list-columns?
2. List the most common types of vector found in a data frame.
What makes lists different?
## Making tidy data with broom
The broom package provides three general tools for turning models into tidy data frames:
1. `broom::glance(model)` returns a row for each model.
Each column gives a model summary: either a measure of model quality, or complexity, or a combination of the two.
2. `broom::tidy(model)` returns a row for each coefficient in the model.
Each column gives information about the estimate or its variability.
3. `broom::augment(model, data)` returns a row for each row in `data`, adding extra values like residuals, and influence statistics.

69
extra/model/model.Rmd Normal file
View File

@@ -0,0 +1,69 @@
# (PART) Model {.unnumbered}
# Introduction {#model-intro}
Now that you are equipped with powerful programming tools we can finally return to modelling.
You'll use your new tools of data wrangling and programming, to fit many models and understand how they work.
The focus of this book is on exploration, not confirmation or formal inference.
But you'll learn a few basic tools that help you understand the variation within your models.
```{r echo = FALSE, out.width = "75%"}
knitr::include_graphics("diagrams/data-science-model.png")
```
The goal of a model is to provide a simple low-dimensional summary of a dataset.
Ideally, the model will capture true "signals" (i.e. patterns generated by the phenomenon of interest), and ignore "noise" (i.e. random variation that you're not interested in).
Here we only cover "predictive" models, which, as the name suggests, generate predictions.
There is another type of model that we're not going to discuss: "data discovery" models.
These models don't make predictions, but instead help you discover interesting relationships within your data.
(These two categories of models are sometimes called supervised and unsupervised, but I don't think that terminology is particularly illuminating.)
This book is not going to give you a deep understanding of the mathematical theory that underlies models.
It will, however, build your intuition about how statistical models work, and give you a family of useful tools that allow you to use models to better understand your data:
- In [model basics], you'll learn how models work mechanistically, focussing on the important family of linear models.
You'll learn general tools for gaining insight into what a predictive model tells you about your data, focussing on simple simulated datasets.
- In [model building], you'll learn how to use models to pull out known patterns in real data.
Once you have recognised an important pattern it's useful to make it explicit in a model, because then you can more easily see the subtler signals that remain.
- In [many models], you'll learn how to use many simple models to help understand complex datasets.
This is a powerful technique, but to access it you'll need to combine modelling and programming tools.
These topics are notable because of what they don't include: any tools for quantitatively assessing models.
That is deliberate: precisely quantifying a model requires a couple of big ideas that we just don't have the space to cover here.
For now, you'll rely on qualitative assessment and your natural scepticism.
In [Learning more about models], we'll point you to other resources where you can learn more.
## Hypothesis generation vs. hypothesis confirmation
In this book, we are going to use models as a tool for exploration, completing the trifecta of the tools for EDA that were introduced in Part 1.
This is not how models are usually taught, but as you will see, models are an important tool for exploration.
Traditionally, the focus of modelling is on inference, or for confirming that an hypothesis is true.
Doing this correctly is not complicated, but it is hard.
There is a pair of ideas that you must understand in order to do inference correctly:
1. Each observation can either be used for exploration or confirmation, not both.
2. You can use an observation as many times as you like for exploration, but you can only use it once for confirmation.
As soon as you use an observation twice, you've switched from confirmation to exploration.
This is necessary because to confirm a hypothesis you must use data independent of the data that you used to generate the hypothesis.
Otherwise you will be over optimistic.
There is absolutely nothing wrong with exploration, but you should never sell an exploratory analysis as a confirmatory analysis because it is fundamentally misleading.
If you are serious about doing an confirmatory analysis, one approach is to split your data into three pieces before you begin the analysis:
1. 60% of your data goes into a **training** (or exploration) set.
You're allowed to do anything you like with this data: visualise it and fit tons of models to it.
2. 20% goes into a **query** set.
You can use this data to compare models or visualisations by hand, but you're not allowed to use it as part of an automated process.
3. 20% is held back for a **test** set.
You can only use this data ONCE, to test your final model.
This partitioning allows you to explore the training data, occasionally generating candidate hypotheses that you check with the query set.
When you are confident you have the right model, you can check it once with the test data.
(Note that even when doing confirmatory modelling, you will still need to do EDA. If you don't do any EDA you will remain blind to the quality problems with your data.)