Use new data_grid instead of expand

This commit is contained in:
hadley 2016-07-27 16:42:14 -05:00
parent 40052d1d52
commit 4c789ab8e9
2 changed files with 15 additions and 16 deletions

View File

@ -45,7 +45,7 @@ The goal of a model is not to uncover truth, but to discover a simple approximat
We need a couple of packages specifically designed for modelling, and all the packages you've used before for EDA.
```{r setup, message = FALSE}
```{r setup, message = FALSE, cache = FALSE}
# Modelling functions
library(modelr)
options(na.action = na.warn)
@ -53,7 +53,6 @@ options(na.action = na.warn)
# EDA tools
library(ggplot2)
library(dplyr)
library(tidyr)
```
## A simple model
@ -243,10 +242,10 @@ It's also useful to see what the model doesn't capture, the so called residuals
### Predictions
To visualise the predictions from a model, we start by generating an evenly spaced grid of values that covers the region where our data lies. The easiest way to do that is to use `tidyr::expand()`. Its first argument is a data frame, and for each subsequent argument it finds the unique variables and then generates all combinations:
To visualise the predictions from a model, we start by generating an evenly spaced grid of values that covers the region where our data lies. The easiest way to do that is to use `modelr::data_grid()`. Its first argument is a data frame, and for each subsequent argument it finds the unique variables and then generates all combinations:
```{r}
grid <- sim1 %>% expand(x)
grid <- sim1 %>% data_grid(x)
grid
```
@ -377,7 +376,7 @@ We can fit a model to it, and generate predictions:
mod2 <- lm(y ~ x, data = sim2)
grid <- sim2 %>%
expand(x) %>%
data_grid(x) %>%
add_predictions(mod2)
grid
```
@ -416,7 +415,7 @@ When you add variables with `+`, the model will estimate each effect independent
To visualise these models we need two new tricks:
1. We have two predictors, so we need to give `expand()` two variables.
1. We have two predictors, so we need to give `data_grid()` both variables.
It finds all the unique values of `x1` and `x2` and then generates all
combinations.
@ -429,7 +428,7 @@ Together this gives us:
```{r}
grid <- sim3 %>%
expand(x1, x2) %>%
data_grid(x1, x2) %>%
gather_predictions(mod1, mod2)
grid
```
@ -467,7 +466,7 @@ mod1 <- lm(y ~ x1 + x2, data = sim4)
mod2 <- lm(y ~ x1 * x2, data = sim4)
grid <- sim4 %>%
expand(
data_grid(
x1 = seq_range(x1, 5),
x2 = seq_range(x2, 5)
) %>%
@ -475,7 +474,7 @@ grid <- sim4 %>%
grid
```
Note my use of `seq_range()` inside `expand()`. Instead of using every unique value of `x`, I'm going to use a regularly spaced grid of five values between the minimum and maximum numbers. It's probably not super important here, but it's a useful technique in general. There are two other useful arguments to `seq_range()`:
Note my use of `seq_range()` inside `data_grid()`. Instead of using every unique value of `x`, I'm going to use a regularly spaced grid of five values between the minimum and maximum numbers. It's probably not super important here, but it's a useful technique in general. There are two other useful arguments to `seq_range()`:
* `pretty = TRUE` will generate a "pretty" sequence, i.e. something that looks
nice to the human eye. This is useful if you want to produce tables of
@ -554,7 +553,7 @@ model_matrix(df, y ~ poly(x, 2))
However there's one major problem with using `poly()`: outside the range of the data, polynomials rapidly shoot off to positive or negative infinity. One safer alternative is to use the natural spline, `splines::ns()`.
```{r}
```{r, cache = FALSE}
library(splines)
model_matrix(df, y ~ ns(x, 2))
```
@ -581,7 +580,7 @@ mod4 <- lm(y ~ ns(x, 4), data = sim5)
mod5 <- lm(y ~ ns(x, 5), data = sim5)
grid <- sim5 %>%
expand(x = seq_range(x, n = 50, expand = 0.1)) %>%
data_grid(x = seq_range(x, n = 50, expand = 0.1)) %>%
gather_predictions(mod1, mod2, mod3, mod4, mod5, .pred = "y")
ggplot(sim5, aes(x, y)) +
@ -610,7 +609,7 @@ sim6 <- tibble(
mod <- lm(y ~ x1 * x2, data = sim6)
grid <- sim6 %>%
expand(
data_grid(
x1 = seq_range(x1, 10),
x2 = c(0, 0.5, 1, 1.5)
) %>%

View File

@ -86,7 +86,7 @@ Then we look at what the model tells us about the data. Note that I back transfo
```{r}
grid <- diamonds2 %>%
expand(carat = seq_range(carat, 20)) %>%
data_grid(carat = seq_range(carat, 20)) %>%
mutate(lcarat = log2(carat)) %>%
add_predictions(mod_diamond, "lprice") %>%
mutate(price = 2 ^ lprice)
@ -213,7 +213,7 @@ One way to remove this strong pattern is to use a model. First, we fit the model
mod <- lm(n ~ wday, data = daily)
grid <- daily %>%
expand(wday) %>%
data_grid(wday) %>%
add_predictions(mod, "n")
ggplot(daily, aes(wday, n)) +
@ -340,7 +340,7 @@ We can see the problem by overlaying the predictions from the model on to the ra
```{r}
grid <- daily %>%
expand(wday, term) %>%
data_grid(wday, term) %>%
add_predictions(mod2, "n")
ggplot(daily, aes(wday, n)) +
@ -372,7 +372,7 @@ library(splines)
mod <- MASS::rlm(n ~ wday * ns(date, 5), data = daily)
daily %>%
tidyr::expand(wday, date = seq_range(date, n = 13)) %>%
data_grid(wday, date = seq_range(date, n = 13)) %>%
add_predictions(mod) %>%
ggplot(aes(date, pred, colour = wday)) +
geom_line() +