library(extraDistr)
Stanには、隠れマルコフモデルの潜在状態を周辺化してパラメータを推定するhmm_marginalという関数が用意されています。これをつかった例をしめします。
モデル
潜在状態、観測値とも3値をとるモデルを考えます。
潜在状態の推移確率は、1→1: 0.5, 1→2: 0.4, 1→3: 0.1, 2→1: 0.3, 2→2: 0.6, 2→3: 0.1, 3→1: 0.1, 3→2: 0.1, 3→3: 0.8とします。行列にまとめると以下のようになります。
\[ \left(\begin{array}{lll}0.5 & 0.4 & 0.1 \\ 0.3 & 0.6 & 0.1 \\ 0.1 & 0.1 & 0.8 \end{array}\right) \]
潜在状態から観測値への出力確率は、1→1: 0.9, 1→2: 0.09, 1→3: 0.01, 2→1: 0.1, 2→2: 0.8, 2→3: 0.1, 3→1: 0.05, 3→2: 0.05, 3→3: 0.9とします。行列にまとめると以下のようになります。
\[ \left(\begin{array}{lll}0.9 & 0.09 & 0.01 \\ 0.1 & 0.8 & 0.1 \\ 0.05 & 0.05 & 0.9 \end{array}\right) \]
模擬データの生成
カテゴリカル分布の乱数発生関数rcat
のため、extraDistr
パッケージを利用します。
上のモデルで模擬データを生成します。時点数T
を200、潜在状態をx
、観測値をy
とします。
set.seed(123)
<- matrix(c(0.5, 0.4, 0.1,
transition 0.3, 0.6, 0.1,
0.1, 0.1, 0.8), ncol = 3, byrow = TRUE)
<- matrix(c(0.9, 0.09, 0.01,
emission 0.1, 0.8, 0.1,
0.05, 0.05, 0.9), ncol = 3, byrow = TRUE)
<- 200
T
## prepare data
<- rep(0, T)
x <- rep(0, T)
y
### latent state
1] <- 1
x[for (t in 2:T) {
<- rcat(1, transition[x[t - 1], ])
x[t]
}
### observation
for (t in 1:T) {
<- rcat(1, emission[x[t], ])
y[t] }
Stanのモデル
CmdStanRをつかいます。
library(cmdstanr)
## This is cmdstanr version 0.5.3
## - CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
## - CmdStan path: /usr/local/cmdstan
## - CmdStan version: 2.32.2
::knit_engines$set(stan = cmdstanr::eng_cmdstan) knitr
モデルです。コンパイルしてmodel
というオブジェクトに格納しておきます。
このモデルでは出力確率は既知として、遷移確率を推定します。
// hidden Morkov model using hmm_marginal function
data {
int<lower=0> T; // number of observations
array[T] int<lower=1, upper=3> Y; // observations
matrix[3, 3] p; // emission probability
}
transformed data {
matrix[3, T] log_omega; // log density for each output
for (t in 1:T)
1:3, t] = log(p[, Y[t]]);
log_omega[
}
parameters {
array[3] simplex[3] Gamma_row; // transition matrix
simplex[3] rho; // initial state probability
}
transformed parameters {
matrix[3, 3] Gamma = [Gamma_row[1]', Gamma_row[2]', Gamma_row[3]'];
}
model {
target += hmm_marginal(log_omega, Gamma, rho);
}
あてはめます。
<- model$sample(data = list(T = T, Y = y, p = emission),
fit iter_warmup = 1000, iter_sampling = 1000,
chains = 4, parallel_chains = 4,
refresh = 0)
## Running MCMC with 4 parallel chains...
##
## Chain 1 finished in 1.1 seconds.
## Chain 2 finished in 1.1 seconds.
## Chain 3 finished in 1.1 seconds.
## Chain 4 finished in 1.2 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 1.1 seconds.
## Total execution time: 1.2 seconds.
推定された遷移確率の事後分布の要約です。だいたい元の値に近い値となっています。
$print("Gamma")
fit## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## Gamma[1,1] 0.52 0.52 0.08 0.08 0.38 0.65 1.00 5692 2978
## Gamma[2,1] 0.31 0.31 0.07 0.07 0.20 0.44 1.00 5204 2984
## Gamma[3,1] 0.10 0.10 0.05 0.05 0.03 0.20 1.00 3232 1543
## Gamma[1,2] 0.37 0.36 0.08 0.08 0.24 0.50 1.00 5638 2939
## Gamma[2,2] 0.60 0.61 0.08 0.08 0.47 0.73 1.00 5071 3004
## Gamma[3,2] 0.15 0.14 0.06 0.06 0.05 0.26 1.00 3989 2516
## Gamma[1,3] 0.12 0.11 0.05 0.05 0.04 0.21 1.00 4736 2350
## Gamma[2,3] 0.08 0.08 0.04 0.04 0.02 0.16 1.00 4699 2718
## Gamma[3,3] 0.75 0.76 0.07 0.07 0.64 0.85 1.00 5733 2810
推定された初期状態の事後分布の要約です。
$print("rho")
fit## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## rho[1] 0.48 0.47 0.23 0.26 0.11 0.85 1.00 5381 2245
## rho[2] 0.27 0.22 0.20 0.21 0.02 0.66 1.00 4664 2548
## rho[3] 0.26 0.22 0.20 0.21 0.02 0.66 1.00 5206 2994
NIMBLEのモデル
比較のため、 NIMBLEでも同じデータにあてはめてみます。
NIMBLEでは離散パラメータも推定できるので、潜在状態の推定もおこないつつ、遷移確率を推定します。
library(nimble)
## nimble version 1.0.0 is loaded.
## For more information on NIMBLE and a User Manual,
## please visit https://R-nimble.org.
##
## Note for advanced users who have written their own MCMC samplers:
## As of version 0.13.0, NIMBLE's protocol for handling posterior
## predictive nodes has changed in a way that could affect user-defined
## samplers in some situations. Please see Section 15.5.1 of the User Manual.
##
## 次のパッケージを付け加えます: 'nimble'
## 以下のオブジェクトは 'package:extraDistr' からマスクされています:
##
## dcat, dinvgamma, pinvgamma, qinvgamma, rcat, rinvgamma
## 以下のオブジェクトは 'package:stats' からマスクされています:
##
## simulate
<- nimbleCode({
hmm_code 1, 1:3] ~ ddirch(alpha0[])
Gamma[2, 1:3] ~ ddirch(alpha0[])
Gamma[3, 1:3] ~ ddirch(alpha0[])
Gamma[
1] ~ dcat(alpha0[])
x[for (t in 2:T) {
~ dcat(Gamma[x[t - 1], 1:3])
x[t]
}for (t in 1:T) {
~ dcat(p[x[t], 1:3])
y[t]
} })
あてはめます。
<- nimbleMCMC(hmm_code,
fit constants = list(T = T,
alpha0 = rep(1/3, 3)),
data = list(y = y, p = emission),
inits = list(x = y,
Gamma = matrix(rep(1/3, 9), 3, 3)),
niter = 11000,
nburnin = 1000,
nchains = 3,
setSeed = 1,
summary = TRUE)
## Defining model
## Building model
## Setting data and initial values
## Running calculate on model
## [Note] Any error reports that follow may simply reflect missing values in model variables.
## Checking model sizes and dimensions
## Checking model calculations
## Compiling
## [Note] This may take a minute.
## [Note] Use 'showCompilerOutput = TRUE' to see C++ compilation details.
## running chain 1...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 2...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 3...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
結果です。だいたい同様の結果が得られました。
$summary$all.chains[1:9, ]
fit## Mean Median St.Dev. 95%CI_low 95%CI_upp
## Gamma[1, 1] 0.52578606 0.52641821 0.08486386 0.3547793026 0.6886325
## Gamma[2, 1] 0.31748336 0.31504616 0.07781868 0.1751151821 0.4743826
## Gamma[3, 1] 0.07920582 0.07237536 0.05552069 0.0003575595 0.2033248
## Gamma[1, 2] 0.37213848 0.36813942 0.08534688 0.2154313062 0.5511173
## Gamma[2, 2] 0.61454572 0.61658358 0.07916704 0.4545583877 0.7640889
## Gamma[3, 2] 0.13794169 0.13098895 0.06782486 0.0289910742 0.2886163
## Gamma[1, 3] 0.10207546 0.09823529 0.05434688 0.0057349910 0.2217010
## Gamma[2, 3] 0.06797092 0.06113455 0.04390016 0.0026524932 0.1686991
## Gamma[3, 3] 0.78285249 0.78806013 0.06676725 0.6425430657 0.9002796