library(extraDistr)Stanには、隠れマルコフモデルの潜在状態を周辺化してパラメータを推定するhmm_marginalという関数が用意されています。これをつかった例をしめします。
モデル
潜在状態、観測値とも3値をとるモデルを考えます。
潜在状態の推移確率は、1→1: 0.5, 1→2: 0.4, 1→3: 0.1, 2→1: 0.3, 2→2: 0.6, 2→3: 0.1, 3→1: 0.1, 3→2: 0.1, 3→3: 0.8とします。行列にまとめると以下のようになります。
\[ \left(\begin{array}{lll}0.5 & 0.4 & 0.1 \\ 0.3 & 0.6 & 0.1 \\ 0.1 & 0.1 & 0.8 \end{array}\right) \]
潜在状態から観測値への出力確率は、1→1: 0.9, 1→2: 0.09, 1→3: 0.01, 2→1: 0.1, 2→2: 0.8, 2→3: 0.1, 3→1: 0.05, 3→2: 0.05, 3→3: 0.9とします。行列にまとめると以下のようになります。
\[ \left(\begin{array}{lll}0.9 & 0.09 & 0.01 \\ 0.1 & 0.8 & 0.1 \\ 0.05 & 0.05 & 0.9 \end{array}\right) \]
模擬データの生成
カテゴリカル分布の乱数発生関数rcatのため、extraDistrパッケージを利用します。
上のモデルで模擬データを生成します。時点数Tを200、潜在状態をx、観測値をyとします。
set.seed(123)
transition <- matrix(c(0.5, 0.4, 0.1,
0.3, 0.6, 0.1,
0.1, 0.1, 0.8), ncol = 3, byrow = TRUE)
emission <- matrix(c(0.9, 0.09, 0.01,
0.1, 0.8, 0.1,
0.05, 0.05, 0.9), ncol = 3, byrow = TRUE)
T <- 200
## prepare data
x <- rep(0, T)
y <- rep(0, T)
### latent state
x[1] <- 1
for (t in 2:T) {
x[t] <- rcat(1, transition[x[t - 1], ])
}
### observation
for (t in 1:T) {
y[t] <- rcat(1, emission[x[t], ])
}Stanのモデル
CmdStanRをつかいます。
library(cmdstanr)
## This is cmdstanr version 0.8.1
## - CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
## - CmdStan path: /usr/local/cmdstan
## - CmdStan version: 2.36.0
knitr::knit_engines$set(stan = cmdstanr::eng_cmdstan)モデルです。コンパイルしてmodelというオブジェクトに格納しておきます。
このモデルでは出力確率は既知として、遷移確率を推定します。
// hidden Morkov model using hmm_marginal function
data {
int<lower=0> T; // number of observations
array[T] int<lower=1, upper=3> Y; // observations
matrix[3, 3] p; // emission probability
}
transformed data {
matrix[3, T] log_omega; // log density for each output
for (t in 1:T)
log_omega[1:3, t] = log(p[, Y[t]]);
}
parameters {
array[3] simplex[3] Gamma_row; // transition matrix
simplex[3] rho; // initial state probability
}
transformed parameters {
matrix[3, 3] Gamma = [Gamma_row[1]', Gamma_row[2]', Gamma_row[3]'];
}
model {
target += hmm_marginal(log_omega, Gamma, rho);
}あてはめます。
fit <- model$sample(data = list(T = T, Y = y, p = emission),
iter_warmup = 1000, iter_sampling = 1000,
chains = 4, parallel_chains = 4,
refresh = 0)
## Running MCMC with 4 parallel chains...
##
## Chain 1 finished in 1.1 seconds.
## Chain 2 finished in 1.1 seconds.
## Chain 4 finished in 1.1 seconds.
## Chain 3 finished in 1.2 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 1.1 seconds.
## Total execution time: 1.2 seconds.推定された遷移確率の事後分布の要約です。だいたい元の値に近い値となっています。
fit$print("Gamma")
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## Gamma[1,1] 0.52 0.52 0.08 0.08 0.38 0.64 1.00 4812 3106
## Gamma[2,1] 0.31 0.31 0.07 0.07 0.20 0.44 1.00 5237 2943
## Gamma[3,1] 0.10 0.09 0.05 0.05 0.03 0.20 1.00 4140 1551
## Gamma[1,2] 0.37 0.36 0.08 0.08 0.24 0.50 1.00 5260 3083
## Gamma[2,2] 0.60 0.61 0.08 0.08 0.47 0.73 1.00 5985 3162
## Gamma[3,2] 0.15 0.14 0.06 0.06 0.05 0.26 1.00 4348 1839
## Gamma[1,3] 0.12 0.11 0.05 0.05 0.04 0.21 1.00 4764 2317
## Gamma[2,3] 0.08 0.08 0.04 0.04 0.02 0.17 1.00 3780 2400
## Gamma[3,3] 0.75 0.76 0.07 0.07 0.64 0.86 1.00 5582 2694推定された初期状態の事後分布の要約です。
fit$print("rho")
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## rho[1] 0.48 0.48 0.23 0.27 0.11 0.86 1.00 5455 2575
## rho[2] 0.26 0.22 0.21 0.21 0.02 0.67 1.00 4664 2464
## rho[3] 0.25 0.21 0.19 0.20 0.02 0.63 1.00 4619 2967NIMBLEのモデル
比較のため、 NIMBLEでも同じデータにあてはめてみます。
NIMBLEでは離散パラメータも推定できるので、潜在状態の推定もおこないつつ、遷移確率を推定します。
library(nimble)
## nimble version 1.3.0 is loaded.
## For more information on NIMBLE and a User Manual,
## please visit https://R-nimble.org.
##
## Note for advanced users who have written their own MCMC samplers:
## As of version 0.13.0, NIMBLE's protocol for handling posterior
## predictive nodes has changed in a way that could affect user-defined
## samplers in some situations. Please see Section 15.5.1 of the User Manual.
##
## 次のパッケージを付け加えます: 'nimble'
## 以下のオブジェクトは 'package:extraDistr' からマスクされています:
##
## dcat, dinvgamma, pinvgamma, qinvgamma, rcat, rinvgamma
## 以下のオブジェクトは 'package:stats' からマスクされています:
##
## simulate
## 以下のオブジェクトは 'package:base' からマスクされています:
##
## declare
hmm_code <- nimbleCode({
Gamma[1, 1:3] ~ ddirch(alpha0[])
Gamma[2, 1:3] ~ ddirch(alpha0[])
Gamma[3, 1:3] ~ ddirch(alpha0[])
x[1] ~ dcat(alpha0[])
for (t in 2:T) {
x[t] ~ dcat(Gamma[x[t - 1], 1:3])
}
for (t in 1:T) {
y[t] ~ dcat(p[x[t], 1:3])
}
})あてはめます。
fit <- nimbleMCMC(hmm_code,
constants = list(T = T,
alpha0 = rep(1/3, 3)),
data = list(y = y, p = emission),
inits = list(x = y,
Gamma = matrix(rep(1/3, 9), 3, 3)),
niter = 11000,
nburnin = 1000,
nchains = 3,
summary = TRUE)
## Defining model
## Building model
## Setting data and initial values
## Running calculate on model
## [Note] Any error reports that follow may simply reflect missing values in model variables.
## Checking model sizes and dimensions
## Checking model calculations
## Compiling
## [Note] This may take a minute.
## [Note] Use 'showCompilerOutput = TRUE' to see C++ compilation details.
## running chain 1...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 2...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 3...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|結果です。だいたい同様の結果が得られました。
fit$summary$all.chains[1:9, ]
## Mean Median St.Dev. 95%CI_low 95%CI_upp
## Gamma[1, 1] 0.52582044 0.52702830 0.08534777 0.3547100134 0.6895483
## Gamma[2, 1] 0.31533949 0.31225447 0.07838847 0.1697561657 0.4765504
## Gamma[3, 1] 0.07948616 0.07270419 0.05498936 0.0003139441 0.2049384
## Gamma[1, 2] 0.37219084 0.36955026 0.08626943 0.2129336009 0.5492764
## Gamma[2, 2] 0.61597780 0.61819887 0.07992799 0.4537925402 0.7655351
## Gamma[3, 2] 0.13700910 0.12924816 0.06707030 0.0280685462 0.2861290
## Gamma[1, 3] 0.10198871 0.09718695 0.05386609 0.0086566529 0.2196535
## Gamma[2, 3] 0.06868271 0.06218870 0.04260513 0.0043005176 0.1683020
## Gamma[3, 3] 0.78350475 0.78861219 0.06665531 0.6403528769 0.9002323