Stanのhmm_marginalによる隠れマルコフモデルのあてはめ

R
Stan
NIMBLE
作者

伊東宏樹

公開

2023年6月17日

更新日

2023年6月17日

Stanには、隠れマルコフモデルの潜在状態を周辺化してパラメータを推定するhmm_marginalという関数が用意されています。これをつかった例をしめします。

モデル

潜在状態、観測値とも3値をとるモデルを考えます。

潜在状態の推移確率は、1→1: 0.5, 1→2: 0.4, 1→3: 0.1, 2→1: 0.3, 2→2: 0.6, 2→3: 0.1, 3→1: 0.1, 3→2: 0.1, 3→3: 0.8とします。行列にまとめると以下のようになります。

\[ \left(\begin{array}{lll}0.5 & 0.4 & 0.1 \\ 0.3 & 0.6 & 0.1 \\ 0.1 & 0.1 & 0.8 \end{array}\right) \]

潜在状態から観測値への出力確率は、1→1: 0.9, 1→2: 0.09, 1→3: 0.01, 2→1: 0.1, 2→2: 0.8, 2→3: 0.1, 3→1: 0.05, 3→2: 0.05, 3→3: 0.9とします。行列にまとめると以下のようになります。

\[ \left(\begin{array}{lll}0.9 & 0.09 & 0.01 \\ 0.1 & 0.8 & 0.1 \\ 0.05 & 0.05 & 0.9 \end{array}\right) \]

模擬データの生成

カテゴリカル分布の乱数発生関数rcatのため、extraDistrパッケージを利用します。

library(extraDistr)

上のモデルで模擬データを生成します。時点数Tを200、潜在状態をx、観測値をyとします。

set.seed(123)

transition <- matrix(c(0.5, 0.4, 0.1,
                       0.3, 0.6, 0.1,
                       0.1, 0.1, 0.8), ncol = 3, byrow = TRUE)
emission <-  matrix(c(0.9,  0.09, 0.01,
                      0.1,  0.8,  0.1,
                      0.05, 0.05, 0.9), ncol = 3, byrow = TRUE)
T <- 200

## prepare data
x <- rep(0, T)
y <- rep(0, T)

### latent state
x[1] <- 1
for (t in 2:T) {
  x[t] <- rcat(1, transition[x[t - 1], ])
}

### observation
for (t in 1:T) {
  y[t] <- rcat(1, emission[x[t], ])
}

Stanのモデル

CmdStanRをつかいます。

library(cmdstanr)
## This is cmdstanr version 0.5.3
## - CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
## - CmdStan path: /usr/local/cmdstan
## - CmdStan version: 2.32.2
knitr::knit_engines$set(stan = cmdstanr::eng_cmdstan)

モデルです。コンパイルしてmodelというオブジェクトに格納しておきます。

このモデルでは出力確率は既知として、遷移確率を推定します。

// hidden Morkov model using hmm_marginal function

data {
  int<lower=0> T;                   // number of observations
  array[T] int<lower=1, upper=3> Y; // observations
  matrix[3, 3] p;                   // emission probability
}

transformed data {
  matrix[3, T] log_omega; // log density for each output

  for (t in 1:T)
    log_omega[1:3, t] = log(p[, Y[t]]);
}

parameters {
  array[3] simplex[3] Gamma_row; // transition matrix
  simplex[3] rho;                // initial state probability
}

transformed parameters {
  matrix[3, 3] Gamma = [Gamma_row[1]', Gamma_row[2]', Gamma_row[3]'];
}

model {
  target += hmm_marginal(log_omega, Gamma, rho);
}

あてはめます。

fit <- model$sample(data = list(T = T, Y = y, p = emission),
                    iter_warmup = 1000, iter_sampling = 1000,
                    chains = 4, parallel_chains = 4,
                    refresh = 0)
## Running MCMC with 4 parallel chains...
## 
## Chain 1 finished in 1.1 seconds.
## Chain 2 finished in 1.1 seconds.
## Chain 3 finished in 1.1 seconds.
## Chain 4 finished in 1.2 seconds.
## 
## All 4 chains finished successfully.
## Mean chain execution time: 1.1 seconds.
## Total execution time: 1.2 seconds.

推定された遷移確率の事後分布の要約です。だいたい元の値に近い値となっています。

fit$print("Gamma")
##    variable mean median   sd  mad   q5  q95 rhat ess_bulk ess_tail
##  Gamma[1,1] 0.52   0.52 0.08 0.08 0.38 0.65 1.00     5692     2978
##  Gamma[2,1] 0.31   0.31 0.07 0.07 0.20 0.44 1.00     5204     2984
##  Gamma[3,1] 0.10   0.10 0.05 0.05 0.03 0.20 1.00     3232     1543
##  Gamma[1,2] 0.37   0.36 0.08 0.08 0.24 0.50 1.00     5638     2939
##  Gamma[2,2] 0.60   0.61 0.08 0.08 0.47 0.73 1.00     5071     3004
##  Gamma[3,2] 0.15   0.14 0.06 0.06 0.05 0.26 1.00     3989     2516
##  Gamma[1,3] 0.12   0.11 0.05 0.05 0.04 0.21 1.00     4736     2350
##  Gamma[2,3] 0.08   0.08 0.04 0.04 0.02 0.16 1.00     4699     2718
##  Gamma[3,3] 0.75   0.76 0.07 0.07 0.64 0.85 1.00     5733     2810

推定された初期状態の事後分布の要約です。

fit$print("rho")
##  variable mean median   sd  mad   q5  q95 rhat ess_bulk ess_tail
##    rho[1] 0.48   0.47 0.23 0.26 0.11 0.85 1.00     5381     2245
##    rho[2] 0.27   0.22 0.20 0.21 0.02 0.66 1.00     4664     2548
##    rho[3] 0.26   0.22 0.20 0.21 0.02 0.66 1.00     5206     2994

NIMBLEのモデル

比較のため、 NIMBLEでも同じデータにあてはめてみます。

NIMBLEでは離散パラメータも推定できるので、潜在状態の推定もおこないつつ、遷移確率を推定します。

library(nimble)
## nimble version 1.0.0 is loaded.
## For more information on NIMBLE and a User Manual,
## please visit https://R-nimble.org.
## 
## Note for advanced users who have written their own MCMC samplers:
##   As of version 0.13.0, NIMBLE's protocol for handling posterior
##   predictive nodes has changed in a way that could affect user-defined
##   samplers in some situations. Please see Section 15.5.1 of the User Manual.
## 
##  次のパッケージを付け加えます: 'nimble'
##  以下のオブジェクトは 'package:extraDistr' からマスクされています:
## 
##     dcat, dinvgamma, pinvgamma, qinvgamma, rcat, rinvgamma
##  以下のオブジェクトは 'package:stats' からマスクされています:
## 
##     simulate

hmm_code <- nimbleCode({
  Gamma[1, 1:3] ~ ddirch(alpha0[])
  Gamma[2, 1:3] ~ ddirch(alpha0[])
  Gamma[3, 1:3] ~ ddirch(alpha0[])

  x[1] ~ dcat(alpha0[])
  for (t in 2:T) {
    x[t] ~ dcat(Gamma[x[t - 1], 1:3])
  }
  for (t in 1:T) {
    y[t] ~ dcat(p[x[t], 1:3])
  }
})

あてはめます。

fit <- nimbleMCMC(hmm_code,
                  constants = list(T = T,
                                   alpha0 = rep(1/3, 3)),
                  data = list(y = y, p = emission),
                  inits = list(x = y,
                               Gamma = matrix(rep(1/3, 9), 3, 3)),
                  niter = 11000,
                  nburnin = 1000,
                  nchains = 3,
                  setSeed = 1,
                  summary = TRUE)
## Defining model
## Building model
## Setting data and initial values
## Running calculate on model
##   [Note] Any error reports that follow may simply reflect missing values in model variables.
## Checking model sizes and dimensions
## Checking model calculations
## Compiling
##   [Note] This may take a minute.
##   [Note] Use 'showCompilerOutput = TRUE' to see C++ compilation details.
## running chain 1...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 2...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 3...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|

結果です。だいたい同様の結果が得られました。

fit$summary$all.chains[1:9, ]
##                   Mean     Median    St.Dev.    95%CI_low 95%CI_upp
## Gamma[1, 1] 0.52578606 0.52641821 0.08486386 0.3547793026 0.6886325
## Gamma[2, 1] 0.31748336 0.31504616 0.07781868 0.1751151821 0.4743826
## Gamma[3, 1] 0.07920582 0.07237536 0.05552069 0.0003575595 0.2033248
## Gamma[1, 2] 0.37213848 0.36813942 0.08534688 0.2154313062 0.5511173
## Gamma[2, 2] 0.61454572 0.61658358 0.07916704 0.4545583877 0.7640889
## Gamma[3, 2] 0.13794169 0.13098895 0.06782486 0.0289910742 0.2886163
## Gamma[1, 3] 0.10207546 0.09823529 0.05434688 0.0057349910 0.2217010
## Gamma[2, 3] 0.06797092 0.06113455 0.04390016 0.0026524932 0.1686991
## Gamma[3, 3] 0.78285249 0.78806013 0.06676725 0.6425430657 0.9002796