library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom 1.0.5 ✔ recipes 1.0.10
## ✔ dials 1.2.1 ✔ rsample 1.2.1
## ✔ dplyr 1.1.4 ✔ tibble 3.2.1
## ✔ ggplot2 3.5.1 ✔ tidyr 1.3.1
## ✔ infer 1.0.7 ✔ tune 1.2.1
## ✔ modeldata 1.3.0 ✔ workflows 1.1.4
## ✔ parsnip 1.2.1 ✔ workflowsets 1.1.0
## ✔ purrr 1.0.2 ✔ yardstick 1.3.1
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ✖ recipes::step() masks stats::step()
## • Use suppressPackageStartupMessages() to eliminate package startup messages
library(palmerpenguins)
##
## 次のパッケージを付け加えます: 'palmerpenguins'
## 以下のオブジェクトは 'package:modeldata' からマスクされています:
##
## penguins
set.seed(123)
Rで、tidymodelsを利用してサポートベクターマシンによる分類をためしてみました。tidymodelsの勉強には、著者のお一人の瓜生さんからいただいた下記書籍を参考にしました。
松村優哉・瓜生真也・吉村広志 Rユーザのためのtidymodels[実践]入門—モダンな統計・機械学習モデリングの世界 技術評論社
準備
tidymodelsパッケージを読み込みます。データにはpalmerpenguinsを使用しました。
データの加工
palmerpenguinsの3種のペンギンのうち、今回はアデリー (Adelie) とジェンツー (Gentoo) のみを使用して、bill_length_mm
とbill_depth_mm
の2つの測定値から両者の分類を試します。使用するデータだけを残してpenguins2
というオブジェクトに代入しました。
<- penguins |>
penguins2 ::filter(species %in% c("Adelie", "Gentoo")) |>
dplyr::mutate(species = factor(species,
dplyrlevels = c("Adelie", "Gentoo"))) |>
::select(species, bill_length_mm, bill_depth_mm) dplyr
データの確認
プロットしてデータを確認してみます。
%>%
penguins2 ::filter(!is.na(bill_length_mm) &
dplyr!is.na(bill_depth_mm)) %>%
ggplot(aes(x = bill_length_mm, y = bill_depth_mm,
color = species)) +
geom_point()
レシピ
レシピの定義と前処理です。今回はデータ全体を使用することにして、学習用とテスト用とに分割するということはしていません。欠測値の補間はおこなっています。
<- recipe(species ~ bill_length_mm + bill_depth_mm,
penguins_recipe data = penguins2) |>
step_impute_mean(bill_length_mm, bill_depth_mm)
モデル
線形サポートベクターマシーンを使用します。modeは分類として、engineにはkernlabを使用します。costには、デフォルトと同じ1を与えています。
<- svm_linear(mode = "classification",
svm_model engine = "kernlab",
cost = 1)
ワークフロー
ワークフローのオブジェクトをつくって、レシピとモデルを追加します。
<- workflow() |>
penguins_wf add_recipe(penguins_recipe) |>
add_model(svm_model)
あてはめ
あてはめを実行します。
<- penguins_wf |>
penguins_fit fit(data = penguins2)
## Setting default kernel parameters
結果
結果を表示します。
|>
penguins_fit extract_fit_parsnip()
## parsnip model object
##
## Support Vector Machine object of class "ksvm"
##
## SV type: C-svc (classification)
## parameter : cost C = 1
##
## Linear (vanilla) kernel function.
##
## Number of Support Vectors : 10
##
## Objective Function Value : -6.17
## Training error : 0.007246
## Probability model included.
判別境界の図示
グリッドで新データを生成して予測をおこない、判別境界を図示してみます。
<- expand_grid(bill_length_mm = seq(30, 60, length = 101),
new_data bill_depth_mm = seq(12.5, 22.5, length = 101))
<- augment(penguins_fit,
penguins_predict new_data = new_data) |>
::rename(species = .pred_class)
dplyrggplot(penguins_predict,
aes(x = bill_length_mm, y = bill_depth_mm, fill = species)) +
geom_raster(alpha = 0.3) +
geom_point(data = dplyr::filter(penguins2,
!is.na(bill_length_mm) &
!is.na(bill_depth_mm)),
mapping = aes(x = bill_length_mm, y = bill_depth_mm,
colour = species))