Stanのhmm_marginalによる隠れマルコフモデルのあてはめ

R
Stan
NIMBLE
作者

伊東宏樹

公開

2023年6月17日

更新日

2023年6月17日

Stanには、隠れマルコフモデルの潜在状態を周辺化してパラメータを推定するhmm_marginalという関数が用意されています。これをつかった例をしめします。

モデル

潜在状態、観測値とも3値をとるモデルを考えます。

潜在状態の推移確率は、1→1: 0.5, 1→2: 0.4, 1→3: 0.1, 2→1: 0.3, 2→2: 0.6, 2→3: 0.1, 3→1: 0.1, 3→2: 0.1, 3→3: 0.8とします。行列にまとめると以下のようになります。

\[ \left(\begin{array}{lll}0.5 & 0.4 & 0.1 \\ 0.3 & 0.6 & 0.1 \\ 0.1 & 0.1 & 0.8 \end{array}\right) \]

潜在状態から観測値への出力確率は、1→1: 0.9, 1→2: 0.09, 1→3: 0.01, 2→1: 0.1, 2→2: 0.8, 2→3: 0.1, 3→1: 0.05, 3→2: 0.05, 3→3: 0.9とします。行列にまとめると以下のようになります。

\[ \left(\begin{array}{lll}0.9 & 0.09 & 0.01 \\ 0.1 & 0.8 & 0.1 \\ 0.05 & 0.05 & 0.9 \end{array}\right) \]

模擬データの生成

カテゴリカル分布の乱数発生関数rcatのため、extraDistrパッケージを利用します。

library(extraDistr)

上のモデルで模擬データを生成します。時点数Tを200、潜在状態をx、観測値をyとします。

set.seed(123)

transition <- matrix(c(0.5, 0.4, 0.1,
                       0.3, 0.6, 0.1,
                       0.1, 0.1, 0.8), ncol = 3, byrow = TRUE)
emission <-  matrix(c(0.9,  0.09, 0.01,
                      0.1,  0.8,  0.1,
                      0.05, 0.05, 0.9), ncol = 3, byrow = TRUE)
T <- 200

## prepare data
x <- rep(0, T)
y <- rep(0, T)

### latent state
x[1] <- 1
for (t in 2:T) {
  x[t] <- rcat(1, transition[x[t - 1], ])
}

### observation
for (t in 1:T) {
  y[t] <- rcat(1, emission[x[t], ])
}

Stanのモデル

CmdStanRをつかいます。

library(cmdstanr)
## This is cmdstanr version 0.8.1
## - CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
## - CmdStan path: /usr/local/cmdstan
## - CmdStan version: 2.36.0
knitr::knit_engines$set(stan = cmdstanr::eng_cmdstan)

モデルです。コンパイルしてmodelというオブジェクトに格納しておきます。

このモデルでは出力確率は既知として、遷移確率を推定します。

// hidden Morkov model using hmm_marginal function

data {
  int<lower=0> T;                   // number of observations
  array[T] int<lower=1, upper=3> Y; // observations
  matrix[3, 3] p;                   // emission probability
}

transformed data {
  matrix[3, T] log_omega; // log density for each output

  for (t in 1:T)
    log_omega[1:3, t] = log(p[, Y[t]]);
}

parameters {
  array[3] simplex[3] Gamma_row; // transition matrix
  simplex[3] rho;                // initial state probability
}

transformed parameters {
  matrix[3, 3] Gamma = [Gamma_row[1]', Gamma_row[2]', Gamma_row[3]'];
}

model {
  target += hmm_marginal(log_omega, Gamma, rho);
}

あてはめます。

fit <- model$sample(data = list(T = T, Y = y, p = emission),
                    iter_warmup = 1000, iter_sampling = 1000,
                    chains = 4, parallel_chains = 4,
                    refresh = 0)
## Running MCMC with 4 parallel chains...
## 
## Chain 1 finished in 1.1 seconds.
## Chain 2 finished in 1.1 seconds.
## Chain 4 finished in 1.1 seconds.
## Chain 3 finished in 1.2 seconds.
## 
## All 4 chains finished successfully.
## Mean chain execution time: 1.1 seconds.
## Total execution time: 1.2 seconds.

推定された遷移確率の事後分布の要約です。だいたい元の値に近い値となっています。

fit$print("Gamma")
##    variable mean median   sd  mad   q5  q95 rhat ess_bulk ess_tail
##  Gamma[1,1] 0.52   0.52 0.08 0.08 0.38 0.64 1.00     4812     3106
##  Gamma[2,1] 0.31   0.31 0.07 0.07 0.20 0.44 1.00     5237     2943
##  Gamma[3,1] 0.10   0.09 0.05 0.05 0.03 0.20 1.00     4140     1551
##  Gamma[1,2] 0.37   0.36 0.08 0.08 0.24 0.50 1.00     5260     3083
##  Gamma[2,2] 0.60   0.61 0.08 0.08 0.47 0.73 1.00     5985     3162
##  Gamma[3,2] 0.15   0.14 0.06 0.06 0.05 0.26 1.00     4348     1839
##  Gamma[1,3] 0.12   0.11 0.05 0.05 0.04 0.21 1.00     4764     2317
##  Gamma[2,3] 0.08   0.08 0.04 0.04 0.02 0.17 1.00     3780     2400
##  Gamma[3,3] 0.75   0.76 0.07 0.07 0.64 0.86 1.00     5582     2694

推定された初期状態の事後分布の要約です。

fit$print("rho")
##  variable mean median   sd  mad   q5  q95 rhat ess_bulk ess_tail
##    rho[1] 0.48   0.48 0.23 0.27 0.11 0.86 1.00     5455     2575
##    rho[2] 0.26   0.22 0.21 0.21 0.02 0.67 1.00     4664     2464
##    rho[3] 0.25   0.21 0.19 0.20 0.02 0.63 1.00     4619     2967

NIMBLEのモデル

比較のため、 NIMBLEでも同じデータにあてはめてみます。

NIMBLEでは離散パラメータも推定できるので、潜在状態の推定もおこないつつ、遷移確率を推定します。

library(nimble)
## nimble version 1.3.0 is loaded.
## For more information on NIMBLE and a User Manual,
## please visit https://R-nimble.org.
## 
## Note for advanced users who have written their own MCMC samplers:
##   As of version 0.13.0, NIMBLE's protocol for handling posterior
##   predictive nodes has changed in a way that could affect user-defined
##   samplers in some situations. Please see Section 15.5.1 of the User Manual.
## 
## 次のパッケージを付け加えます: 'nimble'
## 以下のオブジェクトは 'package:extraDistr' からマスクされています:
## 
##     dcat, dinvgamma, pinvgamma, qinvgamma, rcat, rinvgamma
## 以下のオブジェクトは 'package:stats' からマスクされています:
## 
##     simulate
## 以下のオブジェクトは 'package:base' からマスクされています:
## 
##     declare

hmm_code <- nimbleCode({
  Gamma[1, 1:3] ~ ddirch(alpha0[])
  Gamma[2, 1:3] ~ ddirch(alpha0[])
  Gamma[3, 1:3] ~ ddirch(alpha0[])

  x[1] ~ dcat(alpha0[])
  for (t in 2:T) {
    x[t] ~ dcat(Gamma[x[t - 1], 1:3])
  }
  for (t in 1:T) {
    y[t] ~ dcat(p[x[t], 1:3])
  }
})

あてはめます。

fit <- nimbleMCMC(hmm_code,
                  constants = list(T = T,
                                   alpha0 = rep(1/3, 3)),
                  data = list(y = y, p = emission),
                  inits = list(x = y,
                               Gamma = matrix(rep(1/3, 9), 3, 3)),
                  niter = 11000,
                  nburnin = 1000,
                  nchains = 3,
                  summary = TRUE)
## Defining model
## Building model
## Setting data and initial values
## Running calculate on model
##   [Note] Any error reports that follow may simply reflect missing values in model variables.
## Checking model sizes and dimensions
## Checking model calculations
## Compiling
##   [Note] This may take a minute.
##   [Note] Use 'showCompilerOutput = TRUE' to see C++ compilation details.
## running chain 1...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 2...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 3...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|

結果です。だいたい同様の結果が得られました。

fit$summary$all.chains[1:9, ]
##                   Mean     Median    St.Dev.    95%CI_low 95%CI_upp
## Gamma[1, 1] 0.52582044 0.52702830 0.08534777 0.3547100134 0.6895483
## Gamma[2, 1] 0.31533949 0.31225447 0.07838847 0.1697561657 0.4765504
## Gamma[3, 1] 0.07948616 0.07270419 0.05498936 0.0003139441 0.2049384
## Gamma[1, 2] 0.37219084 0.36955026 0.08626943 0.2129336009 0.5492764
## Gamma[2, 2] 0.61597780 0.61819887 0.07992799 0.4537925402 0.7655351
## Gamma[3, 2] 0.13700910 0.12924816 0.06707030 0.0280685462 0.2861290
## Gamma[1, 3] 0.10198871 0.09718695 0.05386609 0.0086566529 0.2196535
## Gamma[2, 3] 0.06868271 0.06218870 0.04260513 0.0043005176 0.1683020
## Gamma[3, 3] 0.78350475 0.78861219 0.06665531 0.6403528769 0.9002323