r4ds/model-many.Rmd

285 lines
10 KiB
Plaintext
Raw Normal View History

2016-06-13 22:50:55 +08:00
# Working with many models
2016-06-14 22:02:08 +08:00
In this chapter you're going to learn three powerful technique that allow you to work with large numbers of model in a straight forward way. You will combine purrr, tidyr, dplyr, and broom.
We'll dive into a quick case study using data about life expectancy and then dig into the details in the following sections. In the next chapter, you'll use these techniques after generating multiple datasets in other ways.
### Prerequisites
```{r setup, message = FALSE}
# Standard data manipulation and visulisation
library(dplyr)
library(ggplot2)
# Tools for working with models
library(broom)
library(modelr)
# Tools for working with lots of models
2016-06-13 22:50:55 +08:00
library(purrr)
2016-06-14 22:02:08 +08:00
library(tidyr)
```
## gapminder
We're going to explore these ideas using the "Gapminder" data. This data was popularised by Hans Rosling. If you've never heard of him, you should stop now and watch one of his data presentations, like <https://www.youtube.com/watch?v=jbkSRLYSojo>.
We're going to use a subset of the full data as included in the gapminder package, by Jenny Bryan:
```{r}
library(gapminder)
gapminder
```
We're going to focus on just three variables: how does life expectancy (`lifeExp`) change over time (`year`) for each country (`country`). We can attempt to display this with a simple line chart:
```{r}
gapminder %>%
ggplot(aes(year, lifeExp, group = country)) +
geom_line()
```
This is a small dataset, only ~1700 observations and three variables, but it's still hard to see what's going on in this plot. Overall, it looks like life expectency has been steadily improving over time but if you look closely you might spot some countries that don't follow this pattern.
We're going make the unusual patterns easier to see by following the same appraoch as the previous chapter. We'll capture the linear trend in a model and predict the results. You already know how to do this if we had a single country:
```{r}
nz <- filter(gapminder, country == "New Zealand")
nz_mod <- lm(lifeExp ~ year, data = nz)
nz %>%
add_predictions(pred = nz_mod) %>%
ggplot(aes(year, pred)) +
geom_line()
nz %>%
add_residuals(resid = nz_mod) %>%
ggplot(aes(year, resid)) +
geom_hline(yintercept = 0, colour = "white", size = 3) +
geom_line()
```
But what can we do if want to do fit that model to each country?
### Nested data
You could imagine copy and pasting that code multiple times. But you've already learned a better way of handling that pattern: extract out the common code with a function and repeat using a map function from purrr.
This problem is structured a little differently because we want to repeat something for each country, a subset of rows, rather than each variable. So first we need to make a list of data frames. There are lots of ways to do this (for example, you could use `split()` from base R), but we're going to use a function from tidyr, called `nest()`:
```{r}
by_country <- gapminder %>%
group_by(country, continent) %>%
nest()
by_country
```
This creates an data frame that has one row per country, and a rather unusual column: `data`. `data` is a list of data frames. This is a pretty crazy idea: We have a data frame with a column that is a list of other data frames! I'll explain shortly why I think this is a good idea.
If you look at one of the elements of the `data` column you'll see that it contains all the data for that country (Afghanastan in this case).
```{r}
by_country$data[[1]]
```
We've changed from a standard "grouped" data frame, where each row is an observation, and the groups are stored as an index, to a __nested__ data frame where each row is one group, and the full data is stored in a list-column.
### List-columns
Now we need to take each element of that list of data frames and fit a model to it. We can do that with `purrr::map()`:
```{r}
country_model <- function(df) {
lm(lifeExp ~ year, data = df)
}
models <- map(by_country$data, country_model)
```
However, rather than leaving that as a separate object that's floating around in the global environment, I think it's better to store it as a variable in the `by_country` data frame.
This is the basic reason that I think list-columns are such a good idea. We are going to have lots of object where we have one per country. So why not store them all together in one data frame?
```{r}
by_country <- by_country %>%
mutate(model = map(data, country_model))
by_country
```
This has a big advantage: because all the related objects are stored together, you don't need to manual keep them all in sync when you filter or arrange. Dplyr takes take of that for you!
```{r}
by_country %>% filter(continent == "Europe")
by_country %>% arrange(continent, country)
```
If your list of data frames and list of models where separate objects, you have to remember that whenever you re-order or subset one, you need to re-order all the others. It's easy to forget, and you end up with vectors that are no longer synchronised. Your code will continue to work, but it will give the wrong answer!
### Unnesting
Previously we computed the residuals of a single model with a single dataset. Now we have 142 data frames and 142 models. To compute the residuals, we need to call `add_residuals()` in parallel:
```{r}
by_country %>% mutate(
resids = map2(data, model, add_residuals)
)
```
But how you can plot a list of data frames? Well, what if we could turn it back into a regular data frame? Previously we used nest to turn a regular data frame into an nested data frame, now we need to do the opposite with `unnest()`:
```{r}
resids <- by_country %>%
mutate(resids = map2(data, model, ~ add_residuals(.x, resid = .y))) %>%
unnest(resids)
resids
```
And we can plot the results:
```{r}
resids %>%
ggplot(aes(year, resid, group = country)) +
geom_line(alpha = 1 / 3) +
facet_wrap(~continent)
2016-06-13 22:50:55 +08:00
```
2016-06-14 22:02:08 +08:00
There's something intersting going on in Africa: we see large residuals which suggests our model isn't fitting so well there. We'll explore that more in the next section attacking it from a slightly different angle.
### Model quality
Instead of looking at the residuals from the model, we could look at some general measurements of model quality. One way to compute these in a convenient way is the broom package, by David Robinson.
`tidyr::glance()` gives us a data frame with a single row. Each column gives a model summary: either a measure of model quality, or complexity, or a combination of the two:
```{r}
glance(nz_mod)
```
We can use the same technique as with residuals to compute this for each country:
```{r}
by_country %>%
mutate(glance = map(model, glance)) %>%
unnest(glance)
```
But this includes all the list columns - this is the default when `unnest()` works on single row data frames (because it possible to keep them all in sync). To suppress these columns we use `.drop = TRUE`:
```{r}
glance <- by_country %>%
mutate(glance = map(model, glance)) %>%
unnest(glance, .drop = TRUE)
glance
```
We could look at which countries have the worst fits:
```{r}
glance %>% arrange(r.squared)
```
The worst models all appear to be in Africa. We could show this graphically. Here we have a relatively small number of observations and a discrete variable, so `geom_jitter()` is effective:
```{r}
glance %>%
ggplot(aes(continent, r.squared)) +
geom_jitter(width = 0.5)
```
We could put out the countries with particularly bad $R^2$ and plot the data:
```{r}
bad_fit <- filter(glance, r.squared < 0.25)
bad_fit
gapminder %>%
semi_join(bad_fit, by = "country") %>%
ggplot(aes(year, lifeExp, colour = country)) +
geom_line()
```
We see two main effect here: the tragedies of the HIV/AIDS epidemic, and the Rwandan genocide.
### Exercises
1. Explore other methods for visualsiation the distribution of $R^2$ per
continent. You might want to try `ggbeeswarm`, which provides similar
methods for avoiding overlaps as jitter, but with less randomness.
## List-columns
The idea of a list column is powerful. The contract of a data frame is that it's a named list of vectors, where each vector has the same length. A list is a vector, and a list can contain anything, so you can put anything in a list-column of a data frame.
Generally, you should make sure that your list columns are homogeneous: each element should contain the same type of thing. There are no checks to make sure this is true, but if you use purrr and remember what you've learned about type-stable functions you should find it happens naturally.
### Compared to base R
List columns are possible in base R, but conventions in `data.frame()` make creating and printing them a bit of a headache:
```{r, error = TRUE}
# Doesn't work
data.frame(x = list(1:2, 3:5))
# Works, but doesn't print particularly well
data.frame(x = I(list(1:2, 3:5)), y = c("1, 2", "3, 4, 5"))
```
The functions in tibble don't have this problem:
```{r}
data_frame(x = list(1:2, 3:5), y = c("1, 2", "3, 4, 5"))
```
### With `mutate()` and `summarise()`
You might find yourself creating list-columns with mutate and summarise. For example:
```{r}
data_frame(x = c("a,b,c", "d,e,f,g")) %>%
mutate(x = stringr::str_split(x, ","))
```
`unnest()` knows how to handle these lists of vectors as well as lists of data frames.
```{r}
data_frame(x = c("a,b,c", "d,e,f,g")) %>%
mutate(x = stringr::str_split(x, ",")) %>%
unnest()
```
(If you find yourself using this pattern alot, make sure to check out `separate_rows()`)
This can be useful for summary functions like `quantile()` that return a vector of values:
```{r}
mtcars %>%
group_by(cyl) %>%
summarise(q = list(quantile(mpg))) %>%
print() %>%
unnest()
```
Although you probably also want to keep track of which output corresponds to which input:
```{r}
probs <- c(0.01, 0.25, 0.5, 0.75, 0.99)
mtcars %>%
group_by(cyl) %>%
summarise(p = list(probs), q = list(quantile(mpg, probs))) %>%
unnest()
```
And even just `list()` can be a useful summary function (when?). It is a summary function because it takes a vector of length n, and returns a vector of length 1:
```{r}
mtcars %>% group_by(cyl) %>% summarise(list(mpg))
```
This an effective replacement to `split()` in base R (but instead of working with vectors it works with data frames).
### Exercises
## Nesting and unnesting
2016-06-13 22:50:55 +08:00
2016-06-14 22:02:08 +08:00
More details about `unnest()` options.