NIMBLEによる季節調整モデル

R
NIMBLE
作者

伊東宏樹

公開

2023年8月18日

前回の記事のモデルをNIMBLEでやってみました。

準備

パッケージ

NIMBLEのほか、事後分布の処理のためにposteriorやbayesplotなどのパッケージを読み込んでいます。

library(nimble)
## nimble version 1.0.1 is loaded.
## For more information on NIMBLE and a User Manual,
## please visit https://R-nimble.org.
## 
## Note for advanced users who have written their own MCMC samplers:
##   As of version 0.13.0, NIMBLE's protocol for handling posterior
##   predictive nodes has changed in a way that could affect user-defined
##   samplers in some situations. Please see Section 15.5.1 of the User Manual.
## 
##  次のパッケージを付け加えます: 'nimble'
##  以下のオブジェクトは 'package:stats' からマスクされています:
## 
##     simulate
library(posterior)
## This is posterior version 1.4.1
## 
##  次のパッケージを付け加えます: 'posterior'
##  以下のオブジェクトは 'package:stats' からマスクされています:
## 
##     mad, sd, var
##  以下のオブジェクトは 'package:base' からマスクされています:
## 
##     %in%, match
library(bayesplot)
## This is bayesplot version 1.10.0
## - Online documentation and vignettes at mc-stan.org/bayesplot
## - bayesplot theme set to bayesplot::theme_default()
##    * Does _not_ affect other ggplot2 plots
##    * See ?bayesplot_theme_set for details on theme setting
## 
##  次のパッケージを付け加えます: 'bayesplot'
##  以下のオブジェクトは 'package:posterior' からマスクされています:
## 
##     rhat
library(dplyr)
## 
##  次のパッケージを付け加えます: 'dplyr'
##  以下のオブジェクトは 'package:stats' からマスクされています:
## 
##     filter, lag
##  以下のオブジェクトは 'package:base' からマスクされています:
## 
##     intersect, setdiff, setequal, union
library(stringr)
library(ggplot2)

set.seed(1234)

データ

前回と同じく、データは、岩波データサイエンスVol.1の松浦さんの記事で使用されたものを利用します。

data_url <- "https://raw.githubusercontent.com/iwanami-datascience/vol1/master/matsuura/example2/input/data-season.txt"
data <- read.csv(data_url)
data$Time <- seq_along(data$Y)

プロットしてデータを確認します。

p <- ggplot(data) +
  geom_line(aes(x = Time, y = Y))
print(p)

NIMBLEによるモデリング

モデル

BUGS言語でモデルを記述します。ローカル線形トレンドモデルと季節調整モデルを組み合わせたモデルとなっています。また、T_new期先まで予測をおこないます。

NIMBLEでMCMC計算を実行しましたが、このBUGSコードはJAGSでも実行可能なはずです(分布の引数に式を使っているので、WinBUGSではエラーになります)。

code <- nimbleCode({
  # observation
  for (t in 1:T) {
    Y[t] ~ dnorm(mu[t] + season[t], tau[1])
  }

  # level
  for (t in 3:T) {
    mu[t] ~ dnorm(2 * mu[t - 1] - mu[t - 2], tau[2]) 
  }
  
  # seasonal
  for (t in 4:T) {
    season[t] ~ dnorm(-sum(season[(t - 3):(t - 1)]), tau[3])
  }
  
  # forecast
  for (t in (T + 1):(T + T_new)) {
    Y_new[t - T] ~ dnorm(mu[t] + season[t], tau[1])
    mu[t] ~ dnorm(2 * mu[t - 1] - mu[t - 2], tau[2]) 
    season[t] ~ dnorm(-sum(season[(t - 3):(t - 1)]), tau[3])
  }

  # priors
  for (t in 1:2) {
    mu[t] ~ dnorm(0, 1e-4)
  }
  for (t in 1:3) {
    season[t] ~ dnorm(0, 1e-4) 
  }
  for (i in 1:3) {
    tau[i] <- 1 / (sigma[i] * sigma[i])
    sigma[i] ~ dunif(0, 100)
  }
})

MCMC

nimbleMCMC関数でMCMC計算を実行します。samplesAsCodaMCMC = TRUEとすることで、codaパッケージのMCMC型として事後分布のサンプルを出力させ、それをposteriorパッケージのas_draws関数で、draws型に変換しています。

T_num <- nrow(data)
T_new <- 8
init_fun <- function() {
  list(mu = rnorm(T_num + T_new, 20, 2),
       season = rnorm(T_num + T_new, 0, 2),
       Y_new = rnorm(T_new, 20, 2),
       sigma= runif(3, 0, 2))
}
samp <- nimbleMCMC(code,
                   constants = list(T = T_num, T_new = T_new),
                   data = list(Y = data$Y),
                   inits = init_fun,
                   monitors = c("mu", "season", "sigma", "Y_new"),
                   niter = 55000, nburnin = 5000, nchains = 3,
                   thin = 50,
                   samplesAsCodaMCMC = TRUE) |>
  as_draws()
## Defining model
## Building model
## Setting data and initial values
## Running calculate on model
##   [Note] Any error reports that follow may simply reflect missing values in model variables.
## Checking model sizes and dimensions
##   [Note] This model is not fully initialized. This is not an error.
##          To see which variables are not initialized, use model$initializeInfo().
##          For more information on model initialization, see help(modelInitialization).
## Checking model calculations
## Compiling
##   [Note] This may take a minute.
##   [Note] Use 'showCompilerOutput = TRUE' to see C++ compilation details.
## running chain 1...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 2...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 3...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|

結果

要約統計量

標準偏差のパラメータであるsigmaの要約統計量を表示します。

summay_stat <- summary(samp)
summay_stat |>
  filter(str_starts(variable, "sigma"))
## # A tibble: 3 × 10
##   variable  mean median     sd    mad     q5   q95  rhat ess_bulk ess_tail
##   <chr>    <num>  <num>  <num>  <num>  <num> <num> <num>    <num>    <num>
## 1 sigma[1] 0.283  0.254 0.195  0.203  0.0313 0.655  1.04     100.     86.8
## 2 sigma[2] 0.109  0.102 0.0353 0.0301 0.0653 0.176  1.00    1370.   1742. 
## 3 sigma[3] 0.712  0.710 0.133  0.125  0.501  0.932  1.01     557.    598.

軌跡のプロット

bayesplotパッケージのmcmc_trace関数でMCMCの軌跡を確認します。

mcmc_trace(samp, pars = c("sigma[1]", "sigma[2]", "sigma[3]"))

水準成分

状態の水準成分の推定値(事後中央値)と90%信用区間をプロットします。

mu <- summay_stat |>
  filter(str_starts(variable, "mu")) |>
  select(mean, median, q5, q95) |>
  slice_head(n = T_num) |>
  bind_cols(data)
ggplot(mu) +
  geom_ribbon(aes(x = Time, ymin = q5, ymax = q95),
              fill = "gray30", alpha = 0.7) +
  geom_line(aes(x = Time, y = median)) +
  labs(y = "Y")

季節調整成分

状態の季節調整成分の推定値(事後中央値)と90%信用区間をプロットします。

season <- summay_stat |>
  filter(str_starts(variable, "season")) |>
  select(mean, median, q5, q95) |>
  slice_head(n = T_num) |>
  bind_cols(data)
ggplot(season) +
  geom_ribbon(aes(x = Time, ymin = q5, ymax = q95),
              fill = "gray30", alpha = 0.7) +
  geom_line(aes(x = Time, y = median)) +
  labs(y = "Y")

予測

予測値の要約統計量を表示します。

summay_stat |>
  filter(str_starts(variable, "Y_new"))
## # A tibble: 8 × 10
##   variable  mean median    sd   mad    q5   q95  rhat ess_bulk ess_tail
##   <chr>    <num>  <num> <num> <num> <num> <num> <num>    <num>    <num>
## 1 Y_new[1]  24.7   24.7  1.32  1.33  22.5  26.8  1.00     802.    1601.
## 2 Y_new[2]  33.7   33.7  1.31  1.25  31.5  35.9  1.00    2075.    2667.
## 3 Y_new[3]  22.1   22.1  1.44  1.37  19.7  24.5  1.00    1675.    2629.
## 4 Y_new[4]  24.6   24.6  1.49  1.46  22.1  27.0  1.00    1888.    2343.
## 5 Y_new[5]  26.4   26.4  2.30  2.33  22.7  30.3  1.00     999.    1652.
## 6 Y_new[6]  35.4   35.4  2.27  2.22  31.8  39.1  1.00    1925.    2347.
## 7 Y_new[7]  23.9   23.8  2.51  2.56  19.8  28.0  1.00    1639.    2411.
## 8 Y_new[8]  26.3   26.3  2.67  2.58  22.0  30.6  1.00    1747.    2400.

最後に、元のデータと、8期先までの予測値(事後中央値)と90%信用区間をプロットします。

Y_new <- summay_stat |>
  filter(str_starts(variable, "Y_new")) |>
  select(mean, median, q5, q95) |>
  bind_cols(tibble(Time = T_num + 1:T_new))
ggplot(data) +
  geom_line(aes(x = Time, y = Y)) +
  geom_ribbon(data = Y_new, aes(x = Time, ymin = q5, ymax = q95),
              fill = "gray30", alpha = 0.7) +
  geom_line(data = Y_new, aes(x = Time, y = median))