library(nimble)
## nimble version 1.0.1 is loaded.
## For more information on NIMBLE and a User Manual,
## please visit https://R-nimble.org.
##
## Note for advanced users who have written their own MCMC samplers:
## As of version 0.13.0, NIMBLE's protocol for handling posterior
## predictive nodes has changed in a way that could affect user-defined
## samplers in some situations. Please see Section 15.5.1 of the User Manual.
##
## 次のパッケージを付け加えます: 'nimble'
## 以下のオブジェクトは 'package:stats' からマスクされています:
##
## simulate
library(posterior)
## This is posterior version 1.4.1
##
## 次のパッケージを付け加えます: 'posterior'
## 以下のオブジェクトは 'package:stats' からマスクされています:
##
## mad, sd, var
## 以下のオブジェクトは 'package:base' からマスクされています:
##
## %in%, match
library(bayesplot)
## This is bayesplot version 1.10.0
## - Online documentation and vignettes at mc-stan.org/bayesplot
## - bayesplot theme set to bayesplot::theme_default()
## * Does _not_ affect other ggplot2 plots
## * See ?bayesplot_theme_set for details on theme setting
##
## 次のパッケージを付け加えます: 'bayesplot'
## 以下のオブジェクトは 'package:posterior' からマスクされています:
##
## rhat
library(dplyr)
##
## 次のパッケージを付け加えます: 'dplyr'
## 以下のオブジェクトは 'package:stats' からマスクされています:
##
## filter, lag
## 以下のオブジェクトは 'package:base' からマスクされています:
##
## intersect, setdiff, setequal, union
library(stringr)
library(ggplot2)
set.seed(1234)
前回の記事のモデルをNIMBLEでやってみました。
準備
パッケージ
NIMBLEのほか、事後分布の処理のためにposteriorやbayesplotなどのパッケージを読み込んでいます。
データ
前回と同じく、データは、岩波データサイエンスVol.1の松浦さんの記事で使用されたものを利用します。
<- "https://raw.githubusercontent.com/iwanami-datascience/vol1/master/matsuura/example2/input/data-season.txt"
data_url <- read.csv(data_url)
data $Time <- seq_along(data$Y) data
プロットしてデータを確認します。
<- ggplot(data) +
p geom_line(aes(x = Time, y = Y))
print(p)
NIMBLEによるモデリング
モデル
BUGS言語でモデルを記述します。ローカル線形トレンドモデルと季節調整モデルを組み合わせたモデルとなっています。また、T_new
期先まで予測をおこないます。
NIMBLEでMCMC計算を実行しましたが、このBUGSコードはJAGSでも実行可能なはずです(分布の引数に式を使っているので、WinBUGSではエラーになります)。
<- nimbleCode({
code # observation
for (t in 1:T) {
~ dnorm(mu[t] + season[t], tau[1])
Y[t]
}
# level
for (t in 3:T) {
~ dnorm(2 * mu[t - 1] - mu[t - 2], tau[2])
mu[t]
}
# seasonal
for (t in 4:T) {
~ dnorm(-sum(season[(t - 3):(t - 1)]), tau[3])
season[t]
}
# forecast
for (t in (T + 1):(T + T_new)) {
- T] ~ dnorm(mu[t] + season[t], tau[1])
Y_new[t ~ dnorm(2 * mu[t - 1] - mu[t - 2], tau[2])
mu[t] ~ dnorm(-sum(season[(t - 3):(t - 1)]), tau[3])
season[t]
}
# priors
for (t in 1:2) {
~ dnorm(0, 1e-4)
mu[t]
}for (t in 1:3) {
~ dnorm(0, 1e-4)
season[t]
}for (i in 1:3) {
<- 1 / (sigma[i] * sigma[i])
tau[i] ~ dunif(0, 100)
sigma[i]
} })
MCMC
nimbleMCMC
関数でMCMC計算を実行します。samplesAsCodaMCMC = TRUE
とすることで、codaパッケージのMCMC
型として事後分布のサンプルを出力させ、それをposteriorパッケージのas_draws
関数で、draws
型に変換しています。
<- nrow(data)
T_num <- 8
T_new <- function() {
init_fun list(mu = rnorm(T_num + T_new, 20, 2),
season = rnorm(T_num + T_new, 0, 2),
Y_new = rnorm(T_new, 20, 2),
sigma= runif(3, 0, 2))
}<- nimbleMCMC(code,
samp constants = list(T = T_num, T_new = T_new),
data = list(Y = data$Y),
inits = init_fun,
monitors = c("mu", "season", "sigma", "Y_new"),
niter = 55000, nburnin = 5000, nchains = 3,
thin = 50,
samplesAsCodaMCMC = TRUE) |>
as_draws()
## Defining model
## Building model
## Setting data and initial values
## Running calculate on model
## [Note] Any error reports that follow may simply reflect missing values in model variables.
## Checking model sizes and dimensions
## [Note] This model is not fully initialized. This is not an error.
## To see which variables are not initialized, use model$initializeInfo().
## For more information on model initialization, see help(modelInitialization).
## Checking model calculations
## Compiling
## [Note] This may take a minute.
## [Note] Use 'showCompilerOutput = TRUE' to see C++ compilation details.
## running chain 1...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 2...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 3...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
結果
要約統計量
標準偏差のパラメータであるsigma
の要約統計量を表示します。
<- summary(samp)
summay_stat |>
summay_stat filter(str_starts(variable, "sigma"))
## # A tibble: 3 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <num> <num> <num> <num> <num> <num> <num> <num> <num>
## 1 sigma[1] 0.283 0.254 0.195 0.203 0.0313 0.655 1.04 100. 86.8
## 2 sigma[2] 0.109 0.102 0.0353 0.0301 0.0653 0.176 1.00 1370. 1742.
## 3 sigma[3] 0.712 0.710 0.133 0.125 0.501 0.932 1.01 557. 598.
軌跡のプロット
bayesplotパッケージのmcmc_trace
関数でMCMCの軌跡を確認します。
mcmc_trace(samp, pars = c("sigma[1]", "sigma[2]", "sigma[3]"))
水準成分
状態の水準成分の推定値(事後中央値)と90%信用区間をプロットします。
<- summay_stat |>
mu filter(str_starts(variable, "mu")) |>
select(mean, median, q5, q95) |>
slice_head(n = T_num) |>
bind_cols(data)
ggplot(mu) +
geom_ribbon(aes(x = Time, ymin = q5, ymax = q95),
fill = "gray30", alpha = 0.7) +
geom_line(aes(x = Time, y = median)) +
labs(y = "Y")
季節調整成分
状態の季節調整成分の推定値(事後中央値)と90%信用区間をプロットします。
<- summay_stat |>
season filter(str_starts(variable, "season")) |>
select(mean, median, q5, q95) |>
slice_head(n = T_num) |>
bind_cols(data)
ggplot(season) +
geom_ribbon(aes(x = Time, ymin = q5, ymax = q95),
fill = "gray30", alpha = 0.7) +
geom_line(aes(x = Time, y = median)) +
labs(y = "Y")
予測
予測値の要約統計量を表示します。
|>
summay_stat filter(str_starts(variable, "Y_new"))
## # A tibble: 8 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <num> <num> <num> <num> <num> <num> <num> <num> <num>
## 1 Y_new[1] 24.7 24.7 1.32 1.33 22.5 26.8 1.00 802. 1601.
## 2 Y_new[2] 33.7 33.7 1.31 1.25 31.5 35.9 1.00 2075. 2667.
## 3 Y_new[3] 22.1 22.1 1.44 1.37 19.7 24.5 1.00 1675. 2629.
## 4 Y_new[4] 24.6 24.6 1.49 1.46 22.1 27.0 1.00 1888. 2343.
## 5 Y_new[5] 26.4 26.4 2.30 2.33 22.7 30.3 1.00 999. 1652.
## 6 Y_new[6] 35.4 35.4 2.27 2.22 31.8 39.1 1.00 1925. 2347.
## 7 Y_new[7] 23.9 23.8 2.51 2.56 19.8 28.0 1.00 1639. 2411.
## 8 Y_new[8] 26.3 26.3 2.67 2.58 22.0 30.6 1.00 1747. 2400.
最後に、元のデータと、8期先までの予測値(事後中央値)と90%信用区間をプロットします。
<- summay_stat |>
Y_new filter(str_starts(variable, "Y_new")) |>
select(mean, median, q5, q95) |>
bind_cols(tibble(Time = T_num + 1:T_new))
ggplot(data) +
geom_line(aes(x = Time, y = Y)) +
geom_ribbon(data = Y_new, aes(x = Time, ymin = q5, ymax = q95),
fill = "gray30", alpha = 0.7) +
geom_line(data = Y_new, aes(x = Time, y = median))