library(nimble)
## nimble version 1.0.1 is loaded.
## For more information on NIMBLE and a User Manual,
## please visit https://R-nimble.org.
##
## Note for advanced users who have written their own MCMC samplers:
## As of version 0.13.0, NIMBLE's protocol for handling posterior
## predictive nodes has changed in a way that could affect user-defined
## samplers in some situations. Please see Section 15.5.1 of the User Manual.
##
## 次のパッケージを付け加えます: 'nimble'
## 以下のオブジェクトは 'package:stats' からマスクされています:
##
## simulate
library(ggplot2)
library(posterior)
## This is posterior version 1.4.1
##
## 次のパッケージを付け加えます: 'posterior'
## 以下のオブジェクトは 'package:stats' からマスクされています:
##
## mad, sd, var
## 以下のオブジェクトは 'package:base' からマスクされています:
##
## %in%, match
set.seed(1234)
<- 30
N <- c(1, 6, 12)
lambda <- rpois(N * 3, rep(lambda, each = N)) Y
複数の個体群が混ざっている個体数データを想定して、平均が異なるポアソン分布が混合しているデータをディリクレ過程混合モデルでクラスター分けすることを考えてみます。
データ
ここでは、平均が1, 6, 12のポアソン分布にしたがう計測値が混ざっている個体数データを想定して、データを生成します。
以下のようなデータになりました。
ggplot(data.frame(Y = Y)) +
geom_bar(aes(x = Y))
NIMBLEモデル
NIMBLEによるモデルコードは以下のようになります。stick_breaking
関数で折れ棒過程を使っています。
クラスタの数をK
に格納するようにしています。
<- nimbleCode({
code for(n in 1:N) {
~ dcat(pi[])
z[n] ~ dpois(exp(phi[z[n]]))
Y[n]
}for (m in 1:(M - 1)) {
~ dbeta(1, alpha)
q[m]
}1:M] <- stick_breaking(q[1:(M - 1)])
pi[~ dgamma(1, 1)
alpha for(m in 1:M) {
~ dnorm(0, sd = 10)
phi[m]
}for (n in 1:N) {
for (m in 1:M) {
<- equals(m, z[n])
zind[n, m]
}
}for (m in 1:M) {
<- sum(zind[, m])
sumind[m] <- step(sumind[m] - 1 + 0.001)
cl[m]
}<- sum(cl[])
K })
実行
定数、データ、初期値を定義して、サンプリングをおこないます。最後にposteriorパッケージのdraws
型のオブジェクトにしています。
<- list(N = N * 3, M = 20)
constants <- list(Y = Y)
data <- list(z = sample(1:5, size = constants$N, replace = TRUE),
inits zind = matrix(0, ncol = constants$M, nrow = constants$N),
cl = rep(1, constants$M),
q = runif(constants$M, 0, 1),
pi = rep(1 / constants$M, constants$M),
phi = runif(constants$M, -2, 2),
alpha = 1)
<- nimbleMCMC(code, constants = constants,
samp data = data, inits = inits,
monitors = c("phi", "alpha", "pi", "K"),
nchains = 3,
niter = 10000, nburnin = 2000,
samplesAsCodaMCMC = TRUE) |>
as_draws()
## Defining model
## Building model
## Setting data and initial values
## Running calculate on model
## [Note] Any error reports that follow may simply reflect missing values in model variables.
## Checking model sizes and dimensions
## Checking model calculations
## Compiling
## [Note] This may take a minute.
## [Note] Use 'showCompilerOutput = TRUE' to see C++ compilation details.
## running chain 1...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 2...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 3...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
結果
結果です。まずはalpha
の事後分布の要約です。
<- samp |>
summ summary()
|>
summ ::filter(variable == "alpha")
dplyr## # A tibble: 1 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <num> <num> <num> <num> <num> <num> <num> <num> <num>
## 1 alpha 1.01 0.875 0.613 0.537 0.270 2.22 1.03 236. 603.
クラスタ数K
の事後分布の要約です。
|>
summ ::filter(startsWith(variable, "K"))
dplyr## # A tibble: 1 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <num> <num> <num> <num> <num> <num> <num> <num> <num>
## 1 K 4.95 5 1.63 1.48 3 8 1.04 96.9 558.
中央値は5、90%信用区間は3–8と推定されました。
カテゴリカル分布のパラメーターpi
の事後分布の要約を図示します。
<- summ |>
pi_est ::filter(startsWith(variable, "pi")) |>
dplyr::mutate(index = stringr::str_extract(variable, "[0-9]+") |>
dplyras.numeric() |> as.factor())
ggplot(pi_est) +
geom_point(aes(x = index, y = mean)) +
geom_segment(aes(x = index, xend = index, y = q5, yend = q95)) +
labs(x = "Component", y = "Probability")
真のカテゴリー数は3ですが、そのあたりまでの確率が高くなっています。
各カテゴリーの平均の対数phi
は、ラベルスイッチングが起きるので、そのままでは多峰になってしまいます。Stanだと、ordered_vector
型で解決できそうなのですが、NIMBLEにはありません。dconstraintで、大小関係等の制約をかけることはできるのですが、この場合は引数の数が不定ですので、使えません。このあたりはもう少し考えてみます。
【追記】Stanでもやってみましたが、単純にordered vectorを使えばよいというものでもありませんでした。