Delete extra files
I'm confident that these won't make it into the final book
This commit is contained in:
parent
0c05d01794
commit
0b7f9f248b
|
@ -1,152 +0,0 @@
|
|||
|
||||
## Visualizing three or more variables
|
||||
|
||||
In general, outliers, clusters, and patterns become easier to spot as you look at the interaction of more and more variables. However, as you include more variables in your plot, data becomes harder to visualize.
|
||||
|
||||
You can extend scatterplots into three dimensions with the plotly, rgl, rglwidget, and threejs packages (among others). Each creates a "three dimensional," graph that you can rotate with your mouse. Below is an example from plotly, displayed as a static image.
|
||||
|
||||
```{r eval = FALSE}
|
||||
library(plotly)
|
||||
plot_ly(data = iris, x = Sepal.Length, y = Sepal.Width, z = Petal.Width,
|
||||
color = Species, type = "scatter3d", mode = "markers")
|
||||
```
|
||||
|
||||
```{r, echo = FALSE}
|
||||
knitr::include_graphics("images/EDA-plotly.png")
|
||||
```
|
||||
|
||||
You can extend this approach into n-dimensional hyperspace with the ggobi package, but you will soon notice a weakness of multidimensional graphs. You can only visualize multidimensional space by projecting it onto your two dimensional retinas. In the case of 3D graphics, you can combine 2D projections with rotation to create an intuitive illusion of 3D space, but the illusion ceases to be intuitive as soon as you add a fourth dimension.
|
||||
|
||||
This doesn't mean that you should ignore complex interactions in your data. You can explore multivariate relationships in several ways. You can
|
||||
|
||||
* visualize each combination of variables in a multivariate relationship, two at a time
|
||||
|
||||
* use aesthetics and facetting to add additional variables to a 2D plot
|
||||
|
||||
* use a clustering algorithm to spot clusters in multivariate space
|
||||
|
||||
* use a modeling algorithm to spot patterns and outliers in multivariate space
|
||||
|
||||
## Clusters
|
||||
|
||||
Cluster algorithms are automated tools that seek out clusters in n-dimensional space for you. Base R provides two easy to use clustering algorithms: hierarchical clustering and k means clustering.
|
||||
|
||||
### Hierarchical clustering
|
||||
|
||||
Hierarchical clustering uses a simple algorithm to locate groups of points that are near each other in n-dimensional space:
|
||||
|
||||
1. Identify the two points that are closest to each other
|
||||
2. Combine these points into a cluster
|
||||
3. Treat the new cluster as a point
|
||||
4. Repeat until all of the points are grouped into a single cluster
|
||||
|
||||
You can visualize the results of the algorithm as a dendrogram, and you can use the dendrogram to divide your data into any number of clusters. The figure below demonstrates how the algorithm would proceed in a two dimensional dataset.
|
||||
|
||||
```{r, echo = FALSE}
|
||||
knitr::include_graphics("images/EDA-hclust.png")
|
||||
```
|
||||
|
||||
To use hierarchical clustering in R, begin by selecting the numeric columns from your data; you can only apply hierarchical clustering to numeric data. Then apply the `dist()` function to the data and pass the results to `hclust()`. `dist()` computes the distances between your points in the n dimensional space defined by your numeric vectors. `hclust()` performs the clustering algorithm.
|
||||
|
||||
```{r}
|
||||
small_iris <- sample_n(iris, 50)
|
||||
|
||||
iris_hclust <- small_iris |>
|
||||
select(Sepal.Length, Sepal.Width, Petal.Length, Petal.Width) |>
|
||||
dist() |>
|
||||
hclust(method = "complete")
|
||||
```
|
||||
|
||||
Use `plot()` to visualize the results as a dendrogram. Each observation in the dataset will appear at the bottom of the dendrogram labeled by its rowname. You can use the labels argument to set the labels to something more informative.
|
||||
|
||||
```{r fig.height = 4}
|
||||
plot(iris_hclust, labels = small_iris$Species)
|
||||
```
|
||||
|
||||
To see how near two data points are to each other, trace the paths of the data points up through the tree until they intersect. The y value of the intersection displays how far apart the points are in n-dimensional space. Points that are close to each other will intersect at a small y value, points that are far from each other will intersect at a large y value. Groups of points that are near each other will look like "leaves" that all grow on the same "branch." The ordering of the x axis in the dendrogram is somewhat arbitrary (think of the tree as a mobile, each horizontal branch can spin around meaninglessly).
|
||||
|
||||
You can split your data into any number of clusters by drawing a horizontal line across the tree. Each vertical branch that the line crosses will represent a cluster that contains all of the points downstream from the branch. Move the line up the y axis to intersect fewer branches (and create fewer clusters), move the line down the y axis to intersect more branches and (create more clusters).
|
||||
|
||||
`cutree()` provides a useful way to split data points into clusters. Give cutree the output of `hclust()` as well as the number of clusters that you want to split the data into. `cutree()` will return a vector of cluster labels for your dataset. To visualize the results, map the output of `cutree()` to an aesthetic.
|
||||
|
||||
```{r}
|
||||
(clusters <- cutree(iris_hclust, 3))
|
||||
|
||||
ggplot(small_iris, aes(x = Sepal.Width, y = Sepal.Length)) +
|
||||
geom_point(aes(color = factor(clusters)))
|
||||
```
|
||||
|
||||
You can modify the hierarchical clustering algorithm by setting the method argument of hclust to one of "complete", "single", "average", or "centroid". The method determines how to measure the distance between two clusters or a lone point and a cluster, a measurement that affects the outcome of the algorithm.
|
||||
|
||||
```{r, echo = FALSE}
|
||||
knitr::include_graphics("images/EDA-linkage.png")
|
||||
```
|
||||
|
||||
* *complete* - Measures the greatest distance between any two points in the separate clusters. Tends to create distinct clusters and subclusters.
|
||||
|
||||
* *single* - Measures the smallest distance between any two points in the separate clusters. Tends to add points one at a time to existing clusters, creating ambiguously defined clusters.
|
||||
|
||||
* *average* - Measures the average distance between all combinations of points in the separate clusters. Tends to add points one at a time to existing clusters.
|
||||
|
||||
* *centroid* - Measures the distance between the average location of the points in each cluster.
|
||||
|
||||
|
||||
```{r fig.height = 4}
|
||||
small_iris |>
|
||||
select(Sepal.Length, Sepal.Width, Petal.Length, Petal.Width) |>
|
||||
dist() |>
|
||||
hclust(method = "single") |>
|
||||
plot(labels = small_iris$Species)
|
||||
```
|
||||
|
||||
|
||||
### K means clustering
|
||||
|
||||
K means clustering provides a simulation based alternative to hierarchical clustering. It identifies the "best" way to group your data into a predefined number of clusters. The figure below visualizes (in two dimensional space) the k means algorithm:
|
||||
|
||||
1. Randomly assign each data point to one of $k$ groups
|
||||
2. Compute the centroid of each group
|
||||
3. Reassign each point to the group whose centroid it is nearest to
|
||||
4. Repeat steps 2 and 3 until group memberships cease to change
|
||||
|
||||
```{r, echo = FALSE}
|
||||
knitr::include_graphics("images/EDA-kmeans.png")
|
||||
```
|
||||
|
||||
Use `kmeans()` to perform k means clustering with R. As with hierarchical clustering, you can only apply k means clustering to numerical data. Pass your numerical data to the `kmeans()` function, then set `center` to the number of clusters to search for ($k$) and `nstart` to the number of simulations to run. Since the results of k means clustering depend on the initial assignment of points to groups, which is random, R will run `nstart` simulations and then return the best results (as measured by the minimum sum of squared distances between each point and the centroid of the group it is assigned to). Finally, set the maximum number of iterations to let each simulation run in case the simulation cannot quickly find a stable grouping.
|
||||
|
||||
```{r}
|
||||
iris_kmeans <- small_iris |>
|
||||
select(Sepal.Length, Sepal.Width, Petal.Length, Petal.Width) |>
|
||||
kmeans(centers = 3, nstart = 20, iter.max = 50)
|
||||
|
||||
iris_kmeans$cluster
|
||||
```
|
||||
|
||||
Unlike `hclust()`, the k means algorithm does not provide an intuitive visual interface. Instead, `kmeans()` returns a kmeans class object. Subset the object with `$cluster` to access a list of cluster assignments for your dataset, e.g. `iris_kmeans$cluster`. You can visualize the results by mapping them to an aesthetic, or you can apply the results by passing them to dplyr's `group_by()` function.
|
||||
|
||||
```{r}
|
||||
ggplot(small_iris, aes(x = Sepal.Width, y = Sepal.Length)) +
|
||||
geom_point(aes(color = factor(iris_kmeans$cluster)))
|
||||
|
||||
small_iris |>
|
||||
group_by(iris_kmeans$cluster) |>
|
||||
summarise(n_obs = n(), avg_width = mean(Sepal.Width), avg_length = mean(Sepal.Length))
|
||||
```
|
||||
|
||||
|
||||
### Asking questions about clustering
|
||||
|
||||
Ask the same questions about clusters that you find with `hclust()` and `kmeans()` that you would ask about clusters that you find with a graph. Ask yourself:
|
||||
|
||||
* Do the clusters seem to identify real differences between your points? How can you tell?
|
||||
|
||||
* Are the points within each cluster similar in some way?
|
||||
|
||||
* Are the points in separate clusters different in some way?
|
||||
|
||||
* Might there be a mismatch between the number of clusters that you found and the number that exist in real life? Are only a couple of the clusters meaningful? Are there more clusters in the data than you found?
|
||||
|
||||
* How stable are the clusters if you rerun the algorithm?
|
||||
|
||||
Keep in mind that both algorithms _will always_ return a set of clusters, whether your data appears clustered or not. As a result, you should always be skeptical about the results. They can be quite insightful, but there is no reason to treat them as a fact without doing further research.
|
|
@ -1,312 +0,0 @@
|
|||
|
||||
## Heights data
|
||||
|
||||
Have you heard that a relationship exists between your height and your income? It sounds far-fetched---and maybe it is---but many people believe that taller people will be promoted faster and valued more for their work, an effect that increases their income. Could this be true?
|
||||
|
||||
Luckily, it is easy to measure someone's height, as well as their income, which means that we can collect data relevant to the question. In fact, the Bureau of Labor Statistics has been doing this in a controlled way for over 50 years. The BLS [National Longitudinal Surveys (NLS)](https://www.nlsinfo.org/) track the income, education, and life circumstances of a large cohort of Americans across several decades. In case you are wondering just how your tax dollars are being spent, the point of the NLS is not to study the relationship between height and income, that's just a lucky accident.
|
||||
|
||||
A small sample of the full dataset is included in modelr:
|
||||
|
||||
```{r}
|
||||
heights
|
||||
```
|
||||
|
||||
As well as `height` and `income` there are some other variables that might affect someone's income: `age`, `sex`, `race`, years of `education`, and their score on the `afqt` (Armed Forces Qualification Test).
|
||||
|
||||
Now that you have the data, you can visualize the relationship between height and income. But what does the data say? How would you describe the relationship?
|
||||
|
||||
```{r warnings = FALSE}
|
||||
ggplot(heights, aes(height, income)) +
|
||||
geom_point()
|
||||
```
|
||||
|
||||
First, let's address a distraction: the data is censored in an odd way. The y variable is income, which means that there are no y values less than zero. That's not odd. However, there are also no y values above $180,331. In fact, there are a line of unusual values at exactly $180,331. This is because the Bureau of Labor Statistics removed the top 2% of income values and replaced them with the mean value of the top 2% of values, an action that was not designed to enhance the usefulness of the data for data science.
|
||||
|
||||
```{r}
|
||||
n <- nrow(heights)
|
||||
heights <- heights |> filter(income < 150000)
|
||||
nrow(heights) / n
|
||||
```
|
||||
|
||||
I'm going to record the original number of observations in `n`. We'll come back to this every now and then to make sure that we haven't throw out too much of our data.
|
||||
|
||||
Also, you can see that heights have been rounded to the nearest inch so using boxplots will make it easier to see the pattern. We'll also remove the very tall and very short people so we can focus on the most typically heights:
|
||||
|
||||
```{r}
|
||||
heights <- heights |> filter(between(height, 59, 78))
|
||||
nrow(heights) / n
|
||||
|
||||
ggplot(heights, aes(height, income, group = height)) +
|
||||
geom_boxplot()
|
||||
```
|
||||
|
||||
(Throwing away data in the first pass at a model is perfectly acceptable: starting with a simple subset of a problem that you can easily solve is a good general strategy. But in a real analysis, once you've got the first simple model working, you really should come back and all look at the full dataset. Is removing the data still a good idea?)
|
||||
|
||||
You can see there seems to be a fairly weak relationship: as height increase the median wage also seems to increase. But how could we summarise that more quantitiatively?
|
||||
|
||||
## Linear models
|
||||
|
||||
One way is to use a linear model. A linear model is a very broad family of models: it encompasses all models that are a weighted sum of variables.
|
||||
|
||||
The formula specifies a family of models: for example, `income ~ height` describes the family of models specified by `x1 * income + x0`, where `x0` and `x1` are real numbers.
|
||||
|
||||
```{r}
|
||||
income ~ height
|
||||
```
|
||||
|
||||
We fit the model by supplying the family of models (the formula), and the data, to a model fitting function, `lm()`. `lm()` finds the single model in the family of models that is closest to the data:
|
||||
|
||||
```{r}
|
||||
h <- lm(income ~ height, data = heights)
|
||||
h
|
||||
```
|
||||
|
||||
We can extract the coefficients of this fitted model and write down the model it specifies:
|
||||
|
||||
```{r}
|
||||
coef(h)
|
||||
```
|
||||
|
||||
This tells says the model is $`r coef(h)[1]` + `r coef(h)[2]` * height$. In other words, one inch increase of height associated with an increase of \$937 in income.
|
||||
|
||||
|
||||
The definition that `lm()` uses for closeness is that it looks for a model that minimises the "root mean squared error".
|
||||
|
||||
`lm()` fits a straight line that describes the relationship between the variables in your formula. You can picture the result visually like this.
|
||||
|
||||
```{r}
|
||||
ggplot(heights, aes(height, income)) +
|
||||
geom_boxplot(aes(group = height)) +
|
||||
geom_smooth(method = lm, se = FALSE)
|
||||
```
|
||||
|
||||
`lm()` treats the variable(s) on the right-hand side of the formula as _explanatory variables_ that partially determine the value of the variable on the left-hand side of the formula, which is known as the _response variable_. In other words, it acts as if the _response variable_ is determined by a function of the _explanatory variables_. Linear regression is _linear_ because it finds the linear combination of the explanatory variables that best predict the response.
|
||||
|
||||
|
||||
### Exercises
|
||||
|
||||
1. What variables in `heights` do you expect to be most highly correlated with
|
||||
income? Use `cor()` plus `purrr::map_dbl()` to check your guesses.
|
||||
|
||||
1. Correlation only summarises the linear relationship between two continuous
|
||||
variables. There are some famous drawbacks to the correlation. What
|
||||
are they? Hint: google for Anscombe's quartet, read <https://xkcd.com/552/>.
|
||||
|
||||
### Categorical
|
||||
|
||||
Our model so far is extremely simple: it only uses one variable to try and predict income. We also know something else important: women tend to be shorter than men and tend to get paid less.
|
||||
|
||||
```{r}
|
||||
ggplot(heights, aes(height, colour = sex)) +
|
||||
geom_freqpoly(binwidth = 1)
|
||||
ggplot(heights, aes(income, colour = sex)) +
|
||||
geom_freqpoly(binwidth = 5000)
|
||||
```
|
||||
|
||||
What happens if we also include `sex` in the model?
|
||||
|
||||
```{r}
|
||||
h2 <- lm(income ~ height * sex, data = heights)
|
||||
grid <- heights |>
|
||||
expand(height, sex) |>
|
||||
add_predictions(h2, "income")
|
||||
|
||||
ggplot(heights, aes(height, income)) +
|
||||
geom_point() +
|
||||
geom_line(data = grid) +
|
||||
facet_wrap(~sex)
|
||||
```
|
||||
|
||||
Need to commment about predictions for tall women and short men - there is not a lot of data there. Need to be particularly sceptical.
|
||||
|
||||
`*` vs `+`.
|
||||
|
||||
```{r}
|
||||
h3 <- lm(income ~ height + sex, data = heights)
|
||||
grid <- heights |>
|
||||
expand(height, sex) |>
|
||||
gather_predictions(h2, h3)
|
||||
|
||||
ggplot(grid, aes(height, pred, colour = sex)) +
|
||||
geom_line() +
|
||||
facet_wrap(~model)
|
||||
```
|
||||
|
||||
### Continuous
|
||||
|
||||
There appears to be a relationship between a person's education and how poorly the model predicts their income. If we graph the model residuals against `education` above, we see that the more a person is educated, the worse the model underestimates their income:
|
||||
|
||||
But before we add a variable to our model, we need to do a little EDA + cleaning:
|
||||
|
||||
```{r}
|
||||
ggplot(heights, aes(education)) + geom_bar()
|
||||
heights_ed <- heights |> filter(education >= 12)
|
||||
nrow(heights) / n
|
||||
```
|
||||
|
||||
We could improve the model by adding education:
|
||||
|
||||
```{r}
|
||||
he1 <- lm(income ~ height + education, data = heights_ed)
|
||||
he2 <- lm(income ~ height * education, data = heights_ed)
|
||||
```
|
||||
|
||||
How can we visualise the results of this model? One way to think about it as a surface: we have a 2d grid of height and education, and point on that grid gets a predicted income.
|
||||
|
||||
```{r}
|
||||
grid <- heights_ed |>
|
||||
expand(height, education) |>
|
||||
gather_predictions(he1, he2)
|
||||
|
||||
ggplot(grid, aes(height, education, fill = pred)) +
|
||||
geom_raster() +
|
||||
facet_wrap(~model)
|
||||
```
|
||||
|
||||
It's easier to see what's going on in a line plot:
|
||||
|
||||
```{r}
|
||||
ggplot(grid, aes(height, pred, group = education)) +
|
||||
geom_line() +
|
||||
facet_wrap(~model)
|
||||
ggplot(grid, aes(education, pred, group = height)) +
|
||||
geom_line() +
|
||||
facet_wrap(~model)
|
||||
```
|
||||
|
||||
One of the big advantages to `+` instead of `*` is that because the terms are independent we display them using two simple plots instead of one complex plot:
|
||||
|
||||
```{r}
|
||||
heights_ed |>
|
||||
expand(
|
||||
height = seq_range(height, 10),
|
||||
education = mean(education, na.rm = TRUE)
|
||||
) |>
|
||||
add_predictions(he1, "income") |>
|
||||
ggplot(aes(height, income)) +
|
||||
geom_line()
|
||||
|
||||
heights_ed |>
|
||||
expand(
|
||||
height = mean(height, na.rm = TRUE),
|
||||
education = seq_range(education, 10)
|
||||
) |>
|
||||
add_predictions(he1, "income") |>
|
||||
ggplot(aes(education, income)) +
|
||||
geom_line()
|
||||
```
|
||||
|
||||
The full interaction suggests that height matters less as education increases. But which model is "better"? We'll come back to that question later.
|
||||
|
||||
What happens if we add the data back in to the plot? Do you get more or less sceptical about the results from this model?
|
||||
|
||||
You can imagine that if you had a model with four continuous predictions all interacting, that it would be pretty complicated to understand what's going in the model! And certainly you don't have to - it's totally fine to use a model simply as a tool for predicting new values, and in the next chapters you'll learn some techniques to help evaluate such models without looking at them. However, I think the more you can connect your understand of the domain to the model, the more likely you are to detect potential problems before they occur. The goal is not to undertand every last nuance of the model, but instead to understand more than what you did previously.
|
||||
|
||||
condvis.
|
||||
|
||||
### Categorical
|
||||
|
||||
|
||||
```{r}
|
||||
s <- lm(income ~ sex, data = heights)
|
||||
tidy(s)
|
||||
```
|
||||
|
||||
Every level of the factor except one receives its own coefficient. The missing level acts as a baseline.
|
||||
|
||||
To change the baseline, create a new factor with a new levels attribute. R will use the first level in the levels attribute as the baseline.
|
||||
|
||||
```{r}
|
||||
heights$sex <- factor(heights$sex, levels = c("male", "female"))
|
||||
```
|
||||
|
||||
```{r}
|
||||
hes <- lm(income ~ height + education + sex, data = heights)
|
||||
tidy(hes)
|
||||
```
|
||||
|
||||
```{r}
|
||||
heights |>
|
||||
group_by(sex) |>
|
||||
do(glance(lm(income ~ height, data = .)))
|
||||
```
|
||||
|
||||
```{r}
|
||||
hes2 <- lm(income ~ height + education * sex, data = heights)
|
||||
tidy(hes2)
|
||||
```
|
||||
|
||||
### Splines
|
||||
|
||||
But what if the relationship between variables is not linear? For example, the relationship between income and education does not seem to be linear:
|
||||
|
||||
```{r}
|
||||
ggplot(heights_ed, aes(education, income)) +
|
||||
geom_boxplot(aes(group = education)) +
|
||||
geom_smooth(se = FALSE)
|
||||
```
|
||||
|
||||
One way to introduce non-linearity into our model is to use transformed variants of the predictors.
|
||||
|
||||
```{r}
|
||||
mod_e1 <- lm(income ~ education, data = heights_ed)
|
||||
mod_e2 <- lm(income ~ education + I(education ^ 2) + I(education ^ 3), data = heights_ed)
|
||||
|
||||
heights_ed |>
|
||||
expand(education) |>
|
||||
gather_predictions(mod_e1, mod_e2) |>
|
||||
ggplot(aes(education, pred, colour = model)) +
|
||||
geom_point() +
|
||||
geom_line()
|
||||
```
|
||||
|
||||
This is a bit clunky because we have to surround each transformation with `I()`. This is because the rules of model algebra are a little different to usual algebra. `x ^ 2` is equivalent to `x * x` which in the modelling algebra is equivalent to `x + x + x:x` which is the same as `x`. This is useful because `(x + y + z)^2` fit all all major terms and second order interactions of x, y, and z.
|
||||
|
||||
```{r}
|
||||
mod_e1 <- lm(income ~ education, data = heights_ed)
|
||||
mod_e2 <- lm(income ~ poly(education, 2), data = heights_ed)
|
||||
mod_e3 <- lm(income ~ poly(education, 3), data = heights_ed)
|
||||
|
||||
heights_ed |>
|
||||
expand(education) |>
|
||||
gather_predictions(mod_e1, mod_e2, mod_e3) |>
|
||||
ggplot(aes(education, pred, colour = model)) +
|
||||
geom_point() +
|
||||
geom_line()
|
||||
```
|
||||
|
||||
However: there's one major problem with using `poly()`: outside the range of the data, polynomials are going to rapidly shoot off to positive or negative infinity.
|
||||
|
||||
```{r}
|
||||
tibble(education = seq(5, 25)) |>
|
||||
gather_predictions(mod_e1, mod_e2, mod_e3) |>
|
||||
ggplot(aes(education, pred, colour = model)) +
|
||||
geom_line()
|
||||
```
|
||||
|
||||
Splines avoid this problem by linearly interpolating outside the range of the data. This isn't great either, but it's a safer default when you don't know for sure what's going to happen.
|
||||
|
||||
```{r}
|
||||
library(splines)
|
||||
mod_e1 <- lm(income ~ education, data = heights_ed)
|
||||
mod_e2 <- lm(income ~ ns(education, 2), data = heights_ed)
|
||||
mod_e3 <- lm(income ~ ns(education, 3), data = heights_ed)
|
||||
|
||||
tibble(education = seq(5, 25)) |>
|
||||
gather_predictions(mod_e1, mod_e2, mod_e3) |>
|
||||
ggplot(aes(education, pred, colour = model)) +
|
||||
geom_line()
|
||||
```
|
||||
|
||||
|
||||
### Additive models
|
||||
|
||||
|
||||
```{r, dev = "png"}
|
||||
library(mgcv)
|
||||
gam(income ~ s(education), data = heights)
|
||||
|
||||
ggplot(data = heights, mapping = aes(x = education, y = income)) +
|
||||
geom_point() +
|
||||
geom_smooth(method = gam, formula = y ~ s(x))
|
||||
```
|
|
@ -1,325 +0,0 @@
|
|||
# Model assessment
|
||||
|
||||
In this chapter, you'll turn the tools of multiple models towards model assessment: learning how the model performs when given new data.
|
||||
So far we've focussed on models as tools for description, using models to help us understand the patterns in the data we have collected so far.
|
||||
But ideally a model will do more than just describe what we have seen so far - it will also help predict what will come next.
|
||||
|
||||
In other words, we want a model that doesn't just perform well on the sample, but also accurately summarises the underlying population.
|
||||
|
||||
In some industries this is primarily the use of models: you spend relatively little time fitting the model compared to how many times you use it.
|
||||
|
||||
There are two basic ways that a model can fail with new data:
|
||||
|
||||
- You can under- or over-fit the model.
|
||||
Underfitting is where you fail to model and important trend: you leave too much in the residuals, and not enough in the model.
|
||||
Overfitting is the opposite: you fit a trend to what is actually random noise: you've too put much model and not left enough in the residuals.
|
||||
Generally overfitting tends to be more of a problem than underfitting.
|
||||
|
||||
- The process that generates the data might change.
|
||||
There's nothing the model can do about this.
|
||||
You can protect yourself against this to some extent by creating models that you understand and applying your knowledge to the problem.
|
||||
Are these fundamentals likely to change?
|
||||
If you have a model that you are going to use again and again for a long time, you need to plan to maintain the model, regularly checking that it still makes sense.
|
||||
i.e. is the population the same?
|
||||
|
||||
<http://research.google.com/pubs/pub43146.html> <http://www.wired.com/2015/10/can-learn-epic-failure-google-flu-trends/>
|
||||
|
||||
The most common problem with a model that causes it to do poorly with new data is overfitting.
|
||||
|
||||
Obviously, there's a bit of a problem here: we don't have new data with which to check the model, and even if we did, we'd presumably use it to make the model better in the first place.
|
||||
One powerful technique of approaches can help us get around this problem: resampling.
|
||||
|
||||
There are two main resampling techniques that we're going to cover.
|
||||
|
||||
- We will use **cross-validation** to assess model quality.
|
||||
In cross-validation, you split the data into test and training sets.
|
||||
You fit the data to the training set, and evaluate it on the test set.
|
||||
This avoids intrinsic bias of using the same data to both fit the model and assess it's quality.
|
||||
However it introduces a new bias: you're not using all the data to fit the model so it's not going to be quite as good as it could be.
|
||||
|
||||
- We will use **boostrapping** to understand how stable (or how variable) the model is.
|
||||
If you sample data from the same population multiple times, how much does your model vary?
|
||||
Instead of going back to collect new data, you can use the best estimate of the population data: the data you've collected so far.
|
||||
The amazing idea of the bootstrap is that you can resample from the data you already have.
|
||||
|
||||
There are lots of high-level helpers to do these resampling methods in R.
|
||||
We're going to use the tools provided by the modelr package because they are explicit - you'll see exactly what's going on at each step.
|
||||
|
||||
<http://topepo.github.io/caret>.
|
||||
[Applied Predictive Modeling](https://amzn.com/1461468485), by Max Kuhn and Kjell Johnson.
|
||||
|
||||
If you're competing in competitions, like Kaggle, that are predominantly about creating good predictions, developing a good strategy for avoiding overfitting is very important.
|
||||
Otherwise you risk tricking yourself into thinking that you have a good model, when in reality you just have a model that does a good job of fitting your data.
|
||||
|
||||
There is a closely related family that uses a similar idea: model ensembles.
|
||||
However, instead of trying to find the best models, ensembles make use of all the models, acknowledging that even models that don't fit all the data particularly well can still model some subsets well.
|
||||
In general, you can think of model ensemble techniques as functions that take a list of models, and a return a single model that attempts to take the best part of each.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
```{r setup, message = FALSE}
|
||||
# Standard data manipulation and visulisation
|
||||
library(dplyr)
|
||||
library(ggplot2)
|
||||
|
||||
# Tools for working with models
|
||||
library(broom)
|
||||
library(modelr)
|
||||
library(splines)
|
||||
|
||||
# Tools for working with lots of models
|
||||
library(purrr)
|
||||
library(tidyr)
|
||||
```
|
||||
|
||||
```{r}
|
||||
# Options that make your life easier
|
||||
options(
|
||||
contrasts = c("contr.treatment", "contr.treatment"),
|
||||
na.option = na.exclude
|
||||
)
|
||||
```
|
||||
|
||||
## Overfitting
|
||||
|
||||
Both bootstrapping and cross-validation help us to spot and remedy the problem of **over fitting**, where the model fits the data we've seen so far extremely well, but does a bad job of generalising to new data.
|
||||
|
||||
A classic example of over-fitting is to using a polynomial with too many degrees of freedom.
|
||||
|
||||
Bias - variance tradeoff.
|
||||
Simpler = more biased.
|
||||
Complex = more variable.
|
||||
Occam's razor.
|
||||
|
||||
```{r}
|
||||
true_model <- function(x) {
|
||||
1 + 2 * x + rnorm(length(x), sd = 0.25)
|
||||
}
|
||||
|
||||
df <- tibble(
|
||||
x = seq(0, 1, length = 20),
|
||||
y = true_model(x)
|
||||
)
|
||||
|
||||
df |>
|
||||
ggplot(aes(x, y)) +
|
||||
geom_point()
|
||||
```
|
||||
|
||||
We can create a model that fits this data very well:
|
||||
|
||||
```{r, message = FALSE}
|
||||
library(splines)
|
||||
my_model <- function(df) {
|
||||
lm(y ~ poly(x, 7), data = df)
|
||||
}
|
||||
|
||||
mod <- my_model(df)
|
||||
rmse(mod, df)
|
||||
|
||||
grid <- df |>
|
||||
expand(x = seq_range(x, 50))
|
||||
preds <- grid |>
|
||||
add_predictions(mod, var = "y")
|
||||
|
||||
df |>
|
||||
ggplot(aes(x, y)) +
|
||||
geom_line(data = preds) +
|
||||
geom_point()
|
||||
```
|
||||
|
||||
As we fit progressively more and more complicated models, the model error decreases:
|
||||
|
||||
```{r}
|
||||
fs <- list(
|
||||
y ~ x,
|
||||
y ~ poly(x, 2),
|
||||
y ~ poly(x, 3),
|
||||
y ~ poly(x, 4),
|
||||
y ~ poly(x, 5),
|
||||
y ~ poly(x, 6),
|
||||
y ~ poly(x, 7)
|
||||
)
|
||||
|
||||
models <- tibble(
|
||||
n = 1:7,
|
||||
f = fs,
|
||||
mod = map(f, lm, data = df),
|
||||
rmse = map2_dbl(mod, list(df), rmse)
|
||||
)
|
||||
|
||||
models |>
|
||||
ggplot(aes(n, rmse)) +
|
||||
geom_line(colour = "grey70") +
|
||||
geom_point(size = 3)
|
||||
```
|
||||
|
||||
But do you think this model will do well if we apply it to new data from the same population?
|
||||
|
||||
In real-life you can't easily go out and recollect your data.
|
||||
There are two approaches to help you get around this problem.
|
||||
I'll introduce them briefly here, and then we'll go into more depth in the following sections.
|
||||
|
||||
```{r}
|
||||
boot <- bootstrap(df, 100) |>
|
||||
mutate(
|
||||
mod = map(strap, my_model),
|
||||
pred = map2(list(grid), mod, add_predictions)
|
||||
)
|
||||
|
||||
boot |>
|
||||
unnest(pred) |>
|
||||
ggplot(aes(x, pred, group = .id)) +
|
||||
geom_line(alpha = 1/3)
|
||||
```
|
||||
|
||||
It's a little easier to see what's going on if we zoom on the y axis:
|
||||
|
||||
```{r}
|
||||
last_plot() +
|
||||
coord_cartesian(ylim = c(0, 5))
|
||||
```
|
||||
|
||||
(You might notice that while each individual model varies a lot, the average of all the models seems like it might not be that bad. That gives rise to a model ensemble technique called model averaging.)
|
||||
|
||||
Bootstrapping is a useful tool to help us understand how the model might vary if we'd collected a different sample from the population.
|
||||
A related technique is cross-validation which allows us to explore the quality of the model.
|
||||
It works by repeatedly splitting the data into two pieces.
|
||||
One piece, the training set, is used to fit, and the other piece, the test set, is used to measure the model quality.
|
||||
|
||||
The following code generates 100 test-training splits, holding out 20% of the data for testing each time.
|
||||
We then fit a model to the training set, and evaluate the error on the test set:
|
||||
|
||||
```{r}
|
||||
cv <- crossv_mc(df, 100) |>
|
||||
mutate(
|
||||
mod = map(train, my_model),
|
||||
rmse = map2_dbl(mod, test, rmse)
|
||||
)
|
||||
cv
|
||||
```
|
||||
|
||||
Obviously, a plot is going to help us see distribution more easily.
|
||||
I've added our original estimate of the model error as a white vertical line (where the same dataset is used for both training and testing), and you can see it's very optimistic.
|
||||
|
||||
```{r}
|
||||
cv |>
|
||||
ggplot(aes(rmse)) +
|
||||
geom_ref_line(v = rmse(mod, df)) +
|
||||
geom_freqpoly(binwidth = 0.2) +
|
||||
geom_rug()
|
||||
```
|
||||
|
||||
The distribution of errors is highly skewed: there are a few cases which have very high errors.
|
||||
These represent samples where we ended up with a few cases on all with low values or high values of x.
|
||||
Let's take a look:
|
||||
|
||||
```{r}
|
||||
filter(cv, rmse > 1.5) |>
|
||||
unnest(map(train, as.data.frame)) |>
|
||||
ggplot(aes(x, .id)) +
|
||||
geom_point() +
|
||||
xlim(0, 1)
|
||||
```
|
||||
|
||||
All of the models that fit particularly poorly were fit to samples that either missed the first one or two or the last one or two observation.
|
||||
Because polynomials shoot off to positive and negative, they give very bad predictions for those values.
|
||||
|
||||
Now that we've given you a quick overview and intuition for these techniques, let's dive in more detail.
|
||||
|
||||
## Resamples
|
||||
|
||||
### Building blocks
|
||||
|
||||
Both the boostrap and cross-validation are built on top of a "resample" object.
|
||||
In modelr, you can access these low-level tools directly with the `resample_*` functions.
|
||||
|
||||
These functions return an object of class "resample", which represents the resample in a memory efficient way.
|
||||
Instead of storing the resampled dataset itself, it instead stores the integer indices, and a "pointer" to the original dataset.
|
||||
This makes resamples take up much less memory.
|
||||
|
||||
```{r}
|
||||
x <- resample_bootstrap(as_tibble(mtcars))
|
||||
class(x)
|
||||
|
||||
x
|
||||
```
|
||||
|
||||
Most modelling functions call `as.data.frame()` on the `data` argument.
|
||||
This generates a resampled data frame.
|
||||
Because it's called automatically you can just pass the object.
|
||||
|
||||
```{r}
|
||||
lm(mpg ~ wt, data = x)
|
||||
```
|
||||
|
||||
If you get a strange error, it's probably because the modelling function doesn't do this, and you need to do it yourself.
|
||||
You'll also need to do it yourself if you want to `unnest()` the data so you can visualise it.
|
||||
If you want to just get the rows selected, you can use `as.integer()`.
|
||||
|
||||
### Dataframe API
|
||||
|
||||
`bootstrap()` and `crossv_mc()` are built on top of these simpler primitives.
|
||||
They are designed to work naturally in a model exploration environment by returning data frames.
|
||||
Each row of the data frame represents a single sample.
|
||||
They return slightly different columns:
|
||||
|
||||
- `boostrap()` returns a data frame with two columns:
|
||||
|
||||
```{r}
|
||||
bootstrap(df, 3)
|
||||
```
|
||||
|
||||
`strap` gives the bootstrap sample dataset, and `.id` assigns a unique identifier to each model (this is often useful for plotting)
|
||||
|
||||
- `crossv_mc()` return a data frame with three columns:
|
||||
|
||||
```{r}
|
||||
crossv_mc(df, 3)
|
||||
```
|
||||
|
||||
`train` contains the data that you should use to fit (train) the model, and `test` contains the data you should use to validate the model.
|
||||
Together, the test and train columns form an exclusive partition of the full dataset.
|
||||
|
||||
## Numeric summaries of model quality
|
||||
|
||||
When you start dealing with many models, it's helpful to have some rough way of comparing them so you can spend your time looking at the models that do the best job of capturing important features in the data.
|
||||
|
||||
One way to capture the quality of the model is to summarise the distribution of the residuals.
|
||||
For example, you could look at the quantiles of the absolute residuals.
|
||||
For this dataset, 25% of predictions are less than \$7,400 away, and 75% are less than \$25,800 away.
|
||||
That seems like quite a bit of error when predicting someone's income!
|
||||
|
||||
```{r}
|
||||
heights <- tibble(readRDS("data/heights.RDS"))
|
||||
h <- lm(income ~ height, data = heights)
|
||||
h
|
||||
|
||||
qae(h, heights)
|
||||
range(heights$income)
|
||||
```
|
||||
|
||||
You might be familiar with the $R^2$.
|
||||
That's a single number summary that rescales the variance of the residuals to between 0 (very bad) and 1 (very good):
|
||||
|
||||
```{r}
|
||||
rsquare(h, heights)
|
||||
```
|
||||
|
||||
$R^2$ can be interpreted as the amount of variation in the data explained by the model.
|
||||
Here we're explaining 3% of the total variation - not a lot!
|
||||
But I don't think worrying about the relative amount of variation explained is that useful; instead I think you need to consider whether the absolute amount of variation explained is useful for your project.
|
||||
|
||||
It's called the $R^2$ because for simple models like this, it's just the square of the correlation between the variables:
|
||||
|
||||
```{r}
|
||||
cor(heights$income, heights$height) ^ 2
|
||||
```
|
||||
|
||||
The $R^2$ is an ok single number summary, but I prefer to think about the unscaled residuals because it's easier to interpret in the context of the original data.
|
||||
As you'll also learn later, it's also a rather optimistic interpretation of the model.
|
||||
Because you're assessing the model using the same data that was used to fit it, it really gives more of an upper bound on the quality of the model, not a fair assessment.
|
||||
|
||||
## Bootstrapping
|
||||
|
||||
## Cross-validation
|
|
@ -1,785 +0,0 @@
|
|||
# Model basics
|
||||
|
||||
## Introduction
|
||||
|
||||
The goal of a model is to provide a simple low-dimensional summary of a dataset.
|
||||
In the context of this book we're going to use models to partition data into patterns and residuals.
|
||||
Strong patterns will hide subtler trends, so we'll use models to help peel back layers of structure as we explore a dataset.
|
||||
|
||||
However, before we can start using models on interesting, real datasets, you need to understand the basics of how models work.
|
||||
For that reason, this chapter of the book is unique because it uses only simulated datasets.
|
||||
These datasets are very simple, and not at all interesting, but they will help you understand the essence of modelling before you apply the same techniques to real data in the next chapter.
|
||||
|
||||
There are two parts to a model:
|
||||
|
||||
1. First, you define a **family of models** that express a precise, but generic, pattern that you want to capture.
|
||||
For example, the pattern might be a straight line, or a quadratic curve.
|
||||
You will express the model family as an equation like `y = a_1 * x + a_2` or `y = a_1 * x ^ a_2`.
|
||||
Here, `x` and `y` are known variables from your data, and `a_1` and `a_2` are parameters that can vary to capture different patterns.
|
||||
|
||||
2. Next, you generate a **fitted model** by finding the model from the family that is the closest to your data.
|
||||
This takes the generic model family and makes it specific, like `y = 3 * x + 7` or `y = 9 * x ^ 2`.
|
||||
|
||||
It's important to understand that a fitted model is just the closest model from a family of models.
|
||||
That implies that you have the "best" model (according to some criteria); it doesn't imply that you have a good model and it certainly doesn't imply that the model is "true".
|
||||
George Box puts this well in his famous aphorism:
|
||||
|
||||
> All models are wrong, but some are useful.
|
||||
|
||||
It's worth reading the fuller context of the quote:
|
||||
|
||||
> Now it would be very remarkable if any system existing in the real world could be exactly represented by any simple model.
|
||||
> However, cunningly chosen parsimonious models often do provide remarkably useful approximations.
|
||||
> For example, the law PV = RT relating pressure P, volume V and temperature T of an "ideal" gas via a constant R is not exactly true for any real gas, but it frequently provides a useful approximation and furthermore its structure is informative since it springs from a physical view of the behavior of gas molecules.
|
||||
>
|
||||
> For such a model there is no need to ask the question "Is the model true?".
|
||||
> If "truth" is to be the "whole truth" the answer must be "No".
|
||||
> The only question of interest is "Is the model illuminating and useful?".
|
||||
|
||||
The goal of a model is not to uncover truth, but to discover a simple approximation that is still useful.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
In this chapter we'll use the modelr package which wraps around base R's modelling functions to make them work naturally in a pipe.
|
||||
|
||||
```{r setup, message = FALSE}
|
||||
library(tidyverse)
|
||||
|
||||
library(modelr)
|
||||
options(na.action = na.warn)
|
||||
```
|
||||
|
||||
## A simple model
|
||||
|
||||
Lets take a look at the simulated dataset `sim1`, included with the modelr package.
|
||||
It contains two continuous variables, `x` and `y`.
|
||||
Let's plot them to see how they're related:
|
||||
|
||||
```{r}
|
||||
ggplot(sim1, aes(x, y)) +
|
||||
geom_point()
|
||||
```
|
||||
|
||||
You can see a strong pattern in the data.
|
||||
Let's use a model to capture that pattern and make it explicit.
|
||||
It's our job to supply the basic form of the model.
|
||||
In this case, the relationship looks linear, i.e. `y = a_0 + a_1 * x`.
|
||||
Let's start by getting a feel for what models from that family look like by randomly generating a few and overlaying them on the data.
|
||||
For this simple case, we can use `geom_abline()` which takes a slope and intercept as parameters.
|
||||
Later on we'll learn more general techniques that work with any model.
|
||||
|
||||
```{r}
|
||||
models <- tibble(
|
||||
a1 = runif(250, -20, 40),
|
||||
a2 = runif(250, -5, 5)
|
||||
)
|
||||
|
||||
ggplot(sim1, aes(x, y)) +
|
||||
geom_abline(aes(intercept = a1, slope = a2), data = models, alpha = 1/4) +
|
||||
geom_point()
|
||||
```
|
||||
|
||||
There are 250 models on this plot, but a lot are really bad!
|
||||
We need to find the good models by making precise our intuition that a good model is "close" to the data.
|
||||
We need a way to quantify the distance between the data and a model.
|
||||
Then we can fit the model by finding the value of `a_0` and `a_1` that generate the model with the smallest distance from this data.
|
||||
|
||||
One easy place to start is to find the vertical distance between each point and the model, as in the following diagram.
|
||||
(Note that I've shifted the x values slightly so you can see the individual distances.)
|
||||
|
||||
```{r, echo = FALSE}
|
||||
dist1 <- sim1 |>
|
||||
mutate(
|
||||
dodge = rep(c(-1, 0, 1) / 20, 10),
|
||||
x1 = x + dodge,
|
||||
pred = 7 + x1 * 1.5
|
||||
)
|
||||
|
||||
ggplot(dist1, aes(x1, y)) +
|
||||
geom_abline(intercept = 7, slope = 1.5, colour = "grey40") +
|
||||
geom_point(colour = "grey40") +
|
||||
geom_linerange(aes(ymin = y, ymax = pred), colour = "#3366FF")
|
||||
```
|
||||
|
||||
This distance is just the difference between the y value given by the model (the **prediction**), and the actual y value in the data (the **response**).
|
||||
|
||||
To compute this distance, we first turn our model family into an R function.
|
||||
This takes the model parameters and the data as inputs, and gives values predicted by the model as output:
|
||||
|
||||
```{r}
|
||||
model1 <- function(a, data) {
|
||||
a[1] + data$x * a[2]
|
||||
}
|
||||
model1(c(7, 1.5), sim1)
|
||||
```
|
||||
|
||||
Next, we need some way to compute an overall distance between the predicted and actual values.
|
||||
In other words, the plot above shows 30 distances: how do we collapse that into a single number?
|
||||
|
||||
One common way to do this in statistics is to use the "root-mean-squared deviation".
|
||||
We compute the difference between actual and predicted, square them, average them, and then take the square root.
|
||||
This distance has lots of appealing mathematical properties, which we're not going to talk about here.
|
||||
You'll just have to take my word for it!
|
||||
|
||||
```{r}
|
||||
measure_distance <- function(mod, data) {
|
||||
diff <- data$y - model1(mod, data)
|
||||
sqrt(mean(diff ^ 2))
|
||||
}
|
||||
measure_distance(c(7, 1.5), sim1)
|
||||
```
|
||||
|
||||
Now we can use purrr to compute the distance for all the models defined above.
|
||||
We need a helper function because our distance function expects the model as a numeric vector of length 2.
|
||||
|
||||
```{r}
|
||||
sim1_dist <- function(a1, a2) {
|
||||
measure_distance(c(a1, a2), sim1)
|
||||
}
|
||||
|
||||
models <- models |>
|
||||
mutate(dist = purrr::map2_dbl(a1, a2, sim1_dist))
|
||||
models
|
||||
```
|
||||
|
||||
Next, let's overlay the 10 best models on to the data.
|
||||
I've coloured the models by `-dist`: this is an easy way to make sure that the best models (i.e. the ones with the smallest distance) get the brighest colours.
|
||||
|
||||
```{r}
|
||||
ggplot(sim1, aes(x, y)) +
|
||||
geom_point(size = 2, colour = "grey30") +
|
||||
geom_abline(
|
||||
aes(intercept = a1, slope = a2, colour = -dist),
|
||||
data = filter(models, rank(dist) <= 10)
|
||||
)
|
||||
```
|
||||
|
||||
We can also think about these models as observations, and visualising with a scatterplot of `a1` vs `a2`, again coloured by `-dist`.
|
||||
We can no longer directly see how the model compares to the data, but we can see many models at once.
|
||||
Again, I've highlighted the 10 best models, this time by drawing red circles underneath them.
|
||||
|
||||
```{r}
|
||||
ggplot(models, aes(a1, a2)) +
|
||||
geom_point(data = filter(models, rank(dist) <= 10), size = 4, colour = "red") +
|
||||
geom_point(aes(colour = -dist))
|
||||
```
|
||||
|
||||
Instead of trying lots of random models, we could be more systematic and generate an evenly spaced grid of points (this is called a grid search).
|
||||
I picked the parameters of the grid roughly by looking at where the best models were in the plot above.
|
||||
|
||||
```{r}
|
||||
grid <- expand.grid(
|
||||
a1 = seq(-5, 20, length = 25),
|
||||
a2 = seq(1, 3, length = 25)
|
||||
) |>
|
||||
mutate(dist = purrr::map2_dbl(a1, a2, sim1_dist))
|
||||
|
||||
grid |>
|
||||
ggplot(aes(a1, a2)) +
|
||||
geom_point(data = filter(grid, rank(dist) <= 10), size = 4, colour = "red") +
|
||||
geom_point(aes(colour = -dist))
|
||||
```
|
||||
|
||||
When you overlay the best 10 models back on the original data, they all look pretty good:
|
||||
|
||||
```{r}
|
||||
ggplot(sim1, aes(x, y)) +
|
||||
geom_point(size = 2, colour = "grey30") +
|
||||
geom_abline(
|
||||
aes(intercept = a1, slope = a2, colour = -dist),
|
||||
data = filter(grid, rank(dist) <= 10)
|
||||
)
|
||||
```
|
||||
|
||||
You could imagine iteratively making the grid finer and finer until you narrowed in on the best model.
|
||||
But there's a better way to tackle that problem: a numerical minimisation tool called Newton-Raphson search.
|
||||
The intuition of Newton-Raphson is pretty simple: you pick a starting point and look around for the steepest slope.
|
||||
You then ski down that slope a little way, and then repeat again and again, until you can't go any lower.
|
||||
In R, we can do that with `optim()`:
|
||||
|
||||
```{r}
|
||||
best <- optim(c(0, 0), measure_distance, data = sim1)
|
||||
best$par
|
||||
|
||||
ggplot(sim1, aes(x, y)) +
|
||||
geom_point(size = 2, colour = "grey30") +
|
||||
geom_abline(intercept = best$par[1], slope = best$par[2])
|
||||
```
|
||||
|
||||
Don't worry too much about the details of how `optim()` works.
|
||||
It's the intuition that's important here.
|
||||
If you have a function that defines the distance between a model and a dataset, an algorithm that can minimise that distance by modifying the parameters of the model can find the best model.
|
||||
The neat thing about this approach is that it will work for any family of models that you can write an equation for.
|
||||
|
||||
There's one more approach that we can use for this model, because it's a special case of a broader family: linear models.
|
||||
A linear model has the general form `y = a_1 + a_2 * x_1 + a_3 * x_2 + ... + a_n * x_(n - 1)`.
|
||||
So this simple model is equivalent to a general linear model where n is 2 and `x_1` is `x`.
|
||||
R has a tool specifically designed for fitting linear models called `lm()`.
|
||||
`lm()` has a special way to specify the model family: formulas.
|
||||
Formulas look like `y ~ x`, which `lm()` will translate to a function like `y = a_1 + a_2 * x`.
|
||||
We can fit the model and look at the output:
|
||||
|
||||
```{r}
|
||||
sim1_mod <- lm(y ~ x, data = sim1)
|
||||
coef(sim1_mod)
|
||||
```
|
||||
|
||||
These are exactly the same values we got with `optim()`!
|
||||
Behind the scenes `lm()` doesn't use `optim()` but instead takes advantage of the mathematical structure of linear models.
|
||||
Using some connections between geometry, calculus, and linear algebra, `lm()` actually finds the closest model in a single step, using a sophisticated algorithm.
|
||||
This approach is both faster, and guarantees that there is a global minimum.
|
||||
|
||||
### Exercises
|
||||
|
||||
1. One downside of the linear model is that it is sensitive to unusual values because the distance incorporates a squared term.
|
||||
Fit a linear model to the simulated data below, and visualise the results.
|
||||
Rerun a few times to generate different simulated datasets.
|
||||
What do you notice about the model?
|
||||
|
||||
```{r}
|
||||
sim1a <- tibble(
|
||||
x = rep(1:10, each = 3),
|
||||
y = x * 1.5 + 6 + rt(length(x), df = 2)
|
||||
)
|
||||
```
|
||||
|
||||
2. One way to make linear models more robust is to use a different distance measure.
|
||||
For example, instead of root-mean-squared distance, you could use mean-absolute distance:
|
||||
|
||||
```{r}
|
||||
measure_distance <- function(mod, data) {
|
||||
diff <- data$y - model1(mod, data)
|
||||
mean(abs(diff))
|
||||
}
|
||||
```
|
||||
|
||||
Use `optim()` to fit this model to the simulated data above and compare it to the linear model.
|
||||
|
||||
3. One challenge with performing numerical optimisation is that it's only guaranteed to find one local optimum.
|
||||
What's the problem with optimising a three parameter model like this?
|
||||
|
||||
```{r}
|
||||
model1 <- function(a, data) {
|
||||
a[1] + data$x * a[2] + a[3]
|
||||
}
|
||||
```
|
||||
|
||||
## Visualising models
|
||||
|
||||
For simple models, like the one above, you can figure out what pattern the model captures by carefully studying the model family and the fitted coefficients.
|
||||
And if you ever take a statistics course on modelling, you're likely to spend a lot of time doing just that.
|
||||
Here, however, we're going to take a different tack.
|
||||
We're going to focus on understanding a model by looking at its predictions.
|
||||
This has a big advantage: every type of predictive model makes predictions (otherwise what use would it be?) so we can use the same set of techniques to understand any type of predictive model.
|
||||
|
||||
It's also useful to see what the model doesn't capture, the so-called residuals which are left after subtracting the predictions from the data.
|
||||
Residuals are powerful because they allow us to use models to remove striking patterns so we can study the subtler trends that remain.
|
||||
|
||||
### Predictions
|
||||
|
||||
To visualise the predictions from a model, we start by generating an evenly spaced grid of values that covers the region where our data lies.
|
||||
The easiest way to do that is to use `modelr::data_grid()`.
|
||||
Its first argument is a data frame, and for each subsequent argument it finds the unique variables and then generates all combinations:
|
||||
|
||||
```{r}
|
||||
grid <- sim1 |>
|
||||
data_grid(x)
|
||||
grid
|
||||
```
|
||||
|
||||
(This will get more interesting when we start to add more variables to our model.)
|
||||
|
||||
Next we add predictions.
|
||||
We'll use `modelr::add_predictions()` which takes a data frame and a model.
|
||||
It adds the predictions from the model to a new column in the data frame:
|
||||
|
||||
```{r}
|
||||
grid <- grid |>
|
||||
add_predictions(sim1_mod)
|
||||
grid
|
||||
```
|
||||
|
||||
(You can also use this function to add predictions to your original dataset.)
|
||||
|
||||
Next, we plot the predictions.
|
||||
You might wonder about all this extra work compared to just using `geom_abline()`.
|
||||
But the advantage of this approach is that it will work with *any* model in R, from the simplest to the most complex.
|
||||
You're only limited by your visualisation skills.
|
||||
For more ideas about how to visualise more complex model types, you might try <http://vita.had.co.nz/papers/model-vis.html>.
|
||||
|
||||
```{r}
|
||||
ggplot(sim1, aes(x)) +
|
||||
geom_point(aes(y = y)) +
|
||||
geom_line(aes(y = pred), data = grid, colour = "red", size = 1)
|
||||
```
|
||||
|
||||
### Residuals
|
||||
|
||||
The flip-side of predictions are **residuals**.
|
||||
The predictions tell you the pattern that the model has captured, and the residuals tell you what the model has missed.
|
||||
The residuals are just the distances between the observed and predicted values that we computed above.
|
||||
|
||||
We add residuals to the data with `add_residuals()`, which works much like `add_predictions()`.
|
||||
Note, however, that we use the original dataset, not a manufactured grid.
|
||||
This is because to compute residuals we need actual y values.
|
||||
|
||||
```{r}
|
||||
sim1 <- sim1 |>
|
||||
add_residuals(sim1_mod)
|
||||
sim1
|
||||
```
|
||||
|
||||
There are a few different ways to understand what the residuals tell us about the model.
|
||||
One way is to simply draw a frequency polygon to help us understand the spread of the residuals:
|
||||
|
||||
```{r}
|
||||
ggplot(sim1, aes(resid)) +
|
||||
geom_freqpoly(binwidth = 0.5)
|
||||
```
|
||||
|
||||
This helps you calibrate the quality of the model: how far away are the predictions from the observed values?
|
||||
Note that the average of the residual will always be 0.
|
||||
|
||||
You'll often want to recreate plots using the residuals instead of the original predictor.
|
||||
You'll see a lot of that in the next chapter.
|
||||
|
||||
```{r}
|
||||
ggplot(sim1, aes(x, resid)) +
|
||||
geom_ref_line(h = 0) +
|
||||
geom_point()
|
||||
```
|
||||
|
||||
This looks like random noise, suggesting that our model has done a good job of capturing the patterns in the dataset.
|
||||
|
||||
### Exercises
|
||||
|
||||
1. Instead of using `lm()` to fit a straight line, you can use `loess()` to fit a smooth curve.
|
||||
Repeat the process of model fitting, grid generation, predictions, and visualisation on `sim1` using `loess()` instead of `lm()`.
|
||||
How does the result compare to `geom_smooth()`?
|
||||
|
||||
2. `add_predictions()` is paired with `gather_predictions()` and `spread_predictions()`.
|
||||
How do these three functions differ?
|
||||
|
||||
3. What does `geom_ref_line()` do?
|
||||
What package does it come from?
|
||||
Why is displaying a reference line in plots showing residuals useful and important?
|
||||
|
||||
4. Why might you want to look at a frequency polygon of absolute residuals?
|
||||
What are the pros and cons compared to looking at the raw residuals?
|
||||
|
||||
## Formulas and model families
|
||||
|
||||
You've seen formulas before when using `facet_wrap()` and `facet_grid()`.
|
||||
In R, formulas provide a general way of getting "special behaviour".
|
||||
Rather than evaluating the values of the variables right away, they capture them so they can be interpreted by the function.
|
||||
|
||||
The majority of modelling functions in R use a standard conversion from formulas to functions.
|
||||
You've seen one simple conversion already: `y ~ x` is translated to `y = a_1 + a_2 * x`.
|
||||
If you want to see what R actually does, you can use the `model_matrix()` function.
|
||||
It takes a data frame and a formula and returns a tibble that defines the model equation: each column in the output is associated with one coefficient in the model, the function is always `y = a_1 * out_1 + a_2 * out_2`.
|
||||
For the simplest case of `y ~ x1` this shows us something interesting:
|
||||
|
||||
```{r}
|
||||
df <- tribble(
|
||||
~y, ~x1, ~x2,
|
||||
4, 2, 5,
|
||||
5, 1, 6
|
||||
)
|
||||
model_matrix(df, y ~ x1)
|
||||
```
|
||||
|
||||
The way that R adds the intercept to the model is just by having a column that is full of ones.
|
||||
By default, R will always add this column.
|
||||
If you don't want, you need to explicitly drop it with `-1`:
|
||||
|
||||
```{r}
|
||||
model_matrix(df, y ~ x1 - 1)
|
||||
```
|
||||
|
||||
The model matrix grows in an unsurprising way when you add more variables to the model:
|
||||
|
||||
```{r}
|
||||
model_matrix(df, y ~ x1 + x2)
|
||||
```
|
||||
|
||||
This formula notation is sometimes called "Wilkinson-Rogers notation", and was initially described in *Symbolic Description of Factorial Models for Analysis of Variance*, by G.
|
||||
N. Wilkinson and C.
|
||||
E. Rogers <https://www.jstor.org/stable/2346786>.
|
||||
It's worth digging up and reading the original paper if you'd like to understand the full details of the modelling algebra.
|
||||
|
||||
The following sections expand on how this formula notation works for categorical variables, interactions, and transformation.
|
||||
|
||||
### Categorical variables
|
||||
|
||||
Generating a function from a formula is straight forward when the predictor is continuous, but things get a bit more complicated when the predictor is categorical.
|
||||
Imagine you have a formula like `y ~ sex`, where sex could either be male or female.
|
||||
It doesn't make sense to convert that to a formula like `y = a_0 + a_1 * sex` because `sex` isn't a number - you can't multiply it!
|
||||
Instead what R does is convert it to `y = a_0 + a_1 * sex_male` where `sex_male` is one if `sex` is male and zero otherwise:
|
||||
|
||||
```{r}
|
||||
df <- tribble(
|
||||
~ sex, ~ response,
|
||||
"male", 1,
|
||||
"female", 2,
|
||||
"male", 1
|
||||
)
|
||||
model_matrix(df, response ~ sex)
|
||||
```
|
||||
|
||||
You might wonder why R also doesn't create a `sexfemale` column.
|
||||
The problem is that would create a column that is perfectly predictable based on the other columns (i.e. `sexfemale = 1 - sexmale`).
|
||||
Unfortunately the exact details of why this is a problem is beyond the scope of this book, but basically it creates a model family that is too flexible, and will have infinitely many models that are equally close to the data.
|
||||
|
||||
Fortunately, however, if you focus on visualising predictions you don't need to worry about the exact parameterisation.
|
||||
Let's look at some data and models to make that concrete.
|
||||
Here's the `sim2` dataset from modelr:
|
||||
|
||||
```{r}
|
||||
ggplot(sim2) +
|
||||
geom_point(aes(x, y))
|
||||
```
|
||||
|
||||
We can fit a model to it, and generate predictions:
|
||||
|
||||
```{r}
|
||||
mod2 <- lm(y ~ x, data = sim2)
|
||||
|
||||
grid <- sim2 |>
|
||||
data_grid(x) |>
|
||||
add_predictions(mod2)
|
||||
grid
|
||||
```
|
||||
|
||||
Effectively, a model with a categorical `x` will predict the mean value for each category.
|
||||
(Why? Because the mean minimises the root-mean-squared distance.) That's easy to see if we overlay the predictions on top of the original data:
|
||||
|
||||
```{r}
|
||||
ggplot(sim2, aes(x)) +
|
||||
geom_point(aes(y = y)) +
|
||||
geom_point(data = grid, aes(y = pred), colour = "red", size = 4)
|
||||
```
|
||||
|
||||
You can't make predictions about levels that you didn't observe.
|
||||
Sometimes you'll do this by accident so it's good to recognise this error message:
|
||||
|
||||
```{r, error = TRUE}
|
||||
tibble(x = "e") |>
|
||||
add_predictions(mod2)
|
||||
```
|
||||
|
||||
### Interactions (continuous and categorical)
|
||||
|
||||
What happens when you combine a continuous and a categorical variable?
|
||||
`sim3` contains a categorical predictor and a continuous predictor.
|
||||
We can visualise it with a simple plot:
|
||||
|
||||
```{r}
|
||||
ggplot(sim3, aes(x1, y)) +
|
||||
geom_point(aes(colour = x2))
|
||||
```
|
||||
|
||||
There are two possible models you could fit to this data:
|
||||
|
||||
```{r}
|
||||
mod1 <- lm(y ~ x1 + x2, data = sim3)
|
||||
mod2 <- lm(y ~ x1 * x2, data = sim3)
|
||||
```
|
||||
|
||||
When you add variables with `+`, the model will estimate each effect independent of all the others.
|
||||
It's possible to fit the so-called interaction by using `*`.
|
||||
For example, `y ~ x1 * x2` is translated to `y = a_0 + a_1 * x1 + a_2 * x2 + a_12 * x1 * x2`.
|
||||
Note that whenever you use `*`, both the interaction and the individual components are included in the model.
|
||||
|
||||
To visualise these models we need two new tricks:
|
||||
|
||||
1. We have two predictors, so we need to give `data_grid()` both variables.
|
||||
It finds all the unique values of `x1` and `x2` and then generates all combinations.
|
||||
|
||||
2. To generate predictions from both models simultaneously, we can use `gather_predictions()` which adds each prediction as a row.
|
||||
The complement of `gather_predictions()` is `spread_predictions()` which adds each prediction to a new column.
|
||||
|
||||
Together this gives us:
|
||||
|
||||
```{r}
|
||||
grid <- sim3 |>
|
||||
data_grid(x1, x2) |>
|
||||
gather_predictions(mod1, mod2)
|
||||
grid
|
||||
```
|
||||
|
||||
We can visualise the results for both models on one plot using facetting:
|
||||
|
||||
```{r}
|
||||
ggplot(sim3, aes(x1, y, colour = x2)) +
|
||||
geom_point() +
|
||||
geom_line(data = grid, aes(y = pred)) +
|
||||
facet_wrap(~ model)
|
||||
```
|
||||
|
||||
Note that the model that uses `+` has the same slope for each line, but different intercepts.
|
||||
The model that uses `*` has a different slope and intercept for each line.
|
||||
|
||||
Which model is better for this data?
|
||||
We can take look at the residuals.
|
||||
Here I've facetted by both model and `x2` because it makes it easier to see the pattern within each group.
|
||||
|
||||
```{r}
|
||||
sim3 <- sim3 |>
|
||||
gather_residuals(mod1, mod2)
|
||||
|
||||
ggplot(sim3, aes(x1, resid, colour = x2)) +
|
||||
geom_point() +
|
||||
facet_grid(model ~ x2)
|
||||
```
|
||||
|
||||
There is little obvious pattern in the residuals for `mod2`.
|
||||
The residuals for `mod1` show that the model has clearly missed some pattern in `b`, and less so, but still present is pattern in `c`, and `d`.
|
||||
You might wonder if there's a precise way to tell which of `mod1` or `mod2` is better.
|
||||
There is, but it requires a lot of mathematical background, and we don't really care.
|
||||
Here, we're interested in a qualitative assessment of whether or not the model has captured the pattern that we're interested in.
|
||||
|
||||
### Interactions (two continuous)
|
||||
|
||||
Let's take a look at the equivalent model for two continuous variables.
|
||||
Initially things proceed almost identically to the previous example:
|
||||
|
||||
```{r}
|
||||
mod1 <- lm(y ~ x1 + x2, data = sim4)
|
||||
mod2 <- lm(y ~ x1 * x2, data = sim4)
|
||||
|
||||
grid <- sim4 |>
|
||||
data_grid(
|
||||
x1 = seq_range(x1, 5),
|
||||
x2 = seq_range(x2, 5)
|
||||
) |>
|
||||
gather_predictions(mod1, mod2)
|
||||
grid
|
||||
```
|
||||
|
||||
Note my use of `seq_range()` inside `data_grid()`.
|
||||
Instead of using every unique value of `x`, I'm going to use a regularly spaced grid of five values between the minimum and maximum numbers.
|
||||
It's probably not super important here, but it's a useful technique in general.
|
||||
There are two other useful arguments to `seq_range()`:
|
||||
|
||||
- `pretty = TRUE` will generate a "pretty" sequence, i.e. something that looks nice to the human eye.
|
||||
This is useful if you want to produce tables of output:
|
||||
|
||||
```{r}
|
||||
seq_range(c(0.0123, 0.923423), n = 5)
|
||||
seq_range(c(0.0123, 0.923423), n = 5, pretty = TRUE)
|
||||
```
|
||||
|
||||
- `trim = 0.1` will trim off 10% of the tail values.
|
||||
This is useful if the variables have a long tailed distribution and you want to focus on generating values near the center:
|
||||
|
||||
```{r}
|
||||
x1 <- rcauchy(100)
|
||||
seq_range(x1, n = 5)
|
||||
seq_range(x1, n = 5, trim = 0.10)
|
||||
seq_range(x1, n = 5, trim = 0.25)
|
||||
seq_range(x1, n = 5, trim = 0.50)
|
||||
```
|
||||
|
||||
- `expand = 0.1` is in some sense the opposite of `trim()` it expands the range by 10%.
|
||||
|
||||
```{r}
|
||||
x2 <- c(0, 1)
|
||||
seq_range(x2, n = 5)
|
||||
seq_range(x2, n = 5, expand = 0.10)
|
||||
seq_range(x2, n = 5, expand = 0.25)
|
||||
seq_range(x2, n = 5, expand = 0.50)
|
||||
```
|
||||
|
||||
Next let's try and visualise that model.
|
||||
We have two continuous predictors, so you can imagine the model like a 3d surface.
|
||||
We could display that using `geom_tile()`:
|
||||
|
||||
```{r}
|
||||
ggplot(grid, aes(x1, x2)) +
|
||||
geom_tile(aes(fill = pred)) +
|
||||
facet_wrap(~ model)
|
||||
```
|
||||
|
||||
That doesn't suggest that the models are very different!
|
||||
But that's partly an illusion: our eyes and brains are not very good at accurately comparing shades of colour.
|
||||
Instead of looking at the surface from the top, we could look at it from either side, showing multiple slices:
|
||||
|
||||
```{r, asp = 1/2}
|
||||
ggplot(grid, aes(x1, pred, colour = x2, group = x2)) +
|
||||
geom_line() +
|
||||
facet_wrap(~ model)
|
||||
ggplot(grid, aes(x2, pred, colour = x1, group = x1)) +
|
||||
geom_line() +
|
||||
facet_wrap(~ model)
|
||||
```
|
||||
|
||||
This shows you that interaction between two continuous variables works basically the same way as for a categorical and continuous variable.
|
||||
An interaction says that there's not a fixed offset: you need to consider both values of `x1` and `x2` simultaneously in order to predict `y`.
|
||||
|
||||
You can see that even with just two continuous variables, coming up with good visualisations are hard.
|
||||
But that's reasonable: you shouldn't expect it will be easy to understand how three or more variables simultaneously interact!
|
||||
But again, we're saved a little because we're using models for exploration, and you can gradually build up your model over time.
|
||||
The model doesn't have to be perfect, it just has to help you reveal a little more about your data.
|
||||
|
||||
I spent some time looking at the residuals to see if I could figure if `mod2` did better than `mod1`.
|
||||
I think it does, but it's pretty subtle.
|
||||
You'll have a chance to work on it in the exercises.
|
||||
|
||||
### Transformations
|
||||
|
||||
You can also perform transformations inside the model formula.
|
||||
For example, `log(y) ~ sqrt(x1) + x2` is transformed to `log(y) = a_1 + a_2 * sqrt(x1) + a_3 * x2`.
|
||||
If your transformation involves `+`, `*`, `^`, or `-`, you'll need to wrap it in `I()` so R doesn't treat it like part of the model specification.
|
||||
For example, `y ~ x + I(x ^ 2)` is translated to `y = a_1 + a_2 * x + a_3 * x^2`.
|
||||
If you forget the `I()` and specify `y ~ x ^ 2 + x`, R will compute `y ~ x * x + x`.
|
||||
`x * x` means the interaction of `x` with itself, which is the same as `x`.
|
||||
R automatically drops redundant variables so `x + x` become `x`, meaning that `y ~ x ^ 2 + x` specifies the function `y = a_1 + a_2 * x`.
|
||||
That's probably not what you intended!
|
||||
|
||||
Again, if you get confused about what your model is doing, you can always use `model_matrix()` to see exactly what equation `lm()` is fitting:
|
||||
|
||||
```{r}
|
||||
df <- tribble(
|
||||
~y, ~x,
|
||||
1, 1,
|
||||
2, 2,
|
||||
3, 3
|
||||
)
|
||||
model_matrix(df, y ~ x^2 + x)
|
||||
model_matrix(df, y ~ I(x^2) + x)
|
||||
```
|
||||
|
||||
Transformations are useful because you can use them to approximate non-linear functions.
|
||||
If you've taken a calculus class, you may have heard of Taylor's theorem which says you can approximate any smooth function with an infinite sum of polynomials.
|
||||
That means you can use a polynomial function to get arbitrarily close to a smooth function by fitting an equation like `y = a_1 + a_2 * x + a_3 * x^2 + a_4 * x ^ 3`.
|
||||
Typing that sequence by hand is tedious, so R provides a helper function: `poly()`:
|
||||
|
||||
```{r}
|
||||
model_matrix(df, y ~ poly(x, 2))
|
||||
```
|
||||
|
||||
However there's one major problem with using `poly()`: outside the range of the data, polynomials rapidly shoot off to positive or negative infinity.
|
||||
One safer alternative is to use the natural spline, `splines::ns()`.
|
||||
|
||||
```{r}
|
||||
library(splines)
|
||||
model_matrix(df, y ~ ns(x, 2))
|
||||
```
|
||||
|
||||
Let's see what that looks like when we try and approximate a non-linear function:
|
||||
|
||||
```{r}
|
||||
sim5 <- tibble(
|
||||
x = seq(0, 3.5 * pi, length = 50),
|
||||
y = 4 * sin(x) + rnorm(length(x))
|
||||
)
|
||||
|
||||
ggplot(sim5, aes(x, y)) +
|
||||
geom_point()
|
||||
```
|
||||
|
||||
I'm going to fit five models to this data.
|
||||
|
||||
```{r}
|
||||
mod1 <- lm(y ~ ns(x, 1), data = sim5)
|
||||
mod2 <- lm(y ~ ns(x, 2), data = sim5)
|
||||
mod3 <- lm(y ~ ns(x, 3), data = sim5)
|
||||
mod4 <- lm(y ~ ns(x, 4), data = sim5)
|
||||
mod5 <- lm(y ~ ns(x, 5), data = sim5)
|
||||
|
||||
grid <- sim5 |>
|
||||
data_grid(x = seq_range(x, n = 50, expand = 0.1)) |>
|
||||
gather_predictions(mod1, mod2, mod3, mod4, mod5, .pred = "y")
|
||||
|
||||
ggplot(sim5, aes(x, y)) +
|
||||
geom_point() +
|
||||
geom_line(data = grid, colour = "red") +
|
||||
facet_wrap(~ model)
|
||||
```
|
||||
|
||||
Notice that the extrapolation outside the range of the data is clearly bad.
|
||||
This is the downside to approximating a function with a polynomial.
|
||||
But this is a very real problem with every model: the model can never tell you if the behaviour is true when you start extrapolating outside the range of the data that you have seen.
|
||||
You must rely on theory and science.
|
||||
|
||||
### Exercises
|
||||
|
||||
1. What happens if you repeat the analysis of `sim2` using a model without an intercept.
|
||||
What happens to the model equation?
|
||||
What happens to the predictions?
|
||||
|
||||
2. Use `model_matrix()` to explore the equations generated for the models I fit to `sim3` and `sim4`.
|
||||
Why is `*` a good shorthand for interaction?
|
||||
|
||||
3. Using the basic principles, convert the formulas in the following two models into functions.
|
||||
(Hint: start by converting the categorical variable into 0-1 variables.)
|
||||
|
||||
```{r, eval = FALSE}
|
||||
mod1 <- lm(y ~ x1 + x2, data = sim3)
|
||||
mod2 <- lm(y ~ x1 * x2, data = sim3)
|
||||
```
|
||||
|
||||
4. For `sim4`, which of `mod1` and `mod2` is better?
|
||||
I think `mod2` does a slightly better job at removing patterns, but it's pretty subtle.
|
||||
Can you come up with a plot to support my claim?
|
||||
|
||||
## Missing values
|
||||
|
||||
Missing values obviously can not convey any information about the relationship between the variables, so modelling functions will drop any rows that contain missing values.
|
||||
R's default behaviour is to silently drop them, but `options(na.action = na.warn)` (run in the prerequisites), makes sure you get a warning.
|
||||
|
||||
```{r}
|
||||
df <- tribble(
|
||||
~x, ~y,
|
||||
1, 2.2,
|
||||
2, NA,
|
||||
3, 3.5,
|
||||
4, 8.3,
|
||||
NA, 10
|
||||
)
|
||||
|
||||
mod <- lm(y ~ x, data = df)
|
||||
```
|
||||
|
||||
To suppress the warning, set `na.action = na.exclude`:
|
||||
|
||||
```{r}
|
||||
mod <- lm(y ~ x, data = df, na.action = na.exclude)
|
||||
```
|
||||
|
||||
You can always see exactly how many observations were used with `nobs()`:
|
||||
|
||||
```{r}
|
||||
nobs(mod)
|
||||
```
|
||||
|
||||
## Other model families
|
||||
|
||||
This chapter has focussed exclusively on the class of linear models, which assume a relationship of the form `y = a_1 * x1 + a_2 * x2 + ... + a_n * xn`.
|
||||
Linear models additionally assume that the residuals have a normal distribution, which we haven't talked about.
|
||||
There are a large set of model classes that extend the linear model in various interesting ways.
|
||||
Some of them are:
|
||||
|
||||
- **Generalised linear models**, e.g. `stats::glm()`.
|
||||
Linear models assume that the response is continuous and the error has a normal distribution.
|
||||
Generalised linear models extend linear models to include non-continuous responses (e.g. binary data or counts).
|
||||
They work by defining a distance metric based on the statistical idea of likelihood.
|
||||
|
||||
- **Generalised additive models**, e.g. `mgcv::gam()`, extend generalised linear models to incorporate arbitrary smooth functions.
|
||||
That means you can write a formula like `y ~ s(x)` which becomes an equation like `y = f(x)` and let `gam()` estimate what that function is (subject to some smoothness constraints to make the problem tractable).
|
||||
|
||||
- **Penalised linear models**, e.g. `glmnet::glmnet()`, add a penalty term to the distance that penalises complex models (as defined by the distance between the parameter vector and the origin).
|
||||
This tends to make models that generalise better to new datasets from the same population.
|
||||
|
||||
- **Robust linear models**, e.g.
|
||||
`MASS::rlm()`, tweak the distance to downweight points that are very far away.
|
||||
This makes them less sensitive to the presence of outliers, at the cost of being not quite as good when there are no outliers.
|
||||
|
||||
- **Trees**, e.g. `rpart::rpart()`, attack the problem in a completely different way than linear models.
|
||||
They fit a piece-wise constant model, splitting the data into progressively smaller and smaller pieces.
|
||||
Trees aren't terribly effective by themselves, but they are very powerful when used in aggregate by models like **random forests** (e.g. `randomForest::randomForest()`) or **gradient boosting machines** (e.g. `xgboost::xgboost`.)
|
||||
|
||||
These models all work similarly from a programming perspective.
|
||||
Once you've mastered linear models, you should find it easy to master the mechanics of these other model classes.
|
||||
Being a skilled modeller is a mixture of some good general principles and having a big toolbox of techniques.
|
||||
Now that you've learned some general tools and one useful class of models, you can go on and learn more classes from other sources.
|
|
@ -1,495 +0,0 @@
|
|||
# Model building
|
||||
|
||||
## Introduction
|
||||
|
||||
In the previous chapter you learned how linear models work, and learned some basic tools for understanding what a model is telling you about your data.
|
||||
The previous chapter focussed on simulated datasets.
|
||||
This chapter will focus on real data, showing you how you can progressively build up a model to aid your understanding of the data.
|
||||
|
||||
We will take advantage of the fact that you can think about a model partitioning your data into pattern and residuals.
|
||||
We'll find patterns with visualisation, then make them concrete and precise with a model.
|
||||
We'll then repeat the process, but replace the old response variable with the residuals from the model.
|
||||
The goal is to transition from implicit knowledge in the data and your head to explicit knowledge in a quantitative model.
|
||||
This makes it easier to apply to new domains, and easier for others to use.
|
||||
|
||||
For very large and complex datasets this will be a lot of work.
|
||||
There are certainly alternative approaches - a more machine learning approach is simply to focus on the predictive ability of the model.
|
||||
These approaches tend to produce black boxes: the model does a really good job at generating predictions, but you don't know why.
|
||||
This is a totally reasonable approach, but it does make it hard to apply your real world knowledge to the model.
|
||||
That, in turn, makes it difficult to assess whether or not the model will continue to work in the long-term, as fundamentals change.
|
||||
For most real models, I'd expect you to use some combination of this approach and a more classic automated approach.
|
||||
|
||||
It's a challenge to know when to stop.
|
||||
You need to figure out when your model is good enough, and when additional investment is unlikely to pay off.
|
||||
I particularly like this quote from reddit user Broseidon241:
|
||||
|
||||
> A long time ago in art class, my teacher told me "An artist needs to know when a piece is done. You can't tweak something into perfection - wrap it up. If you don't like it, do it over again. Otherwise begin something new".
|
||||
> Later in life, I heard "A poor seamstress makes many mistakes. A good seamstress works hard to correct those mistakes. A great seamstress isn't afraid to throw out the garment and start over."
|
||||
>
|
||||
> -- Broseidon241, <https://www.reddit.com/r/datascience/comments/4irajq>
|
||||
|
||||
### Prerequisites
|
||||
|
||||
We'll use the same tools as in the previous chapter, but add in some real datasets: `diamonds` from ggplot2, and `flights` from nycflights13.
|
||||
We'll also need lubridate in order to work with the date/times in `flights`.
|
||||
|
||||
```{r setup, message = FALSE}
|
||||
library(tidyverse)
|
||||
library(modelr)
|
||||
options(na.action = na.warn)
|
||||
|
||||
library(nycflights13)
|
||||
library(lubridate)
|
||||
```
|
||||
|
||||
## Why are low quality diamonds more expensive? {#diamond-prices}
|
||||
|
||||
In previous chapters we've seen a surprising relationship between the quality of diamonds and their price: low quality diamonds (poor cuts, bad colours, and inferior clarity) have higher prices.
|
||||
|
||||
```{r dev = "png"}
|
||||
ggplot(diamonds, aes(cut, price)) + geom_boxplot()
|
||||
ggplot(diamonds, aes(color, price)) + geom_boxplot()
|
||||
ggplot(diamonds, aes(clarity, price)) + geom_boxplot()
|
||||
```
|
||||
|
||||
Note that the worst diamond color is J (slightly yellow), and the worst clarity is I1 (inclusions visible to the naked eye).
|
||||
|
||||
### Price and carat
|
||||
|
||||
It looks like lower quality diamonds have higher prices because there is an important confounding variable: the weight (`carat`) of the diamond.
|
||||
The weight of the diamond is the single most important factor for determining the price of the diamond, and lower quality diamonds tend to be larger.
|
||||
|
||||
```{r}
|
||||
ggplot(diamonds, aes(carat, price)) +
|
||||
geom_hex(bins = 50)
|
||||
```
|
||||
|
||||
We can make it easier to see how the other attributes of a diamond affect its relative `price` by fitting a model to separate out the effect of `carat`.
|
||||
But first, lets make a couple of tweaks to the diamonds dataset to make it easier to work with:
|
||||
|
||||
1. Focus on diamonds smaller than 2.5 carats (99.7% of the data)
|
||||
2. Log-transform the carat and price variables.
|
||||
|
||||
```{r}
|
||||
diamonds2 <- diamonds |>
|
||||
filter(carat <= 2.5) |>
|
||||
mutate(lprice = log2(price), lcarat = log2(carat))
|
||||
```
|
||||
|
||||
Together, these changes make it easier to see the relationship between `carat` and `price`:
|
||||
|
||||
```{r}
|
||||
ggplot(diamonds2, aes(lcarat, lprice)) +
|
||||
geom_hex(bins = 50)
|
||||
```
|
||||
|
||||
The log-transformation is particularly useful here because it makes the pattern linear, and linear patterns are the easiest to work with.
|
||||
Let's take the next step and remove that strong linear pattern.
|
||||
We first make the pattern explicit by fitting a model:
|
||||
|
||||
```{r}
|
||||
mod_diamond <- lm(lprice ~ lcarat, data = diamonds2)
|
||||
```
|
||||
|
||||
Then we look at what the model tells us about the data.
|
||||
Note that I back transform the predictions, undoing the log transformation, so I can overlay the predictions on the raw data:
|
||||
|
||||
```{r}
|
||||
grid <- diamonds2 |>
|
||||
data_grid(carat = seq_range(carat, 20)) |>
|
||||
mutate(lcarat = log2(carat)) |>
|
||||
add_predictions(mod_diamond, "lprice") |>
|
||||
mutate(price = 2 ^ lprice)
|
||||
|
||||
ggplot(diamonds2, aes(carat, price)) +
|
||||
geom_hex(bins = 50) +
|
||||
geom_line(data = grid, colour = "red", size = 1)
|
||||
```
|
||||
|
||||
That tells us something interesting about our data.
|
||||
If we believe our model, then the large diamonds are much cheaper than expected.
|
||||
This is probably because no diamond in this dataset costs more than \$19,000.
|
||||
|
||||
Now we can look at the residuals, which verifies that we've successfully removed the strong linear pattern:
|
||||
|
||||
```{r}
|
||||
diamonds2 <- diamonds2 |>
|
||||
add_residuals(mod_diamond, "lresid")
|
||||
|
||||
ggplot(diamonds2, aes(lcarat, lresid)) +
|
||||
geom_hex(bins = 50)
|
||||
```
|
||||
|
||||
Importantly, we can now re-do our motivating plots using those residuals instead of `price`.
|
||||
|
||||
```{r dev = "png"}
|
||||
ggplot(diamonds2, aes(cut, lresid)) + geom_boxplot()
|
||||
ggplot(diamonds2, aes(color, lresid)) + geom_boxplot()
|
||||
ggplot(diamonds2, aes(clarity, lresid)) + geom_boxplot()
|
||||
```
|
||||
|
||||
Now we see the relationship we expect: as the quality of the diamond increases, so too does its relative price.
|
||||
To interpret the `y` axis, we need to think about what the residuals are telling us, and what scale they are on.
|
||||
A residual of -1 indicates that `lprice` was 1 unit lower than a prediction based solely on its weight.
|
||||
$2^{-1}$ is 1/2, points with a value of -1 are half the expected price, and residuals with value 1 are twice the predicted price.
|
||||
|
||||
### A more complicated model
|
||||
|
||||
If we wanted to, we could continue to build up our model, moving the effects we've observed into the model to make them explicit.
|
||||
For example, we could include `color`, `cut`, and `clarity` into the model so that we also make explicit the effect of these three categorical variables:
|
||||
|
||||
```{r}
|
||||
mod_diamond2 <- lm(lprice ~ lcarat + color + cut + clarity, data = diamonds2)
|
||||
```
|
||||
|
||||
This model now includes four predictors, so it's getting harder to visualise.
|
||||
Fortunately, they're currently all independent which means that we can plot them individually in four plots.
|
||||
To make the process a little easier, we're going to use the `.model` argument to `data_grid`:
|
||||
|
||||
```{r}
|
||||
grid <- diamonds2 |>
|
||||
data_grid(cut, .model = mod_diamond2) |>
|
||||
add_predictions(mod_diamond2)
|
||||
grid
|
||||
|
||||
ggplot(grid, aes(cut, pred)) +
|
||||
geom_point()
|
||||
```
|
||||
|
||||
If the model needs variables that you haven't explicitly supplied, `data_grid()` will automatically fill them in with "typical" value.
|
||||
For continuous variables, it uses the median, and categorical variables it uses the most common value (or values, if there's a tie).
|
||||
|
||||
```{r}
|
||||
diamonds2 <- diamonds2 |>
|
||||
add_residuals(mod_diamond2, "lresid2")
|
||||
|
||||
ggplot(diamonds2, aes(lcarat, lresid2)) +
|
||||
geom_hex(bins = 50)
|
||||
```
|
||||
|
||||
This plot indicates that there are some diamonds with quite large residuals - remember a residual of 2 indicates that the diamond is 4x the price that we expected.
|
||||
It's often useful to look at unusual values individually:
|
||||
|
||||
```{r}
|
||||
diamonds2 |>
|
||||
filter(abs(lresid2) > 1) |>
|
||||
add_predictions(mod_diamond2) |>
|
||||
mutate(pred = round(2 ^ pred)) |>
|
||||
select(price, pred, carat:table, x:z) |>
|
||||
arrange(price)
|
||||
```
|
||||
|
||||
Nothing really jumps out at me here, but it's probably worth spending time considering if this indicates a problem with our model, or if there are errors in the data.
|
||||
If there are mistakes in the data, this could be an opportunity to buy diamonds that have been priced low incorrectly.
|
||||
|
||||
### Exercises
|
||||
|
||||
1. In the plot of `lcarat` vs. `lprice`, there are some bright vertical strips.
|
||||
What do they represent?
|
||||
|
||||
2. If `log(price) = a_0 + a_1 * log(carat)`, what does that say about the relationship between `price` and `carat`?
|
||||
|
||||
3. Extract the diamonds that have very high and very low residuals.
|
||||
Is there anything unusual about these diamonds?
|
||||
Are they particularly bad or good, or do you think these are pricing errors?
|
||||
|
||||
4. Does the final model, `mod_diamond2`, do a good job of predicting diamond prices?
|
||||
Would you trust it to tell you how much to spend if you were buying a diamond?
|
||||
|
||||
## What affects the number of daily flights?
|
||||
|
||||
Let's work through a similar process for a dataset that seems even simpler at first glance: the number of flights that leave NYC per day.
|
||||
This is a really small dataset --- only 365 rows and 2 columns --- and we're not going to end up with a fully realised model, but as you'll see, the steps along the way will help us better understand the data.
|
||||
Let's get started by counting the number of flights per day and visualising it with ggplot2.
|
||||
|
||||
```{r}
|
||||
daily <- flights |>
|
||||
mutate(date = make_date(year, month, day)) |>
|
||||
group_by(date) |>
|
||||
summarise(n = n())
|
||||
daily
|
||||
|
||||
ggplot(daily, aes(date, n)) +
|
||||
geom_line()
|
||||
```
|
||||
|
||||
### Day of week
|
||||
|
||||
Understanding the long-term trend is challenging because there's a very strong day-of-week effect that dominates the subtler patterns.
|
||||
Let's start by looking at the distribution of flight numbers by day-of-week:
|
||||
|
||||
```{r}
|
||||
daily <- daily |>
|
||||
mutate(wday = wday(date, label = TRUE))
|
||||
ggplot(daily, aes(wday, n)) +
|
||||
geom_boxplot()
|
||||
```
|
||||
|
||||
There are fewer flights on weekends because most travel is for business.
|
||||
The effect is particularly pronounced on Saturday: you might sometimes leave on Sunday for a Monday morning meeting, but it's very rare that you'd leave on Saturday as you'd much rather be at home with your family.
|
||||
|
||||
One way to remove this strong pattern is to use a model.
|
||||
First, we fit the model, and display its predictions overlaid on the original data:
|
||||
|
||||
```{r}
|
||||
mod <- lm(n ~ wday, data = daily)
|
||||
|
||||
grid <- daily |>
|
||||
data_grid(wday) |>
|
||||
add_predictions(mod, "n")
|
||||
|
||||
ggplot(daily, aes(wday, n)) +
|
||||
geom_boxplot() +
|
||||
geom_point(data = grid, colour = "red", size = 4)
|
||||
```
|
||||
|
||||
Next we compute and visualise the residuals:
|
||||
|
||||
```{r}
|
||||
daily <- daily |>
|
||||
add_residuals(mod)
|
||||
daily |>
|
||||
ggplot(aes(date, resid)) +
|
||||
geom_ref_line(h = 0) +
|
||||
geom_line()
|
||||
```
|
||||
|
||||
Note the change in the y-axis: now we are seeing the deviation from the expected number of flights, given the day of week.
|
||||
This plot is useful because now that we've removed much of the large day-of-week effect, we can see some of the subtler patterns that remain:
|
||||
|
||||
1. Our model seems to fail starting in June: you can still see a strong regular pattern that our model hasn't captured.
|
||||
Drawing a plot with one line for each day of the week makes the cause easier to see:
|
||||
|
||||
```{r}
|
||||
ggplot(daily, aes(date, resid, colour = wday)) +
|
||||
geom_ref_line(h = 0) +
|
||||
geom_line()
|
||||
```
|
||||
|
||||
Our model fails to accurately predict the number of flights on Saturday: during summer there are more flights than we expect, and during fall there are fewer.
|
||||
We'll see how we can do better to capture this pattern in the next section.
|
||||
|
||||
2. There are some days with far fewer flights than expected:
|
||||
|
||||
```{r}
|
||||
daily |>
|
||||
filter(resid < -100)
|
||||
```
|
||||
|
||||
If you're familiar with American public holidays, you might spot New Year's day, July 4th, Thanksgiving and Christmas.
|
||||
There are some others that don't seem to correspond to public holidays.
|
||||
You'll work on those in one of the exercises.
|
||||
|
||||
3. There seems to be some smoother long term trend over the course of a year.
|
||||
We can highlight that trend with `geom_smooth()`:
|
||||
|
||||
```{r}
|
||||
daily |>
|
||||
ggplot(aes(date, resid)) +
|
||||
geom_ref_line(h = 0) +
|
||||
geom_line(colour = "grey50") +
|
||||
geom_smooth(se = FALSE, span = 0.20)
|
||||
```
|
||||
|
||||
There are fewer flights in January (and December), and more in summer (May-Sep).
|
||||
We can't do much with this pattern quantitatively, because we only have a single year of data.
|
||||
But we can use our domain knowledge to brainstorm potential explanations.
|
||||
|
||||
### Seasonal Saturday effect
|
||||
|
||||
Let's first tackle our failure to accurately predict the number of flights on Saturday.
|
||||
A good place to start is to go back to the raw numbers, focussing on Saturdays:
|
||||
|
||||
```{r}
|
||||
daily |>
|
||||
filter(wday == "Sat") |>
|
||||
ggplot(aes(date, n)) +
|
||||
geom_point() +
|
||||
geom_line() +
|
||||
scale_x_date(NULL, date_breaks = "1 month", date_labels = "%b")
|
||||
```
|
||||
|
||||
(I've used both points and lines to make it more clear what is data and what is interpolation.)
|
||||
|
||||
I suspect this pattern is caused by summer holidays: many people go on holiday in the summer, and people don't mind travelling on Saturdays for vacation.
|
||||
Looking at this plot, we might guess that summer holidays are from early June to late August.
|
||||
That seems to line up fairly well with the [state's school terms](http://schools.nyc.gov/Calendar/2013-2014+School+Year+Calendars.htm): summer break in 2013 was Jun 26--Sep 9.
|
||||
|
||||
Why are there more Saturday flights in spring than fall?
|
||||
I asked some American friends and they suggested that it's less common to plan family vacations during fall because of the big Thanksgiving and Christmas holidays.
|
||||
We don't have the data to know for sure, but it seems like a plausible working hypothesis.
|
||||
|
||||
Lets create a "term" variable that roughly captures the three school terms, and check our work with a plot:
|
||||
|
||||
```{r}
|
||||
term <- function(date) {
|
||||
cut(date,
|
||||
breaks = ymd(20130101, 20130605, 20130825, 20140101),
|
||||
labels = c("spring", "summer", "fall")
|
||||
)
|
||||
}
|
||||
|
||||
daily <- daily |>
|
||||
mutate(term = term(date))
|
||||
|
||||
daily |>
|
||||
filter(wday == "Sat") |>
|
||||
ggplot(aes(date, n, colour = term)) +
|
||||
geom_point(alpha = 1/3) +
|
||||
geom_line() +
|
||||
scale_x_date(NULL, date_breaks = "1 month", date_labels = "%b")
|
||||
```
|
||||
|
||||
(I manually tweaked the dates to get nice breaks in the plot. Using a visualisation to help you understand what your function is doing is a really powerful and general technique.)
|
||||
|
||||
It's useful to see how this new variable affects the other days of the week:
|
||||
|
||||
```{r}
|
||||
daily |>
|
||||
ggplot(aes(wday, n, colour = term)) +
|
||||
geom_boxplot()
|
||||
```
|
||||
|
||||
It looks like there is significant variation across the terms, so fitting a separate day of week effect for each term is reasonable.
|
||||
This improves our model, but not as much as we might hope:
|
||||
|
||||
```{r}
|
||||
mod1 <- lm(n ~ wday, data = daily)
|
||||
mod2 <- lm(n ~ wday * term, data = daily)
|
||||
|
||||
daily |>
|
||||
gather_residuals(without_term = mod1, with_term = mod2) |>
|
||||
ggplot(aes(date, resid, colour = model)) +
|
||||
geom_line(alpha = 0.75)
|
||||
```
|
||||
|
||||
We can see the problem by overlaying the predictions from the model on to the raw data:
|
||||
|
||||
```{r}
|
||||
grid <- daily |>
|
||||
data_grid(wday, term) |>
|
||||
add_predictions(mod2, "n")
|
||||
|
||||
ggplot(daily, aes(wday, n)) +
|
||||
geom_boxplot() +
|
||||
geom_point(data = grid, colour = "red") +
|
||||
facet_wrap(~ term)
|
||||
```
|
||||
|
||||
Our model is finding the *mean* effect, but we have a lot of big outliers, so mean tends to be far away from the typical value.
|
||||
We can alleviate this problem by using a model that is robust to the effect of outliers: `MASS::rlm()`.
|
||||
This greatly reduces the impact of the outliers on our estimates, and gives a model that does a good job of removing the day of week pattern:
|
||||
|
||||
```{r, warn = FALSE}
|
||||
mod3 <- MASS::rlm(n ~ wday * term, data = daily)
|
||||
|
||||
daily |>
|
||||
add_residuals(mod3, "resid") |>
|
||||
ggplot(aes(date, resid)) +
|
||||
geom_hline(yintercept = 0, size = 2, colour = "white") +
|
||||
geom_line()
|
||||
```
|
||||
|
||||
It's now much easier to see the long-term trend, and the positive and negative outliers.
|
||||
|
||||
### Computed variables
|
||||
|
||||
If you're experimenting with many models and many visualisations, it's a good idea to bundle the creation of variables up into a function so there's no chance of accidentally applying a different transformation in different places.
|
||||
For example, we could write:
|
||||
|
||||
```{r}
|
||||
compute_vars <- function(data) {
|
||||
data |>
|
||||
mutate(
|
||||
term = term(date),
|
||||
wday = wday(date, label = TRUE)
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Another option is to put the transformations directly in the model formula:
|
||||
|
||||
```{r}
|
||||
wday2 <- function(x) wday(x, label = TRUE)
|
||||
mod3 <- lm(n ~ wday2(date) * term(date), data = daily)
|
||||
```
|
||||
|
||||
Either approach is reasonable.
|
||||
Making the transformed variable explicit is useful if you want to check your work, or use them in a visualisation.
|
||||
But you can't easily use transformations (like splines) that return multiple columns.
|
||||
Including the transformations in the model function makes life a little easier when you're working with many different datasets because the model is self contained.
|
||||
|
||||
### Time of year: an alternative approach
|
||||
|
||||
In the previous section we used our domain knowledge (how the US school term affects travel) to improve the model.
|
||||
An alternative to using our knowledge explicitly in the model is to give the data more room to speak.
|
||||
We could use a more flexible model and allow that to capture the pattern we're interested in.
|
||||
A simple linear trend isn't adequate, so we could try using a natural spline to fit a smooth curve across the year:
|
||||
|
||||
```{r}
|
||||
library(splines)
|
||||
mod <- MASS::rlm(n ~ wday * ns(date, 5), data = daily)
|
||||
|
||||
daily |>
|
||||
data_grid(wday, date = seq_range(date, n = 13)) |>
|
||||
add_predictions(mod) |>
|
||||
ggplot(aes(date, pred, colour = wday)) +
|
||||
geom_line() +
|
||||
geom_point()
|
||||
```
|
||||
|
||||
We see a strong pattern in the numbers of Saturday flights.
|
||||
This is reassuring, because we also saw that pattern in the raw data.
|
||||
It's a good sign when you get the same signal from different approaches.
|
||||
|
||||
### Exercises
|
||||
|
||||
1. Use your Google sleuthing skills to brainstorm why there were fewer than expected flights on Jan 20, May 26, and Sep 1.
|
||||
(Hint: they all have the same explanation.) How would these days generalise to another year?
|
||||
|
||||
2. What do the three days with high positive residuals represent?
|
||||
How would these days generalise to another year?
|
||||
|
||||
```{r}
|
||||
daily |>
|
||||
slice_max(n = 3, resid)
|
||||
```
|
||||
|
||||
3. Create a new variable that splits the `wday` variable into terms, but only for Saturdays, i.e. it should have `Thurs`, `Fri`, but `Sat-summer`, `Sat-spring`, `Sat-fall`.
|
||||
How does this model compare with the model with every combination of `wday` and `term`?
|
||||
|
||||
4. Create a new `wday` variable that combines the day of week, term (for Saturdays), and public holidays.
|
||||
What do the residuals of that model look like?
|
||||
|
||||
5. What happens if you fit a day of week effect that varies by month (i.e. `n ~ wday * month`)?
|
||||
Why is this not very helpful?
|
||||
|
||||
6. What would you expect the model `n ~ wday + ns(date, 5)` to look like?
|
||||
Knowing what you know about the data, why would you expect it to be not particularly effective?
|
||||
|
||||
7. We hypothesised that people leaving on Sundays are more likely to be business travellers who need to be somewhere on Monday.
|
||||
Explore that hypothesis by seeing how it breaks down based on distance and time: if it's true, you'd expect to see more Sunday evening flights to places that are far away.
|
||||
|
||||
8. It's a little frustrating that Sunday and Saturday are on separate ends of the plot.
|
||||
Write a small function to set the levels of the factor so that the week starts on Monday.
|
||||
|
||||
## Learning more about models
|
||||
|
||||
We have only scratched the absolute surface of modelling, but you have hopefully gained some simple, but general-purpose tools that you can use to improve your own data analyses.
|
||||
It's OK to start simple!
|
||||
As you've seen, even very simple models can make a dramatic difference in your ability to tease out interactions between variables.
|
||||
|
||||
These modelling chapters are even more opinionated than the rest of the book.
|
||||
I approach modelling from a somewhat different perspective to most others, and there is relatively little space devoted to it.
|
||||
Modelling really deserves a book on its own, so I'd highly recommend that you read at least one of these three books:
|
||||
|
||||
- *Statistical Modeling: A Fresh Approach* by Danny Kaplan, <https://dtkaplan.github.io/SM2-bookdown/>.
|
||||
This book provides a gentle introduction to modelling, where you build your intuition, mathematical tools, and R skills in parallel.
|
||||
The book replaces a traditional "introduction to statistics" course, providing a curriculum that is up-to-date and relevant to data science.
|
||||
|
||||
- *An Introduction to Statistical Learning* by Gareth James, Daniela Witten, Trevor Hastie, and Robert Tibshirani, <https://www.statlearning.com> (available online for free).
|
||||
This book presents a family of modern modelling techniques collectively known as statistical learning.
|
||||
For an even deeper understanding of the math behind the models, read the classic *Elements of Statistical Learning* by Trevor Hastie, Robert Tibshirani, and Jerome Friedman, <https://web.stanford.edu/~hastie/Papers/ESLII.pdf> (also available online for free).
|
||||
|
||||
- *Applied Predictive Modeling* by Max Kuhn and Kjell Johnson, <http://appliedpredictivemodeling.com>.
|
||||
This book is a companion to the **caret** package and provides practical tools for dealing with real-life predictive modelling challenges.
|
|
@ -1,622 +0,0 @@
|
|||
# Many models
|
||||
|
||||
## Introduction
|
||||
|
||||
In this chapter you're going to learn three powerful ideas that help you to work with large numbers of models with ease:
|
||||
|
||||
1. Using many simple models to better understand complex datasets.
|
||||
|
||||
2. Using list-columns to store arbitrary data structures in a data frame.
|
||||
For example, this will allow you to have a column that contains linear models.
|
||||
|
||||
3. Using the **broom** package, by David Robinson, to turn models into tidy data.
|
||||
This is a powerful technique for working with large numbers of models because once you have tidy data, you can apply all of the techniques that you've learned about earlier in the book.
|
||||
|
||||
We'll start by diving into a motivating example using data about life expectancy around the world.
|
||||
It's a small dataset but it illustrates how important modelling can be for improving your visualisations.
|
||||
We'll use a large number of simple models to partition out some of the strongest signals so we can see the subtler signals that remain.
|
||||
We'll also see how model summaries can help us pick out outliers and unusual trends.
|
||||
|
||||
The following sections will dive into more detail about the individual techniques:
|
||||
|
||||
1. In [list-columns](#list-columns-1), you'll learn more about the list-column data structure, and why it's valid to put lists in data frames.
|
||||
|
||||
2. In [creating list-columns], you'll learn the three main ways in which you'll create list-columns.
|
||||
|
||||
3. In [simplifying list-columns] you'll learn how to convert list-columns back to regular atomic vectors (or sets of atomic vectors) so you can work with them more easily.
|
||||
|
||||
4. In [making tidy data with broom], you'll learn about the full set of tools provided by broom, and see how they can be applied to other types of data structure.
|
||||
|
||||
This chapter is somewhat aspirational: if this book is your first introduction to R, this chapter is likely to be a struggle.
|
||||
It requires you to have deeply internalised ideas about modelling, data structures, and iteration.
|
||||
So don't worry if you don't get it --- just put this chapter aside for a few months, and come back when you want to stretch your brain.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Working with many models requires many of the packages of the tidyverse (for data exploration, wrangling, and programming) and modelr to facilitate modelling.
|
||||
|
||||
```{r setup, message = FALSE}
|
||||
library(modelr)
|
||||
library(tidyverse)
|
||||
```
|
||||
|
||||
## gapminder
|
||||
|
||||
To motivate the power of many simple models, we're going to look into the "gapminder" data.
|
||||
This data was popularised by Hans Rosling, a Swedish doctor and statistician.
|
||||
If you've never heard of him, stop reading this chapter right now and go watch one of his videos!
|
||||
He is a fantastic data presenter and illustrates how you can use data to present a compelling story.
|
||||
A good place to start is this short video filmed in conjunction with the BBC: <https://www.youtube.com/watch?v=jbkSRLYSojo>.
|
||||
|
||||
The gapminder data summarises the progression of countries over time, looking at statistics like life expectancy and GDP.
|
||||
The data is easy to access in R, thanks to Jenny Bryan who created the gapminder package:
|
||||
|
||||
```{r}
|
||||
library(gapminder)
|
||||
gapminder
|
||||
```
|
||||
|
||||
In this case study, we're going to focus on just three variables to answer the question "How does life expectancy (`lifeExp`) change over time (`year`) for each country (`country`)?".
|
||||
A good place to start is with a plot:
|
||||
|
||||
```{r}
|
||||
gapminder |>
|
||||
ggplot(aes(year, lifeExp, group = country)) +
|
||||
geom_line(alpha = 1/3)
|
||||
```
|
||||
|
||||
This is a small dataset: it only has \~1,700 observations and 3 variables.
|
||||
But it's still hard to see what's going on!
|
||||
Overall, it looks like life expectancy has been steadily improving.
|
||||
However, if you look closely, you might notice some countries that don't follow this pattern.
|
||||
How can we make those countries easier to see?
|
||||
|
||||
One way is to use the same approach as in the last chapter: there's a strong signal (overall linear growth) that makes it hard to see subtler trends.
|
||||
We'll tease these factors apart by fitting a model with a linear trend.
|
||||
The model captures steady growth over time, and the residuals will show what's left.
|
||||
|
||||
You already know how to do that if we had a single country:
|
||||
|
||||
```{r, out.width = "33%", fig.asp = 1, fig.width = 3, fig.align='default'}
|
||||
nz <- filter(gapminder, country == "New Zealand")
|
||||
nz |>
|
||||
ggplot(aes(year, lifeExp)) +
|
||||
geom_line() +
|
||||
ggtitle("Full data = ")
|
||||
|
||||
nz_mod <- lm(lifeExp ~ year, data = nz)
|
||||
nz |>
|
||||
add_predictions(nz_mod) |>
|
||||
ggplot(aes(year, pred)) +
|
||||
geom_line() +
|
||||
ggtitle("Linear trend + ")
|
||||
|
||||
nz |>
|
||||
add_residuals(nz_mod) |>
|
||||
ggplot(aes(year, resid)) +
|
||||
geom_hline(yintercept = 0, colour = "white", size = 3) +
|
||||
geom_line() +
|
||||
ggtitle("Remaining pattern")
|
||||
```
|
||||
|
||||
How can we easily fit that model to every country?
|
||||
|
||||
### Nested data
|
||||
|
||||
You could imagine copy and pasting that code multiple times; but you've already learned a better way!
|
||||
Extract out the common code with a function and repeat using a map function from purrr.
|
||||
This problem is structured a little differently to what you've seen before.
|
||||
Instead of repeating an action for each variable, we want to repeat an action for each country, a subset of rows.
|
||||
To do that, we need a new data structure: the **nested data frame**.
|
||||
To create a nested data frame we start with a grouped data frame, and "nest" it:
|
||||
|
||||
```{r}
|
||||
by_country <- gapminder |>
|
||||
group_by(country, continent) |>
|
||||
nest()
|
||||
|
||||
by_country
|
||||
```
|
||||
|
||||
(I'm cheating a little by grouping on both `continent` and `country`. Given `country`, `continent` is fixed, so this doesn't add any more groups, but it's an easy way to carry an extra variable along for the ride.)
|
||||
|
||||
This creates a data frame that has one row per group (per country), and a rather unusual column: `data`.
|
||||
`data` is a list of data frames (or tibbles, to be precise).
|
||||
This seems like a crazy idea: we have a data frame with a column that is a list of other data frames!
|
||||
I'll explain shortly why I think this is a good idea.
|
||||
|
||||
The `data` column is a little tricky to look at because it's a moderately complicated list, and we're still working on good tools to explore these objects.
|
||||
Unfortunately using `str()` is not recommended as it will often produce very long output.
|
||||
But if you pluck out a single element from the `data` column you'll see that it contains all the data for that country (in this case, Afghanistan).
|
||||
|
||||
```{r}
|
||||
by_country$data[[1]]
|
||||
```
|
||||
|
||||
Note the difference between a standard grouped data frame and a nested data frame: in a grouped data frame, each row is an observation; in a nested data frame, each row is a group.
|
||||
Another way to think about a nested dataset is we now have a meta-observation: a row that represents the complete time course for a country, rather than a single point in time.
|
||||
|
||||
### List-columns
|
||||
|
||||
Now that we have our nested data frame, we're in a good position to fit some models.
|
||||
We have a model-fitting function:
|
||||
|
||||
```{r}
|
||||
country_model <- function(df) {
|
||||
lm(lifeExp ~ year, data = df)
|
||||
}
|
||||
```
|
||||
|
||||
And we want to apply it to every data frame.
|
||||
The data frames are in a list, so we can use `purrr::map()` to apply `country_model` to each element:
|
||||
|
||||
```{r}
|
||||
models <- map(by_country$data, country_model)
|
||||
```
|
||||
|
||||
However, rather than leaving the list of models as a free-floating object, I think it's better to store it as a column in the `by_country` data frame.
|
||||
Storing related objects in columns is a key part of the value of data frames, and why I think list-columns are such a good idea.
|
||||
In the course of working with these countries, we are going to have lots of lists where we have one element per country.
|
||||
So why not store them all together in one data frame?
|
||||
|
||||
In other words, instead of creating a new object in the global environment, we're going to create a new variable in the `by_country` data frame.
|
||||
That's a job for `dplyr::mutate()`:
|
||||
|
||||
```{r}
|
||||
by_country <- by_country |>
|
||||
mutate(model = map(data, country_model))
|
||||
by_country
|
||||
```
|
||||
|
||||
This has a big advantage: because all the related objects are stored together, you don't need to manually keep them in sync when you filter or arrange.
|
||||
The semantics of the data frame takes care of that for you:
|
||||
|
||||
```{r}
|
||||
by_country |>
|
||||
filter(continent == "Europe")
|
||||
by_country |>
|
||||
arrange(continent, country)
|
||||
```
|
||||
|
||||
If your list of data frames and list of models were separate objects, you have to remember that whenever you re-order or subset one vector, you need to re-order or subset all the others in order to keep them in sync.
|
||||
If you forget, your code will continue to work, but it will give the wrong answer!
|
||||
|
||||
### Unnesting
|
||||
|
||||
Previously we computed the residuals of a single model with a single dataset.
|
||||
Now we have 142 data frames and 142 models.
|
||||
To compute the residuals, we need to call `add_residuals()` with each model-data pair:
|
||||
|
||||
```{r}
|
||||
by_country <- by_country |>
|
||||
mutate(
|
||||
resids = map2(data, model, add_residuals)
|
||||
)
|
||||
by_country
|
||||
```
|
||||
|
||||
But how can you plot a list of data frames?
|
||||
Instead of struggling to answer that question, let's turn the list of data frames back into a regular data frame.
|
||||
Previously we used `nest()` to turn a regular data frame into an nested data frame, and now we do the opposite with `unnest()`:
|
||||
|
||||
```{r}
|
||||
resids <- unnest(by_country, resids)
|
||||
resids
|
||||
```
|
||||
|
||||
Note that each regular column is repeated once for each row of the nested tibble.
|
||||
|
||||
Now we have regular data frame, we can plot the residuals:
|
||||
|
||||
```{r}
|
||||
resids |>
|
||||
ggplot(aes(year, resid)) +
|
||||
geom_line(aes(group = country), alpha = 1 / 3) +
|
||||
geom_smooth(se = FALSE)
|
||||
|
||||
```
|
||||
|
||||
Facetting by continent is particularly revealing:
|
||||
|
||||
```{r}
|
||||
resids |>
|
||||
ggplot(aes(year, resid, group = country)) +
|
||||
geom_line(alpha = 1 / 3) +
|
||||
facet_wrap(~continent)
|
||||
```
|
||||
|
||||
It looks like we've missed some mild patterns.
|
||||
There's also something interesting going on in Africa: we see some very large residuals which suggests our model isn't fitting so well there.
|
||||
We'll explore that more in the next section, attacking it from a slightly different angle.
|
||||
|
||||
### Model quality
|
||||
|
||||
Instead of looking at the residuals from the model, we could look at some general measurements of model quality.
|
||||
You learned how to compute some specific measures in the previous chapter.
|
||||
Here we'll show a different approach using the broom package.
|
||||
The broom package provides a general set of functions to turn models into tidy data.
|
||||
Here we'll use `broom::glance()` to extract some model quality metrics.
|
||||
If we apply it to a model, we get a data frame with a single row:
|
||||
|
||||
```{r}
|
||||
broom::glance(nz_mod)
|
||||
```
|
||||
|
||||
We can use `mutate()` and `unnest()` to create a data frame with a row for each country:
|
||||
|
||||
```{r}
|
||||
glance <- by_country |>
|
||||
mutate(glance = map(model, broom::glance)) |>
|
||||
select(country, continent, glance) |>
|
||||
unnest(glance)
|
||||
glance
|
||||
```
|
||||
|
||||
(Pay attention to the variables that aren't printed: there's a lot of useful stuff there.)
|
||||
|
||||
With this data frame in hand, we can start to look for models that don't fit well:
|
||||
|
||||
```{r}
|
||||
glance |>
|
||||
arrange(r.squared)
|
||||
```
|
||||
|
||||
The worst models all appear to be in Africa.
|
||||
Let's double check that with a plot.
|
||||
Here we have a relatively small number of observations and a discrete variable, so `geom_jitter()` is effective:
|
||||
|
||||
```{r}
|
||||
glance |>
|
||||
ggplot(aes(continent, r.squared)) +
|
||||
geom_jitter(width = 0.5)
|
||||
```
|
||||
|
||||
We could pull out the countries with particularly bad $R^2$ and plot the data:
|
||||
|
||||
```{r}
|
||||
bad_fit <- filter(glance, r.squared < 0.25)
|
||||
|
||||
gapminder |>
|
||||
semi_join(bad_fit, by = "country") |>
|
||||
ggplot(aes(year, lifeExp, colour = country)) +
|
||||
geom_line()
|
||||
```
|
||||
|
||||
We see two main effects here: the tragedies of the HIV/AIDS epidemic and the Rwandan genocide.
|
||||
|
||||
### Exercises
|
||||
|
||||
1. A linear trend seems to be slightly too simple for the overall trend.
|
||||
Can you do better with a quadratic polynomial?
|
||||
How can you interpret the coefficients of the quadratic?
|
||||
(Hint you might want to transform `year` so that it has mean zero.)
|
||||
|
||||
2. Explore other methods for visualising the distribution of $R^2$ per continent.
|
||||
You might want to try the ggbeeswarm package, which provides similar methods for avoiding overlaps as jitter, but uses deterministic methods.
|
||||
|
||||
3. To create the last plot (showing the data for the countries with the worst model fits), we needed two steps: we created a data frame with one row per country and then semi-joined it to the original dataset.
|
||||
It's possible to avoid this join if we use `unnest()` instead of `unnest(.drop = TRUE)`.
|
||||
How?
|
||||
|
||||
## List-columns {#list-columns-1}
|
||||
|
||||
Now that you've seen a basic workflow for managing many models, let's dive back into some of the details.
|
||||
In this section, we'll explore the list-column data structure in a little more detail.
|
||||
It's only recently that I've really appreciated the idea of the list-column.
|
||||
List-columns are implicit in the definition of the data frame: a data frame is a named list of equal length vectors.
|
||||
A list is a vector, so it's always been legitimate to use a list as a column of a data frame.
|
||||
However, base R doesn't make it easy to create list-columns, and `data.frame()` treats a list as a list of columns:.
|
||||
|
||||
```{r}
|
||||
data.frame(x = list(1:3, 3:5))
|
||||
```
|
||||
|
||||
You can prevent `data.frame()` from doing this with `I()`, but the result doesn't print particularly well:
|
||||
|
||||
```{r}
|
||||
data.frame(
|
||||
x = I(list(1:3, 3:5)),
|
||||
y = c("1, 2", "3, 4, 5")
|
||||
)
|
||||
```
|
||||
|
||||
Tibble alleviates this problem by being lazier (`tibble()` doesn't modify its inputs) and by providing a better print method:
|
||||
|
||||
```{r}
|
||||
tibble(
|
||||
x = list(1:3, 3:5),
|
||||
y = c("1, 2", "3, 4, 5")
|
||||
)
|
||||
```
|
||||
|
||||
It's even easier with `tribble()` as it can automatically work out that you need a list:
|
||||
|
||||
```{r}
|
||||
tribble(
|
||||
~x, ~y,
|
||||
1:3, "1, 2",
|
||||
3:5, "3, 4, 5"
|
||||
)
|
||||
```
|
||||
|
||||
List-columns are often most useful as intermediate data structure.
|
||||
They're hard to work with directly, because most R functions work with atomic vectors or data frames, but the advantage of keeping related items together in a data frame is worth a little hassle.
|
||||
|
||||
Generally there are three parts of an effective list-column pipeline:
|
||||
|
||||
1. You create the list-column using one of `nest()`, `summarise()` + `list()`, or `mutate()` + a map function, as described in [Creating list-columns].
|
||||
|
||||
2. You create other intermediate list-columns by transforming existing list columns with `map()`, `map2()` or `pmap()`.
|
||||
For example, in the case study above, we created a list-column of models by transforming a list-column of data frames.
|
||||
|
||||
3. You simplify the list-column back down to a data frame or atomic vector, as described in [Simplifying list-columns].
|
||||
|
||||
## Creating list-columns
|
||||
|
||||
Typically, you won't create list-columns with `tibble()`.
|
||||
Instead, you'll create them from regular columns, using one of three methods:
|
||||
|
||||
1. With `tidyr::nest()` to convert a grouped data frame into a nested data frame where you have list-column of data frames.
|
||||
|
||||
2. With `mutate()` and vectorised functions that return a list.
|
||||
|
||||
3. With `summarise()` and summary functions that return multiple results.
|
||||
|
||||
Alternatively, you might create them from a named list, using `tibble::enframe()`.
|
||||
|
||||
Generally, when creating list-columns, you should make sure they're homogeneous: each element should contain the same type of thing.
|
||||
There are no checks to make sure this is true, but if you use purrr and remember what you've learned about type-stable functions, you should find it happens naturally.
|
||||
|
||||
### With nesting
|
||||
|
||||
`nest()` creates a nested data frame, which is a data frame with a list-column of data frames.
|
||||
In a nested data frame each row is a meta-observation: the other columns give variables that define the observation (like country and continent above), and the list-column of data frames gives the individual observations that make up the meta-observation.
|
||||
|
||||
There are two ways to use `nest()`.
|
||||
So far you've seen how to use it with a grouped data frame.
|
||||
When applied to a grouped data frame, `nest()` keeps the grouping columns as is, and bundles everything else into the list-column:
|
||||
|
||||
```{r}
|
||||
gapminder |>
|
||||
group_by(country, continent) |>
|
||||
nest()
|
||||
```
|
||||
|
||||
You can also use it on an ungrouped data frame, specifying which columns you want to nest:
|
||||
|
||||
```{r}
|
||||
gapminder |>
|
||||
nest(data = c(year:gdpPercap))
|
||||
```
|
||||
|
||||
### From vectorised functions
|
||||
|
||||
Some useful functions take an atomic vector and return a list.
|
||||
For example, in [strings] you learned about `stringr::str_split()` which takes a character vector and returns a list of character vectors.
|
||||
If you use that inside mutate, you'll get a list-column:
|
||||
|
||||
```{r}
|
||||
df <- tribble(
|
||||
~x1,
|
||||
"a,b,c",
|
||||
"d,e,f,g"
|
||||
)
|
||||
|
||||
df |>
|
||||
mutate(x2 = stringr::str_split(x1, ","))
|
||||
```
|
||||
|
||||
`unnest()` knows how to handle these lists of vectors:
|
||||
|
||||
```{r}
|
||||
df |>
|
||||
mutate(x2 = stringr::str_split(x1, ",")) |>
|
||||
unnest(x2)
|
||||
```
|
||||
|
||||
(If you find yourself using this pattern a lot, make sure to check out `tidyr::separate_rows()` which is a wrapper around this common pattern).
|
||||
|
||||
Another example of this pattern is using the `map()`, `map2()`, `pmap()` from purrr.
|
||||
For example, we could take the final example from [Invoking different functions] and rewrite it to use `mutate()`:
|
||||
|
||||
```{r}
|
||||
sim <- tribble(
|
||||
~f, ~params,
|
||||
"runif", list(min = -1, max = 1),
|
||||
"rnorm", list(sd = 5),
|
||||
"rpois", list(lambda = 10)
|
||||
)
|
||||
|
||||
sim |>
|
||||
mutate(sims = invoke_map(f, params, n = 10))
|
||||
```
|
||||
|
||||
Note that technically `sims` isn't homogeneous because it contains both double and integer vectors.
|
||||
However, this is unlikely to cause many problems since integers and doubles are both numeric vectors.
|
||||
|
||||
### From multivalued summaries
|
||||
|
||||
One restriction of `summarise()` is that it only works with summary functions that return a single value.
|
||||
That means that you can't use it with functions like `quantile()` that return a vector of arbitrary length:
|
||||
|
||||
```{r, error = TRUE}
|
||||
mtcars |>
|
||||
group_by(cyl) |>
|
||||
summarise(q = quantile(mpg))
|
||||
```
|
||||
|
||||
You can however, wrap the result in a list!
|
||||
This obeys the contract of `summarise()`, because each summary is now a list (a vector) of length 1.
|
||||
|
||||
```{r}
|
||||
mtcars |>
|
||||
group_by(cyl) |>
|
||||
summarise(q = list(quantile(mpg)))
|
||||
```
|
||||
|
||||
To make useful results with unnest, you'll also need to capture the probabilities:
|
||||
|
||||
```{r}
|
||||
probs <- c(0.01, 0.25, 0.5, 0.75, 0.99)
|
||||
mtcars |>
|
||||
group_by(cyl) |>
|
||||
summarise(p = list(probs), q = list(quantile(mpg, probs))) |>
|
||||
unnest(c(p, q))
|
||||
```
|
||||
|
||||
### From a named list
|
||||
|
||||
Data frames with list-columns provide a solution to a common problem: what do you do if you want to iterate over both the contents of a list and its elements?
|
||||
Instead of trying to jam everything into one object, it's often easier to make a data frame: one column can contain the elements, and one column can contain the list.
|
||||
An easy way to create such a data frame from a list is `tibble::enframe()`.
|
||||
|
||||
```{r}
|
||||
x <- list(
|
||||
a = 1:5,
|
||||
b = 3:4,
|
||||
c = 5:6
|
||||
)
|
||||
|
||||
df <- enframe(x)
|
||||
df
|
||||
```
|
||||
|
||||
The advantage of this structure is that it generalises in a straightforward way - names are useful if you have a character vector of metadata but don't help if you have other types of data, or multiple vectors.
|
||||
|
||||
Now if you want to iterate over names and values in parallel, you can use `map2()`:
|
||||
|
||||
```{r}
|
||||
df |>
|
||||
mutate(
|
||||
smry = map2_chr(name, value, ~ stringr::str_c(.x, ": ", .y[1]))
|
||||
)
|
||||
```
|
||||
|
||||
### Exercises
|
||||
|
||||
1. List all the functions that you can think of that take a atomic vector and return a list.
|
||||
|
||||
2. Brainstorm useful summary functions that, like `quantile()`, return multiple values.
|
||||
|
||||
3. What's missing in the following data frame?
|
||||
How does `quantile()` return that missing piece?
|
||||
Why isn't that helpful here?
|
||||
|
||||
```{r}
|
||||
mtcars |>
|
||||
group_by(cyl) |>
|
||||
summarise(q = list(quantile(mpg))) |>
|
||||
unnest(q)
|
||||
```
|
||||
|
||||
4. What does this code do?
|
||||
Why might it be useful?
|
||||
|
||||
```{r, eval = FALSE}
|
||||
mtcars |>
|
||||
group_by(cyl) |>
|
||||
summarise_all(list(list))
|
||||
```
|
||||
|
||||
## Simplifying list-columns
|
||||
|
||||
To apply the techniques of data manipulation and visualisation you've learned in this book, you'll need to simplify the list-column back to a regular column (an atomic vector), or set of columns.
|
||||
The technique you'll use to collapse back down to a simpler structure depends on whether you want a single value per element, or multiple values:
|
||||
|
||||
1. If you want a single value, use `mutate()` with `map_lgl()`, `map_int()`, `map_dbl()`, and `map_chr()` to create an atomic vector.
|
||||
|
||||
2. If you want many values, use `unnest()` to convert list-columns back to regular columns, repeating the rows as many times as necessary.
|
||||
|
||||
These are described in more detail below.
|
||||
|
||||
### List to vector
|
||||
|
||||
If you can reduce your list column to an atomic vector then it will be a regular column.
|
||||
For example, you can always summarise an object with its type and length, so this code will work regardless of what sort of list-column you have:
|
||||
|
||||
```{r}
|
||||
df <- tribble(
|
||||
~x,
|
||||
letters[1:5],
|
||||
1:3,
|
||||
runif(5)
|
||||
)
|
||||
|
||||
df |> mutate(
|
||||
type = map_chr(x, typeof),
|
||||
length = map_int(x, length)
|
||||
)
|
||||
```
|
||||
|
||||
This is the same basic information that you get from the default tbl print method, but now you can use it for filtering.
|
||||
This is a useful technique if you have a heterogeneous list, and want to filter out the parts aren't working for you.
|
||||
|
||||
Don't forget about the `map_*()` shortcuts - you can use `map_chr(x, "apple")` to extract the string stored in `apple` for each element of `x`.
|
||||
This is useful for pulling apart nested lists into regular columns.
|
||||
Use the `.null` argument to provide a value to use if the element is missing (instead of returning `NULL`):
|
||||
|
||||
```{r}
|
||||
df <- tribble(
|
||||
~x,
|
||||
list(a = 1, b = 2),
|
||||
list(a = 2, c = 4)
|
||||
)
|
||||
df |> mutate(
|
||||
a = map_dbl(x, "a"),
|
||||
b = map_dbl(x, "b", .null = NA_real_)
|
||||
)
|
||||
```
|
||||
|
||||
### Unnesting
|
||||
|
||||
`unnest()` works by repeating the regular columns once for each element of the list-column.
|
||||
For example, in the following very simple example we repeat the first row 4 times (because there the first element of `y` has length four), and the second row once:
|
||||
|
||||
```{r}
|
||||
tibble(x = 1:2, y = list(1:4, 1)) |> unnest(y)
|
||||
```
|
||||
|
||||
This means that you can't simultaneously unnest two columns that contain different number of elements:
|
||||
|
||||
```{r, error = TRUE}
|
||||
# Ok, because y and z have the same number of elements in
|
||||
# every row
|
||||
df1 <- tribble(
|
||||
~x, ~y, ~z,
|
||||
1, c("a", "b"), 1:2,
|
||||
2, "c", 3
|
||||
)
|
||||
df1
|
||||
df1 |> unnest(c(y, z))
|
||||
|
||||
# Doesn't work because y and z have different number of elements
|
||||
df2 <- tribble(
|
||||
~x, ~y, ~z,
|
||||
1, "a", 1:2,
|
||||
2, c("b", "c"), 3
|
||||
)
|
||||
df2
|
||||
df2 |> unnest(c(y, z))
|
||||
```
|
||||
|
||||
The same principle applies when unnesting list-columns of data frames.
|
||||
You can unnest multiple list-cols as long as all the data frames in each row have the same number of rows.
|
||||
|
||||
### Exercises
|
||||
|
||||
1. Why might the `lengths()` function be useful for creating atomic vector columns from list-columns?
|
||||
|
||||
2. List the most common types of vector found in a data frame.
|
||||
What makes lists different?
|
||||
|
||||
## Making tidy data with broom
|
||||
|
||||
The broom package provides three general tools for turning models into tidy data frames:
|
||||
|
||||
1. `broom::glance(model)` returns a row for each model.
|
||||
Each column gives a model summary: either a measure of model quality, or complexity, or a combination of the two.
|
||||
|
||||
2. `broom::tidy(model)` returns a row for each coefficient in the model.
|
||||
Each column gives information about the estimate or its variability.
|
||||
|
||||
3. `broom::augment(model, data)` returns a row for each row in `data`, adding extra values like residuals, and influence statistics.
|
|
@ -1,69 +0,0 @@
|
|||
# (PART) Model {.unnumbered}
|
||||
|
||||
# Introduction {#model-intro}
|
||||
|
||||
Now that you are equipped with powerful programming tools we can finally return to modelling.
|
||||
You'll use your new tools of data wrangling and programming, to fit many models and understand how they work.
|
||||
The focus of this book is on exploration, not confirmation or formal inference.
|
||||
But you'll learn a few basic tools that help you understand the variation within your models.
|
||||
|
||||
```{r echo = FALSE, out.width = "75%"}
|
||||
knitr::include_graphics("diagrams/data-science-model.png")
|
||||
```
|
||||
|
||||
The goal of a model is to provide a simple low-dimensional summary of a dataset.
|
||||
Ideally, the model will capture true "signals" (i.e. patterns generated by the phenomenon of interest), and ignore "noise" (i.e. random variation that you're not interested in).
|
||||
Here we only cover "predictive" models, which, as the name suggests, generate predictions.
|
||||
There is another type of model that we're not going to discuss: "data discovery" models.
|
||||
These models don't make predictions, but instead help you discover interesting relationships within your data.
|
||||
(These two categories of models are sometimes called supervised and unsupervised, but I don't think that terminology is particularly illuminating.)
|
||||
|
||||
This book is not going to give you a deep understanding of the mathematical theory that underlies models.
|
||||
It will, however, build your intuition about how statistical models work, and give you a family of useful tools that allow you to use models to better understand your data:
|
||||
|
||||
- In [model basics], you'll learn how models work mechanistically, focussing on the important family of linear models.
|
||||
You'll learn general tools for gaining insight into what a predictive model tells you about your data, focussing on simple simulated datasets.
|
||||
|
||||
- In [model building], you'll learn how to use models to pull out known patterns in real data.
|
||||
Once you have recognised an important pattern it's useful to make it explicit in a model, because then you can more easily see the subtler signals that remain.
|
||||
|
||||
- In [many models], you'll learn how to use many simple models to help understand complex datasets.
|
||||
This is a powerful technique, but to access it you'll need to combine modelling and programming tools.
|
||||
|
||||
These topics are notable because of what they don't include: any tools for quantitatively assessing models.
|
||||
That is deliberate: precisely quantifying a model requires a couple of big ideas that we just don't have the space to cover here.
|
||||
For now, you'll rely on qualitative assessment and your natural scepticism.
|
||||
In [Learning more about models], we'll point you to other resources where you can learn more.
|
||||
|
||||
## Hypothesis generation vs. hypothesis confirmation
|
||||
|
||||
In this book, we are going to use models as a tool for exploration, completing the trifecta of the tools for EDA that were introduced in Part 1.
|
||||
This is not how models are usually taught, but as you will see, models are an important tool for exploration.
|
||||
Traditionally, the focus of modelling is on inference, or for confirming that an hypothesis is true.
|
||||
Doing this correctly is not complicated, but it is hard.
|
||||
There is a pair of ideas that you must understand in order to do inference correctly:
|
||||
|
||||
1. Each observation can either be used for exploration or confirmation, not both.
|
||||
|
||||
2. You can use an observation as many times as you like for exploration, but you can only use it once for confirmation.
|
||||
As soon as you use an observation twice, you've switched from confirmation to exploration.
|
||||
|
||||
This is necessary because to confirm a hypothesis you must use data independent of the data that you used to generate the hypothesis.
|
||||
Otherwise you will be over-optimistic.
|
||||
There is absolutely nothing wrong with exploration, but you should never sell an exploratory analysis as a confirmatory analysis because it is fundamentally misleading.
|
||||
|
||||
If you are serious about doing a confirmatory analysis, one approach is to split your data into three pieces before you begin the analysis:
|
||||
|
||||
1. 60% of your data goes into a **training** (or exploration) set.
|
||||
You're allowed to do anything you like with this data: visualise it and fit tons of models to it.
|
||||
|
||||
2. 20% goes into a **query** set.
|
||||
You can use this data to compare models or visualisations by hand, but you're not allowed to use it as part of an automated process.
|
||||
|
||||
3. 20% is held back for a **test** set.
|
||||
You can only use this data ONCE, to test your final model.
|
||||
|
||||
This partitioning allows you to explore the training data, occasionally generating candidate hypotheses that you check with the query set.
|
||||
When you are confident you have the right model, you can check it once with the test data.
|
||||
|
||||
(Note that even when doing confirmatory modelling, you will still need to do EDA. If you don't do any EDA you will remain blind to the quality problems with your data.)
|
|
@ -1,227 +0,0 @@
|
|||
```{r, include = FALSE}
|
||||
library(magrittr)
|
||||
```
|
||||
|
||||
# Robust code
|
||||
|
||||
(This is an advanced topic. You shouldn't worry too much about it when you first start writing functions. Instead you should focus on getting a function that works right for the easiest 80% of the problem. Then in time, you'll learn how to get to 99% with minimal extra effort. The defaults in this book should steer you in the right direction: we avoid teaching you functions with major surprises.)
|
||||
|
||||
In this section you'll learn an important principle that lends itself to reliable and readable code: favour code that can be understood with a minimum of context. On one extreme, take this code:
|
||||
|
||||
```{r, eval = FALSE}
|
||||
baz <- foo(bar, qux)
|
||||
```
|
||||
|
||||
What does it do? You can glean only a little from the context: `foo()` is a function that takes (at least) two arguments, and it returns a result we store in `baz`. But apart from that, you have no idea. To understand what this function does, you need to read the definitions of `foo()`, `bar`, and `qux`. Using better variable names helps a lot:
|
||||
|
||||
```{r, eval = FALSE}
|
||||
df2 <- arrange(df, qux)
|
||||
```
|
||||
|
||||
It's now much easier to see what's going on! Function and variable names are important because they tell you about (or at least jog your memory of) what the code does. That helps you understand code in isolation, even if you don't completely understand all the details. Unfortunately naming things is hard, and it's hard to give concrete advice apart from giving objects short but evocative names. As autocomplete in RStudio has gotten better, I've tended to use longer names that are more descriptive. Short names are faster to type, but you write code relatively infrequently compared to the number of times that you read it.
|
||||
|
||||
The idea of minimising the context needed to understand your code goes beyond just good naming. You also want to favour functions with predictable behaviour and few surprises. If a function does radically different things when its inputs differ slightly, you'll need to carefully read the surrounding context in order to predict what it will do. The goal of this section is to educate you about the most common ways R functions can be surprising and to provide you with unsurprising alternatives.
|
||||
|
||||
There are three common classes of surprises in R:
|
||||
|
||||
1. Unstable types: What will `df[, x]` return? You can assume that `df`
|
||||
is a data frame and `x` is a vector because of their names. But you don't
|
||||
know whether this code will return a data frame or a vector because the
|
||||
behaviour of `[` depends on the length of x.
|
||||
|
||||
1. Non-standard evaluation: What will `filter(df, x == y)` do? It depends on
|
||||
whether `x` or `y` or both are variable in `df` or variables in the current
|
||||
environment.
|
||||
|
||||
1. Hidden arguments: What sort of variable will `data.frame(x = "a")`
|
||||
create? It will be either a character vector or a factor depending on
|
||||
the value of the global `stringsAsFactors` option.
|
||||
|
||||
Avoiding these three types of functions helps you to write code that you is easily understand and fails obviously with unexpected input. If this behaviour is so important, why do any functions behave differently? It's because R is not just a programming language, but it's also an environment for interactive data analysis. Some things make sense for interactive use (where you quickly check the output and guessing what you want is ok) but don't make sense for programming (where you want errors to arise as quickly as possible).
|
||||
You might notice that these issues revolve around data frames. That's unfortunate because data frames are the data structure you'll use most commonly. It's ironic, the most frustrating things about programming in R are features that were originally designed to make your data analysis easier! Data frames try very hard to be helpful:
|
||||
|
||||
```{r}
|
||||
df <- data.frame(xy = c("x", "y"))
|
||||
# Character vectors were hard to work with for a long time, so R
|
||||
# helpfully converts to a factor for you:
|
||||
class(df$xy)
|
||||
|
||||
# If you're only selecting a single column, R tries to be helpful
|
||||
# and give you that column, rather than giving you a single column
|
||||
# data frame
|
||||
class(df[, "xy"])
|
||||
|
||||
# If you have long variable names, R is "helpful" and lets you select
|
||||
# them with a unique prefix
|
||||
df$x
|
||||
```
|
||||
|
||||
These features all made sense at the time they were added to R, but computing environments have changed a lot, and these features now tend to cause a lot of problems. tibble disables them for you:
|
||||
|
||||
```{r, error = TRUE}
|
||||
df <- tibble::tibble(xy = c("x", "y"))
|
||||
class(df$xy)
|
||||
class(df[, "xy"])
|
||||
df$x
|
||||
```
|
||||
|
||||
### Unpredictable types
|
||||
|
||||
One of the aspects most frustrating for programming is that `[` returns a vector if the result has a single column, and returns a data frame otherwise. In other words, if you see code like `df[x, ]` you can't predict what it will return without knowing the value of `x`. This can trip you up in surprising ways. For example, imagine you've written this function to return the last row of a data frame:
|
||||
|
||||
```{r}
|
||||
last_row <- function(df) {
|
||||
df[nrow(df), ]
|
||||
}
|
||||
```
|
||||
|
||||
It's not always going to return a row! If you give it a single column data frame, it will return a single number:
|
||||
|
||||
```{r}
|
||||
df <- data.frame(x = 1:3)
|
||||
last_row(df)
|
||||
```
|
||||
|
||||
There are two ways to avoid this problem:
|
||||
|
||||
* Use `drop = FALSE`: `df[x, , drop = FALSE]`.
|
||||
* Subset the data frame like a list: `df[x]`.
|
||||
|
||||
Using one of those techniques for `last_row()` makes it more predictable: you know it will always return a data frame.
|
||||
|
||||
```{r}
|
||||
last_row <- function(df) {
|
||||
df[nrow(df), , drop = FALSE]
|
||||
}
|
||||
last_row(df)
|
||||
```
|
||||
|
||||
Another common cause of problems is the `sapply()` function. If you've never heard of it before, feel free to skip this bit: just remember to avoid it! The problem with `sapply()` is that it tries to guess what the simplest form of output is, and it always succeeds.
|
||||
|
||||
The following code shows how `sapply()` can produce three different types of data depending on the input.
|
||||
|
||||
```{r}
|
||||
df <- data.frame(
|
||||
a = 1L,
|
||||
b = 1.5,
|
||||
y = Sys.time(),
|
||||
z = ordered(1)
|
||||
)
|
||||
|
||||
|
||||
df[1:4] |> sapply(class) |> str()
|
||||
df[1:2] |> sapply(class) |> str()
|
||||
df[3:4] |> sapply(class) |> str()
|
||||
```
|
||||
|
||||
In the next chapter, you'll learn about the purrr package which provides a variety of alternatives. In this case, you could use `map_chr()` which always returns a character vector: if it can't, it will throw an error. Another option is the base `vapply()` function which takes a third argument indicating what the output should look like.
|
||||
|
||||
This doesn't make `sapply()` bad and `vapply()` and `map_chr()` good. `sapply()` is nice because you can use it interactively without having to think about what `f` will return. 95% of the time it will do the right thing, and if it doesn't you can quickly fix it. `map_chr()` is more important when you're programming because a clear error message is more valuable when an operation is buried deep inside a tree of function calls. At this point it's worth thinking more about
|
||||
|
||||
### Non-standard evaluation
|
||||
|
||||
You've learned a number of functions that implement special lookup rules:
|
||||
|
||||
```{r, eval = FALSE}
|
||||
ggplot(mpg, aes(displ, cty)) + geom_point()
|
||||
filter(mpg, displ > 10)
|
||||
```
|
||||
|
||||
These are called "non-standard evaluation", or NSE for short, because the usual lookup rules don't apply. In both cases above neither `displ` nor `cty` are present in the global environment. Instead both ggplot2 and dplyr look for them first in a data frame. This is great for interactive use, but can cause problems inside a function because they'll fall back to the global environment if the variable isn't found.
|
||||
|
||||
[Talk a little bit about the standard scoping rules]
|
||||
|
||||
For example, take this function:
|
||||
|
||||
```{r}
|
||||
big_x <- function(df, threshold) {
|
||||
dplyr::filter(df, x > threshold)
|
||||
}
|
||||
```
|
||||
|
||||
There are two ways in which this function can fail:
|
||||
|
||||
1. `df$x` might not exist. There are two potential failure modes:
|
||||
|
||||
```{r, error = TRUE}
|
||||
big_x(mtcars, 10)
|
||||
|
||||
x <- 1
|
||||
big_x(mtcars, 10)
|
||||
```
|
||||
|
||||
The second failure mode is particularly pernicious because it doesn't
|
||||
throw an error, but instead silently returns an incorrect result. It
|
||||
works because by design `filter()` looks in both the data frame and
|
||||
the parent environment.
|
||||
|
||||
It is unlikely that the variable you care about will both be missing where
|
||||
you expect it, and present where you don't expect it. But I think it's
|
||||
worth weighing heavily in your analysis of potential failure modes because
|
||||
it's a failure that's easy to miss (since it just silently gives a bad
|
||||
result), and hard to track down (since you need to read a lot of context).
|
||||
|
||||
1. `df$threshold` might exist:
|
||||
|
||||
```{r}
|
||||
df <- tibble::tibble(x = 1:10, threshold = 100)
|
||||
big_x(df, 5)
|
||||
```
|
||||
|
||||
Again, this is bad because it silently gives an unexpected result.
|
||||
|
||||
How can you avoid this problem? Currently, you need to do this:
|
||||
|
||||
```{r}
|
||||
big_x <- function(df, threshold) {
|
||||
if (!"x" %in% names(df))
|
||||
stop("`df` must contain variable called `x`.", call. = FALSE)
|
||||
|
||||
if ("threshold" %in% names(df))
|
||||
stop("`df` must not contain variable called `threshold`.", call. = FALSE)
|
||||
|
||||
dplyr::filter(df, x > threshold)
|
||||
}
|
||||
```
|
||||
|
||||
Because dplyr currently has no way to force a name to be interpreted as either a local or parent variable, as I've only just realised, that's really why you should avoid NSE. In a future version you should be able to do:
|
||||
|
||||
```{r}
|
||||
big_x <- function(df, threshold) {
|
||||
dplyr::filter(df, local(x) > parent(threshold))
|
||||
}
|
||||
```
|
||||
|
||||
Another option is to implement it yourself using base subsetting:
|
||||
|
||||
```{r}
|
||||
big_x <- function(df, threshold) {
|
||||
rows <- df$x > threshold
|
||||
df[!is.na(rows) & rows, , drop = FALSE]
|
||||
}
|
||||
```
|
||||
|
||||
The challenge is remembering that `filter()` also drops missing values, and you also need to remember to use `drop = FALSE`!
|
||||
|
||||
### Relying on global options
|
||||
|
||||
Functions are easiest to reason about if they have two properties:
|
||||
|
||||
1. Their output only depends on their inputs.
|
||||
1. They don't affect the outside world except through their return value.
|
||||
|
||||
The first property is particularly important. If a function has hidden additional inputs, it's very difficult to even know where the important context is!
|
||||
|
||||
The biggest breakers of this rule in base R are functions that create data frames. Most of these functions have a `stringsAsFactors` argument that defaults to `getOption("stringsAsFactors")`. This means that a global option affects the operation of a very large number of functions, and you need to be aware that, depending on an external state, a function might produce either a character vector or a factor. In this book, we steer you away from that problem by recommending functions like `readr::read_csv()` and `tibble::tibble()` that don't rely on this option. But be aware of it! Generally if a function is affected by a global option, you should avoid setting it.
|
||||
|
||||
Only use `options()` to control side-effects of a function. The value of an option should never affect the return value of a function. There are only three violations of this rule in base R: `stringsAsFactors`, `encoding`, `na.action`. For example, base R lets you control the number of digits printed in default displays with (e.g.) `options(digits = 3)`. This is a good use of an option because it's something that people frequently want control over, but doesn't affect the computation of a result, just its display. Follow this principle with your own use of options.
|
||||
|
||||
### Trying too hard
|
||||
|
||||
Another class of problems is functions that try really really hard to always return a useful result. Unfortunately they try so hard that they never throw error messages so you never find out if the input is really really weird.
|
||||
|
||||
### Exercises
|
||||
|
||||
1. Look at the `encoding` argument to `file()`, `url()`, `gzfile()` etc.
|
||||
What is the default value? Why should you avoid setting the default
|
||||
value on a global level?
|
Loading…
Reference in New Issue