RWEP/SD/20240409_1_model/index.qmd

1426 lines
28 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

---
title: "模型构建"
subtitle: 《区域水环境污染数据分析实践》<br>Data analysis practice of regional water environment pollution
author: 苏命、王为东<br>中国科学院大学资源与环境学院<br>中国科学院生态环境研究中心
date: today
lang: zh
format:
revealjs:
theme: dark
slide-number: true
chalkboard:
buttons: true
preview-links: auto
lang: zh
toc: true
toc-depth: 1
toc-title: 大纲
logo: ./_extensions/inst/img/ucaslogo.png
css: ./_extensions/inst/css/revealjs.css
pointer:
key: "p"
color: "#32cd32"
pointerSize: 18
revealjs-plugins:
- pointer
filters:
- d2
knitr:
opts_chunk:
dev: "svg"
retina: 3
execute:
freeze: auto
cache: true
echo: true
fig-width: 5
fig-height: 6
---
# tidymodels主要步骤
```{r}
#| echo: false
hexes <- function(..., size = 64) {
x <- c(...)
x <- sort(unique(x), decreasing = TRUE)
right <- (seq_along(x) - 1) * size
res <- glue::glue(
'![](hexes/<x>.png){.absolute top=-20 right=<right> width="<size>" height="<size * 1.16>"}',
.open = "<", .close = ">"
)
paste0(res, collapse = " ")
}
knitr::opts_chunk$set(
digits = 3,
comment = "#>",
dev = 'svglite'
)
# devtools::install_github("gadenbuie/countdown")
# library(countdown)
library(ggplot2)
theme_set(theme_bw())
options(cli.width = 70, ggplot2.discrete.fill = c("#7e96d5", "#de6c4e"))
train_color <- "#1a162d"
test_color <- "#cd4173"
data_color <- "#767381"
assess_color <- "#84cae1"
splits_pal <- c(data_color, train_color, test_color)
```
## 何为tidymodels? {background-image="images/tm-org.png" background-size="80%"}
```{r load-tm}
#| message: true
#| echo: true
#| warning: true
library(tidymodels)
```
## 整体思路
```{r diagram-split, echo = FALSE}
#| fig-align: "center"
knitr::include_graphics("images/whole-game-split.jpg")
```
## 整体思路
```{r diagram-model-1, echo = FALSE}
#| fig-align: "center"
knitr::include_graphics("images/whole-game-model-1.jpg")
```
:::notes
Stress that we are **not** fitting a model on the entire training set other than for illustrative purposes in deck 2.
:::
## 整体思路
```{r diagram-model-n, echo = FALSE}
#| fig-align: "center"
knitr::include_graphics("images/whole-game-model-n.jpg")
```
## 整体思路
```{r, echo = FALSE}
#| fig-align: "center"
knitr::include_graphics("images/whole-game-resamples.jpg")
```
## 整体思路
```{r, echo = FALSE}
#| fig-align: "center"
knitr::include_graphics("images/whole-game-select.jpg")
```
## 整体思路
```{r diagram-final-fit, echo = FALSE}
#| fig-align: "center"
knitr::include_graphics("images/whole-game-final-fit.jpg")
```
## 整体思路
```{r diagram-final-performance, echo = FALSE}
#| fig-align: "center"
knitr::include_graphics("images/whole-game-final-performance.jpg")
```
## 相关包的安装
```{r load-pkgs}
#| eval: false
# Install the packages for the workshop
pkgs <-
c("bonsai", "doParallel", "embed", "finetune", "lightgbm", "lme4",
"plumber", "probably", "ranger", "rpart", "rpart.plot", "rules",
"splines2", "stacks", "text2vec", "textrecipes", "tidymodels",
"vetiver", "remotes")
install.packages(pkgs)
```
. . .
<br></br>
## Data on Chicago taxi trips
```{r taxi-print}
library(tidymodels)
taxi
```
## 数据分割与使用
对于机器学习,我们通常将数据分成训练集和测试集:
. . .
- 训练集用于估计模型参数。
- 测试集用于独立评估模型性能。
. . .
在训练过程中不要使用测试集。
. . .
```{r test-train-split}
#| echo: false
#| fig.width: 12
#| fig.height: 3
#|
set.seed(123)
library(forcats)
one_split <- slice(taxi, 1:30) %>%
initial_split() %>%
tidy() %>%
add_row(Row = 1:30, Data = "Original") %>%
mutate(Data = case_when(
Data == "Analysis" ~ "Training",
Data == "Assessment" ~ "Testing",
TRUE ~ Data
)) %>%
mutate(Data = factor(Data, levels = c("Original", "Training", "Testing")))
all_split <-
ggplot(one_split, aes(x = Row, y = fct_rev(Data), fill = Data)) +
geom_tile(color = "white",
linewidth = 1) +
scale_fill_manual(values = splits_pal, guide = "none") +
theme_minimal() +
theme(axis.text.y = element_text(size = rel(2)),
axis.text.x = element_blank(),
legend.position = "top",
panel.grid = element_blank()) +
coord_equal(ratio = 1) +
labs(x = NULL, y = NULL)
all_split
```
## The initial split
```{r taxi-split}
set.seed(123)
taxi_split <- initial_split(taxi)
taxi_split
```
## Accessing the data
```{r taxi-train-test}
taxi_train <- training(taxi_split)
taxi_test <- testing(taxi_split)
```
## The training set
```{r taxi-train}
taxi_train
```
## 练习
```{r taxi-split-prop}
set.seed(123)
taxi_split <- initial_split(taxi, prop = 0.8)
taxi_train <- training(taxi_split)
taxi_test <- testing(taxi_split)
nrow(taxi_train)
nrow(taxi_test)
```
## Stratification
Use `strata = tip`
```{r taxi-split-prop-strata}
set.seed(123)
taxi_split <- initial_split(taxi, prop = 0.8, strata = tip)
taxi_split
```
## Stratification
Stratification often helps, with very little downside
```{r taxi-tip-pct-by-split, echo = FALSE}
bind_rows(
taxi_train %>% mutate(split = "train"),
taxi_test %>% mutate(split = "test")
) %>%
ggplot(aes(x = split, fill = tip)) +
geom_bar(position = "fill")
```
## 模型类型
模型多种多样
- `lm` for linear model
- `glm` for generalized linear model (e.g. logistic regression)
- `glmnet` for regularized regression
- `keras` for regression using TensorFlow
- `stan` for Bayesian regression
- `spark` for large data sets
## 指定模型
```{r}
#| echo: false
library(tidymodels)
set.seed(123)
taxi_split <- initial_split(taxi, prop = 0.8, strata = tip)
taxi_train <- training(taxi_split)
taxi_test <- testing(taxi_split)
```
```{r logistic-reg}
logistic_reg()
```
:::notes
Models have default engines
:::
## To specify a model
```{r logistic-reg-glmnet}
logistic_reg() %>%
set_engine("glmnet")
```
. . .
```{r logistic-reg-stan}
logistic_reg() %>%
set_engine("stan")
```
::: columns
::: {.column width="40%"}
- Choose a model
- Specify an engine
- Set the [mode]{.underline}
:::
::: {.column width="60%"}
![](images/taxi_spinning.svg)
:::
:::
## To specify a model
```{r decision-tree}
decision_tree()
```
:::notes
Some models have a default mode
:::
## To specify a model
```{r decision-tree-classification}
decision_tree() %>%
set_mode("classification")
```
. . .
<br></br>
::: r-fit-text
All available models are listed at <https://www.tidymodels.org/find/parsnip/>
:::
## Workflows
```{r good-workflow}
#| echo: false
#| out-width: '70%'
#| fig-align: 'center'
knitr::include_graphics("images/good_workflow.png")
```
## 为什么要使用 `workflow()`?
- 与基本的 R 工具相比,工作流能更好地处理新的因子水平
. . .
- 除了公式之外,还可以使用其他的预处理器(更多关于高级 tidymodels 中的特征工程!)
. . .
- 在使用多个模型时,它们可以帮助组织工作
. . .
- [最重要的是]{.underline},工作流涵盖了整个建模过程:`fit()` 和 `predict()` 不仅适用于实际的模型拟合,还适用于预处理步骤
::: notes
工作流比基本的 R 处理水平更好的两种方式:
- 强制要求在预测时不允许出现新的水平(这是一个可选的检查,可以关闭)
- 恢复在拟合时存在但在预测时缺失的水平(例如,“新”数据中没有该水平的实例)
:::
## A model workflow
```{r tree-spec}
tree_spec <-
decision_tree(cost_complexity = 0.002) %>%
set_mode("classification")
tree_spec %>%
fit(tip ~ ., data = taxi_train)
```
## A model workflow
```{r tree-wflow}
tree_spec <-
decision_tree(cost_complexity = 0.002) %>%
set_mode("classification")
workflow() %>%
add_formula(tip ~ .) %>%
add_model(tree_spec) %>%
fit(data = taxi_train)
```
## A model workflow
```{r tree-wflow-fit}
tree_spec <-
decision_tree(cost_complexity = 0.002) %>%
set_mode("classification")
workflow(tip ~ ., tree_spec) %>%
fit(data = taxi_train)
```
## 预测
How do you use your new `tree_fit` model?
```{r tree-wflow-fit-2}
tree_spec <-
decision_tree(cost_complexity = 0.002) %>%
set_mode("classification")
tree_fit <-
workflow(tip ~ ., tree_spec) %>%
fit(data = taxi_train)
```
## 练习
*Run:*
`predict(tree_fit, new_data = taxi_test)`
. . .
*Run:*
`augment(tree_fit, new_data = taxi_test)`
*What do you get?*
## tidymodels 的预测
- 预测结果始终在一个 **tibble** 内
- 列名和类型可读性强
- `new_data` 中的行数和输出中的行数**相同**
## 理解模型
如何 **理解**`tree_fit` 模型?
```{r plot-tree-fit-4}
#| echo: false
#| fig-align: center
#| fig-width: 8
#| fig-height: 5
#| out-width: 100%
library(rpart.plot)
tree_fit %>%
extract_fit_engine() %>%
rpart.plot(roundint = FALSE)
```
## Evaluating models: 预测值
```{r}
#| echo: false
library(tidymodels)
set.seed(123)
taxi_split <- initial_split(taxi, prop = 0.8, strata = tip)
taxi_train <- training(taxi_split)
taxi_test <- testing(taxi_split)
tree_spec <- decision_tree(cost_complexity = 0.0001, mode = "classification")
taxi_wflow <- workflow(tip ~ ., tree_spec)
taxi_fit <- fit(taxi_wflow, taxi_train)
```
```{r taxi-fit-augment}
augment(taxi_fit, new_data = taxi_train) %>%
relocate(tip, .pred_class, .pred_yes, .pred_no)
```
## Confusion matrix
![](images/confusion-matrix.png)
## Confusion matrix
```{r conf-mat}
augment(taxi_fit, new_data = taxi_train) %>%
conf_mat(truth = tip, estimate = .pred_class)
```
## Confusion matrix
```{r conf-mat-plot}
augment(taxi_fit, new_data = taxi_train) %>%
conf_mat(truth = tip, estimate = .pred_class) %>%
autoplot(type = "heatmap")
```
## Metrics for model performance
::: columns
::: {.column width="60%"}
```{r acc}
augment(taxi_fit, new_data = taxi_train) %>%
accuracy(truth = tip, estimate = .pred_class)
```
:::
::: {.column width="40%"}
![](images/confusion-matrix-accuracy.png)
:::
:::
## 二分类模型评估
模型的敏感性Sensitivity和特异性Specificity是评估二分类模型性能的重要指标
- **敏感性**Sensitivity也称为真阳性率衡量了模型正确识别正类别样本的能力。公式为真阳性数除以真阳性数加上假阴性数
$$
\text{Sensitivity} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Negatives}}
$$
- **特异性**Specificity也称为真阴性率衡量了模型正确识别负类别样本的能力。公式为真阴性数除以真阴性数加上假阳性数
$$
\text{Specificity} = \frac{\text{True Negatives}}{\text{True Negatives} + \text{False Positives}}
$$
在评估模型时,我们希望敏感性和特异性都很高。高敏感性表示模型能够捕获真正的正类别样本,高特异性表示模型能够准确排除负类别样本。
## Metrics for model performance
::: columns
::: {.column width="60%"}
```{r sens}
augment(taxi_fit, new_data = taxi_train) %>%
sensitivity(truth = tip, estimate = .pred_class)
```
:::
::: {.column width="40%"}
![](images/confusion-matrix-sensitivity.png)
:::
:::
## Metrics for model performance
::: columns
::: {.column width="60%"}
```{r sens-2}
#| code-line-numbers: "3-6"
augment(taxi_fit, new_data = taxi_train) %>%
sensitivity(truth = tip, estimate = .pred_class)
```
<br>
```{r spec}
augment(taxi_fit, new_data = taxi_train) %>%
specificity(truth = tip, estimate = .pred_class)
```
:::
::: {.column width="40%"}
![](images/confusion-matrix-specificity.png)
:::
:::
## Metrics for model performance
We can use `metric_set()` to combine multiple calculations into one
```{r taxi-metrics}
taxi_metrics <- metric_set(accuracy, specificity, sensitivity)
augment(taxi_fit, new_data = taxi_train) %>%
taxi_metrics(truth = tip, estimate = .pred_class)
```
## Metrics for model performance
```{r taxi-metrics-grouped}
taxi_metrics <- metric_set(accuracy, specificity, sensitivity)
augment(taxi_fit, new_data = taxi_train) %>%
group_by(local) %>%
taxi_metrics(truth = tip, estimate = .pred_class)
```
## Varying the threshold
```{r}
#| label: thresholds
#| echo: false
augment(taxi_fit, new_data = taxi_train) %>%
roc_curve(truth = tip, .pred_yes) %>%
filter(is.finite(.threshold)) %>%
pivot_longer(c(specificity, sensitivity), names_to = "statistic", values_to = "value") %>%
rename(`event threshold` = .threshold) %>%
ggplot(aes(x = `event threshold`, y = value, col = statistic, group = statistic)) +
geom_line() +
scale_color_brewer(palette = "Dark2") +
labs(y = NULL) +
coord_equal() +
theme(legend.position = "top")
```
## ROC 曲线
- ROCReceiver Operating Characteristic曲线用于评估二分类模型的性能特别是在不同的阈值下比较模型的敏感性和特异性。
- ROC曲线的横轴是假阳性率False Positive RateFPR纵轴是真阳性率True Positive RateTPR。在ROC曲线上每个点对应于一个特定的阈值。通过改变阈值我们可以观察到模型在不同条件下的表现。
- ROC曲线越接近左上角0,1说明模型的性能越好因为这表示在较低的假阳性率下模型能够获得较高的真阳性率。ROC曲线下面积Area Under the ROC CurveAUC也是评估模型性能的一种指标AUC值越大表示模型性能越好。
## ROC curve plot
```{r roc-curve}
#| fig-width: 6
#| fig-height: 6
#| output-location: "column"
augment(taxi_fit, new_data = taxi_train) %>%
roc_curve(truth = tip, .pred_yes) %>%
autoplot()
```
## 过度拟合
![](./images/tuning-overfitting-train-1.svg)
## 过度拟合
![](images/tuning-overfitting-test-1.svg)
## Cross-validation {background-color="white" background-image="https://www.tmwr.org/premade/resampling.svg" background-size="80%"}
## Cross-validation
![](https://www.tmwr.org/premade/three-CV.svg)
## Cross-validation
![](https://www.tmwr.org/premade/three-CV-iter.svg)
## Cross-validation
```{r vfold-cv}
vfold_cv(taxi_train) # v = 10 is default
```
## Cross-validation
What is in this?
```{r taxi-splits}
taxi_folds <- vfold_cv(taxi_train)
taxi_folds$splits[1:3]
```
::: notes
Talk about a list column, storing non-atomic types in dataframe
:::
## Cross-validation
```{r vfold-cv-v}
vfold_cv(taxi_train, v = 5)
```
## Cross-validation
```{r vfold-cv-strata}
vfold_cv(taxi_train, strata = tip)
```
. . .
Stratification often helps, with very little downside
## Cross-validation
We'll use this setup:
```{r taxi-folds}
set.seed(123)
taxi_folds <- vfold_cv(taxi_train, v = 10, strata = tip)
taxi_folds
```
. . .
Set the seed when creating resamples
## Fit our model to the resamples
```{r fit-resamples}
taxi_res <- fit_resamples(taxi_wflow, taxi_folds)
taxi_res
```
## Evaluating model performance
```{r collect-metrics}
taxi_res %>%
collect_metrics()
```
::: notes
collect_metrics() 是一套 collect_*() 函数之一,可用于处理调参结果的列。调参结果中以 . 为前缀的大多数列都有对应的 collect_*() 函数,可以进行常见摘要选项的汇总。
:::
. . .
We can reliably measure performance using only the **training** data 🎉
## Comparing metrics
How do the metrics from resampling compare to the metrics from training and testing?
```{r calc-roc-auc}
#| echo: false
taxi_training_roc_auc <-
taxi_fit %>%
augment(taxi_train) %>%
roc_auc(tip, .pred_yes) %>%
pull(.estimate) %>%
round(digits = 2)
taxi_testing_roc_auc <-
taxi_fit %>%
augment(taxi_test) %>%
roc_auc(tip, .pred_yes) %>%
pull(.estimate) %>%
round(digits = 2)
```
::: columns
::: {.column width="50%"}
```{r collect-metrics-2}
taxi_res %>%
collect_metrics() %>%
select(.metric, mean, n)
```
:::
::: {.column width="50%"}
The ROC AUC previously was
- `r taxi_training_roc_auc` for the training set
- `r taxi_testing_roc_auc` for test set
:::
:::
. . .
Remember that:
⚠️ the training set gives you overly optimistic metrics
⚠️ the test set is precious
## Evaluating model performance
```{r save-predictions}
# Save the assessment set results
ctrl_taxi <- control_resamples(save_pred = TRUE)
taxi_res <- fit_resamples(taxi_wflow, taxi_folds, control = ctrl_taxi)
taxi_res
```
## Evaluating model performance
```{r collect-predictions}
# Save the assessment set results
taxi_preds <- collect_predictions(taxi_res)
taxi_preds
```
## Evaluating model performance
```{r taxi-metrics-by-id}
taxi_preds %>%
group_by(id) %>%
taxi_metrics(truth = tip, estimate = .pred_class)
```
## Where are the fitted models?
```{r taxi-res}
taxi_res
```
## Bootstrapping
![](https://www.tmwr.org/premade/bootstraps.svg)
## Bootstrapping
```{r bootstraps}
set.seed(3214)
bootstraps(taxi_train)
```
## Monte Carlo Cross-Validation
```{r mc-cv}
set.seed(322)
mc_cv(taxi_train, times = 10)
```
## Validation set
```{r validation-split}
set.seed(853)
taxi_val_split <- initial_validation_split(taxi, strata = tip)
validation_set(taxi_val_split)
```
## Create a random forest model
```{r rf-spec}
rf_spec <- rand_forest(trees = 1000, mode = "classification")
rf_spec
```
## Create a random forest model
```{r rf-wflow}
rf_wflow <- workflow(tip ~ ., rf_spec)
rf_wflow
```
## Evaluating model performance
```{r collect-metrics-rf}
ctrl_taxi <- control_resamples(save_pred = TRUE)
# Random forest uses random numbers so set the seed first
set.seed(2)
rf_res <- fit_resamples(rf_wflow, taxi_folds, control = ctrl_taxi)
collect_metrics(rf_res)
```
## The whole game - status update
```{r diagram-select, echo = FALSE}
#| fig-align: "center"
knitr::include_graphics("images/whole-game-transparent-select.jpg")
```
## The final fit
```{r final-fit}
# taxi_split has train + test info
final_fit <- last_fit(rf_wflow, taxi_split)
final_fit
```
## 何为`final_fit`?
```{r collect-metrics-final-fit}
collect_metrics(final_fit)
```
. . .
These are metrics computed with the **test** set
## 何为`final_fit`?
```{r collect-predictions-final-fit}
collect_predictions(final_fit)
```
## 何为`final_fit`?
```{r extract-workflow}
extract_workflow(final_fit)
```
. . .
Use this for **prediction** on new data, like for deploying
## Tuning models - Specifying tuning parameters
```{r}
#| label: tag-for-tuning
#| code-line-numbers: "1|"
rf_spec <- rand_forest(min_n = tune()) %>%
set_mode("classification")
rf_wflow <- workflow(tip ~ ., rf_spec)
rf_wflow
```
## Try out multiple values
`tune_grid()` works similar to `fit_resamples()` but covers multiple parameter values:
```{r}
#| label: rf-tune_grid
#| code-line-numbers: "2|3-4|5|"
set.seed(22)
rf_res <- tune_grid(
rf_wflow,
taxi_folds,
grid = 5
)
```
## Compare results
Inspecting results and selecting the best-performing hyperparameter(s):
```{r}
#| label: rf-results
show_best(rf_res)
best_parameter <- select_best(rf_res)
best_parameter
```
`collect_metrics()` and `autoplot()` are also available.
## The final fit
```{r}
#| label: rf-finalize
rf_wflow <- finalize_workflow(rf_wflow, best_parameter)
final_fit <- last_fit(rf_wflow, taxi_split)
collect_metrics(final_fit)
```
# 实践部分
## 数据
```{r}
require(tidyverse)
sitedf <- readr::read_csv("https://www.epa.gov/sites/default/files/2014-01/nla2007_sampledlakeinformation_20091113.csv") |>
select(SITE_ID,
lon = LON_DD,
lat = LAT_DD,
name = LAKENAME,
area = LAKEAREA,
zmax = DEPTHMAX
) |>
group_by(SITE_ID) |>
summarize(lon = mean(lon, na.rm = TRUE),
lat = mean(lat, na.rm = TRUE),
name = unique(name),
area = mean(area, na.rm = TRUE),
zmax = mean(zmax, na.rm = TRUE))
visitdf <- readr::read_csv("https://www.epa.gov/sites/default/files/2013-09/nla2007_profile_20091008.csv") |>
select(SITE_ID,
date = DATE_PROFILE,
year = YEAR,
visit = VISIT_NO
) |>
distinct()
waterchemdf <- readr::read_csv("https://www.epa.gov/sites/default/files/2013-09/nla2007_profile_20091008.csv") |>
select(SITE_ID,
date = DATE_PROFILE,
depth = DEPTH,
temp = TEMP_FIELD,
do = DO_FIELD,
ph = PH_FIELD,
cond = COND_FIELD,
)
sddf <- readr::read_csv("https://www.epa.gov/sites/default/files/2014-10/nla2007_secchi_20091008.csv") |>
select(SITE_ID,
date = DATE_SECCHI,
sd = SECMEAN,
clear_to_bottom = CLEAR_TO_BOTTOM
)
trophicdf <- readr::read_csv("https://www.epa.gov/sites/default/files/2014-10/nla2007_trophic_conditionestimate_20091123.csv") |>
select(SITE_ID,
visit = VISIT_NO,
tp = PTL,
tn = NTL,
chla = CHLA) |>
left_join(visitdf, by = c("SITE_ID", "visit")) |>
select(-year, -visit) |>
group_by(SITE_ID, date) |>
summarize(tp = mean(tp, na.rm = TRUE),
tn = mean(tn, na.rm = TRUE),
chla = mean(chla, na.rm = TRUE)
)
phytodf <- readr::read_csv("https://www.epa.gov/sites/default/files/2014-10/nla2007_phytoplankton_softalgaecount_20091023.csv") |>
select(SITE_ID,
date = DATEPHYT,
depth = SAMPLE_DEPTH,
phyta = DIVISION,
genus = GENUS,
species = SPECIES,
tax = TAXANAME,
abund = ABUND) |>
mutate(phyta = gsub(" .*$", "", phyta)) |>
filter(!is.na(genus)) |>
group_by(SITE_ID, date, depth, phyta, genus) |>
summarize(abund = sum(abund, na.rm = TRUE)) |>
nest(phytodf = -c(SITE_ID, date))
envdf <- waterchemdf |>
filter(depth < 2) |>
select(-depth) |>
group_by(SITE_ID, date) |>
summarise_all(~mean(., na.rm = TRUE)) |>
ungroup() |>
left_join(sddf, by = c("SITE_ID", "date")) |>
left_join(trophicdf, by = c("SITE_ID", "date"))
nla <- envdf |>
left_join(phytodf) |>
left_join(sitedf, by = "SITE_ID") |>
filter(!purrr::map_lgl(phytodf, is.null)) |>
mutate(cyanophyta = purrr::map(phytodf, ~ .x |>
dplyr::filter(phyta == "Cyanophyta") |>
summarize(cyanophyta = sum(abund, na.rm = TRUE))
)) |>
unnest(cyanophyta) |>
select(-phyta) |>
mutate(clear_to_bottom = ifelse(is.na(clear_to_bottom), TRUE, FALSE))
# library(rmdify)
# library(dwfun)
# dwfun::init()
```
## 数据
```{r}
skimr::skim(nla)
```
## 简单模型
```{r}
nla |>
filter(tp > 1) |>
ggplot(aes(tn, tp)) +
geom_point() +
geom_smooth(method = "lm") +
scale_x_log10(breaks = scales::trans_breaks("log10", function(x) 10^x),
labels = scales::trans_format("log10", scales::math_format(10^.x))) +
scale_y_log10(breaks = scales::trans_breaks("log10", function(x) 10^x),
labels = scales::trans_format("log10", scales::math_format(10^.x)))
m1 <- lm(log10(tp) ~ log10(tn), data = nla)
summary(m1)
```
## 复杂指标
```{r}
nla |>
filter(tp > 1) |>
ggplot(aes(tp, cyanophyta)) +
geom_point() +
geom_smooth(method = "lm") +
scale_x_log10(breaks = scales::trans_breaks("log10", function(x) 10^x),
labels = scales::trans_format("log10", scales::math_format(10^.x))) +
scale_y_log10(breaks = scales::trans_breaks("log10", function(x) 10^x),
labels = scales::trans_format("log10", scales::math_format(10^.x)))
m2 <- lm(log10(cyanophyta) ~ log10(tp), data = nla)
summary(m2)
```
## tidymodels - Data split
```{r}
(nla_split <- rsample::initial_split(nla, prop = 0.7, strata = zmax))
(nla_train <- training(nla_split))
(nla_test <- testing(nla_split))
```
## tidymodels - recipe
```{r}
nla_formula <- as.formula("cyanophyta ~ temp + do + ph + cond + sd + tp + tn + chla + clear_to_bottom")
# nla_formula <- as.formula("cyanophyta ~ temp + do + ph + cond + sd + tp + tn")
nla_recipe <- recipes::recipe(nla_formula, data = nla_train) |>
recipes::step_string2factor(all_nominal()) |>
recipes::step_nzv(all_nominal()) |>
recipes::step_log(chla, cyanophyta, base = 10) |>
recipes::step_normalize(all_numeric_predictors()) |>
prep()
nla_recipe
```
## tidymodels - cross validation
```{r}
nla_cv <- recipes::bake(
nla_recipe,
new_data = training(nla_split)
) |>
rsample::vfold_cv(v = 10)
nla_cv
```
## tidymodels - Model specification
```{r}
xgboost_model <- parsnip::boost_tree(
mode = "regression",
trees = 1000,
min_n = tune(),
tree_depth = tune(),
learn_rate = tune(),
loss_reduction = tune()
) |>
set_engine("xgboost", objective = "reg:squarederror")
xgboost_model
```
## tidymodels - Grid specification
```{r}
# grid specification
xgboost_params <- dials::parameters(
min_n(),
tree_depth(),
learn_rate(),
loss_reduction()
)
xgboost_params
```
## tidymodels - Grid specification
```{r}
xgboost_grid <- dials::grid_max_entropy(
xgboost_params,
size = 60
)
knitr::kable(head(xgboost_grid))
```
## tidymodels - Workflow
```{r}
xgboost_wf <- workflows::workflow() |>
add_model(xgboost_model) |>
add_formula(nla_formula)
xgboost_wf
```
## tidymodels - Tune
```{r}
#| cache: true
# hyperparameter tuning
if (FALSE) {
xgboost_tuned <- tune::tune_grid(
object = xgboost_wf,
resamples = nla_cv,
grid = xgboost_grid,
metrics = yardstick::metric_set(rmse, rsq, mae),
control = tune::control_grid(verbose = TRUE)
)
saveRDS(xgboost_tuned, "./xgboost_tuned.RDS")
}
xgboost_tuned <- readRDS("./xgboost_tuned.RDS")
```
## tidymodels - Best model
```{r}
xgboost_tuned |>
tune::show_best(metric = "rmse") |>
knitr::kable()
```
## tidymodels - Best model
```{r}
xgboost_tuned |>
collect_metrics()
```
## tidymodels - Best model
```{r}
#| fig-width: 9
#| fig-height: 5
#| out-width: "100%"
xgboost_tuned |>
autoplot()
```
## tidymodels - Best model
```{r}
xgboost_best_params <- xgboost_tuned |>
tune::select_best("rmse")
knitr::kable(xgboost_best_params)
```
## tidymodels - Final model
```{r}
xgboost_model_final <- xgboost_model |>
finalize_model(xgboost_best_params)
xgboost_model_final
```
## tidymodels - Train evaluation
```{r}
(train_processed <- bake(nla_recipe, new_data = nla_train))
```
## tidymodels - Train data
```{r}
train_prediction <- xgboost_model_final |>
# fit the model on all the training data
fit(
formula = nla_formula,
data = train_processed
) |>
# predict the sale prices for the training data
predict(new_data = train_processed) |>
bind_cols(nla_train |>
mutate(.obs = log10(cyanophyta)))
xgboost_score_train <-
train_prediction |>
yardstick::metrics(.obs, .pred) |>
mutate(.estimate = format(round(.estimate, 2), big.mark = ","))
knitr::kable(xgboost_score_train)
```
## tidymodels - train evaluation
```{r}
#| fig-width: 5
#| fig-height: 3
#| out-width: "80%"
train_prediction |>
ggplot(aes(.pred, .obs)) +
geom_point() +
geom_smooth(method = "lm")
```
## tidymodels - test data
```{r}
test_processed <- bake(nla_recipe, new_data = nla_test)
test_prediction <- xgboost_model_final |>
# fit the model on all the training data
fit(
formula = nla_formula,
data = train_processed
) |>
# use the training model fit to predict the test data
predict(new_data = test_processed) |>
bind_cols(nla_test |>
mutate(.obs = log10(cyanophyta)))
# measure the accuracy of our model using `yardstick`
xgboost_score <- test_prediction |>
yardstick::metrics(.obs, .pred) |>
mutate(.estimate = format(round(.estimate, 2), big.mark = ","))
knitr::kable(xgboost_score)
```
## tidymodels - evaluation
```{r}
#| fig-width: 5
#| fig-height: 3
#| out-width: "80%"
cyanophyta_prediction_residual <- test_prediction |>
arrange(.pred) %>%
mutate(residual_pct = (.obs - .pred) / .pred) |>
select(.pred, residual_pct)
cyanophyta_prediction_residual |>
ggplot(aes(x = .pred, y = residual_pct)) +
geom_point() +
xlab("Predicted Cyanophyta") +
ylab("Residual (%)")
```
## tidymodels - test evaluation
```{r}
#| fig-width: 5
#| fig-height: 3
#| out-width: "80%"
test_prediction |>
ggplot(aes(.pred, .obs)) +
geom_point() +
geom_smooth(method = "lm", colour = "black")
```
## 欢迎讨论!{.center}
`r rmdify::slideend(wechat = FALSE, type = "public", tel = FALSE, thislink = "https://drwater.rcees.ac.cn/course/public/RWEP/@PUB/SD/")`