




The BUGS Book 11.8節 Bayesian nonparametric models の Example 11.8.1 Galaxy clustering: Dirichlet process mixuture models にあるモデルをNIMBLEに実装してみました。




The BUGS Bookのサポートページで配布されているExample-11.8.1-galaxy.odcodcreadでテキストに変換して取り出しました。てもとのmacOS環境ではそのままではコンパイルが完了しなかったので、Makefileを修正する必要がありました。

データは、velocityが銀河の速度(×1000 km/h)、dens.xが速度の分布範囲、ndensdens.xの要素数、Cがクラスター数の最大、nvelocityの要素数です(でいいかな)。

velocity <- c( 9.172,  9.35,   9.483,  9.558,  9.775,
              10.227, 10.406, 16.084, 16.17,  18.419,
              18.552, 18.6,   18.927, 19.052, 19.07,
              19.33,  19.343, 19.349, 19.44,  19.473,
              19.529, 19.541, 19.547, 19.663, 19.846,
              19.856, 19.863, 19.914, 19.918, 19.973,
              19.989, 20.166, 20.175, 20.179, 20.196,
              20.215, 20.221, 20.415, 20.629, 20.795,
              20.821, 20.846, 20.875, 20.986, 21.137,
              21.492, 21.701, 21.814, 21.921, 21.96,
              22.185, 22.209, 22.242, 22.249, 22.314,
              22.374, 22.495, 22.746, 22.747, 22.888,
              22.914, 23.206, 23.241, 23.263, 23.484,
              23.538, 23.542, 23.666, 23.706, 23.711,
              24.129, 24.285, 24.289, 24.368, 24.717,
              24.99,  25.633, 26.96,  26.995, 32.065,
              32.789, 34.279)
dens.x <- seq(8, 35, 0.1)
ndens <- 271
C <- 20
n <- 82


NIMBLEのコードです。もとのBUGSコードはなぜかサンプリングがはじまらなかったのですが、NIMBLEマニュアルの10.3節 Stick-breaking modelを参考に、stick_breaking関数を使うように修正したら、うまく動作しました。

code <- nimbleCode({
  for (i in 1:n) {
    velocity[i] ~ dnorm(mu[i], tau[i])
    mu[i] <- mu.mix[group[i]]
    tau[i] <- tau.mix[group[i]]
    group[i] ~ dcat(pi[])
    for (j in 1:C) {
      gind[i, j] <- equals(j, group[i])
  for (j in 1:(C - 1)) {
    q[j] ~ dbeta(1, alpha)
  pi[1:C] <- stick_breaking(q[1:(C - 1)])
  for (j in 1:C) {
    mu.mix[j] ~ dnorm(amu, mu.prec[j])
    mu.prec[j] <- bmu * tau.mix[j]
    tau.mix[j] ~ dgamma(aprec, bprec)
  alpha <- 1
  amu ~ dnorm(0, 0.001)
  bmu ~ dgamma(0.5, 50)
  aprec <- 2
  bprec ~ dgamma(2, 1)
  K <- sum(cl[])
  for (j in 1:C) {
    sumind[j] <- sum(gind[, j])
    cl[j] <- step(sumind[j] - 1 + 0.001) # cluster j used in
                                         # this iteration
  for (j in 1:ndens) {
    for (i in 1:C) {
      dens.cpt[i, j] <- pi[i] * sqrt(tau.mix[i] / (2 * 3.141592654)) *
                        exp(-0.5 * tau.mix[i] * (mu.mix[i] - dens.x[j])
                                              * (mu.mix[i] - dens.x[j]))
    dens[j] <- sum(dens.cpt[, j])


モデルをあてはめます。JAGSでThe BUGS BookのBUGSコードをあてはめるときにはとくに初期値を設定する必要はなかったのですが、NIMBLEではわりと細かく設定しないとうまく動作しません。


samp <- nimbleMCMC(code,
                    constants = list(C = C, n = n,
                                     ndens = ndens),
                    data = list(velocity = velocity,
                                dens.x = dens.x),
                    inits = list(amu = 0, bmu = 1,
                                 group = rep(1, n),
                                 gind = matrix(0, ncol = C, nrow = n),
                                 dens.cpt = matrix(1.0e-4,
                                                   ncol = ndens, nrow = C),
                                 cl = rep(1, C),
                                 q = rep(0.5, C),
                                 pi = rep(1 / C, C)),
                    monitors = c("K", "amu", "bmu", "cl", "dens",
                                 "mu.mix", "pi", "tau.mix"),
                    niter = 21000,
                    nburnin = 1000,
                    thin = 10,
                    nchains = 3,
                    samplesAsCodaMCMC = TRUE) |>
## Defining model
## Building model
## Setting data and initial values
## Running calculate on model
##   [Note] Any error reports that follow may simply reflect missing values in model variables.
## Checking model sizes and dimensions
##   [Note] This model is not fully initialized. This is not an error.
##          To see which variables are not initialized, use model$initializeInfo().
##          For more information on model initialization, see help(modelInitialization).
## Checking model calculations
## [Note] NAs were detected in model variables: bprec, logProb_bprec, lifted_d1_over_bprec, tau.mix, logProb_tau.mix, mu.prec, lifted_d1_over_sqrt_oPmu_dot_prec_oBj_cB_cP_L12, tau, mu.mix, logProb_mu.mix, lifted_d1_over_sqrt_oPtau_oBi_cB_cP_L2, dens.cpt, dens, mu, logProb_velocity.
## Compiling
##   [Note] This may take a minute.
##   [Note] Use 'showCompilerOutput = TRUE' to see C++ compilation details.
## running chain 1...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 2...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
## running chain 3...
## |-------------|-------------|-------------|-------------|
## |-------------------------------------------------------|
summary <- summary(samp)




summary |>
## # A tibble: 354 × 10
##    variable    mean  median    sd     mad      q5    q95  rhat ess_bulk ess_tail
##    <chr>      <num>   <num> <num>   <num>   <num>  <num> <num>    <num>    <num>
##  1 mu.mix[… 19.8    20.0    5.90   3.91   9.59e+0 27.6    1.05     98.0    240. 
##  2 pi[3]     0.157   0.0927 0.150  0.0919 1.18e-2  0.449  1.04    267.     689. 
##  3 mu.mix[… 21.2    21.6    4.21   2.38   1.01e+1 26.9    1.04     60.8     60.3
##  4 tau.mix…  1.61    1.37   1.26   1.13   2.42e-1  3.85   1.03    133.     744. 
##  5 mu.mix[… 19.6    19.9    7.97   6.23   9.43e+0 33.2    1.03     89.0    733. 
##  6 tau.mix…  1.29    0.988  1.10   0.998  1.95e-1  3.29   1.03    116.    1387. 
##  7 pi[2]     0.267   0.277  0.180  0.229  2.54e-2  0.548  1.03    158.     356. 
##  8 pi[1]     0.355   0.372  0.202  0.181  3.32e-2  0.710  1.02    199.     313. 
##  9 mu.mix[… 20.5    20.1    9.15  10.1    9.43e+0 33.5    1.02    123.    1238. 
## 10 pi[4]     0.0943  0.0558 0.107  0.0513 5.74e-3  0.363  1.01    297.     478. 
## # ℹ 344 more rows




mcmc_trace(samp, pars = c("amu", "bmu", "pi[1]", "pi[2]"))



summary |>
  dplyr::filter(variable == "K")
## # A tibble: 1 × 10
##   variable  mean median    sd   mad    q5   q95  rhat ess_bulk ess_tail
##   <chr>    <num>  <num> <num> <num> <num> <num> <num>    <num>    <num>
## 1 K         6.82      7  1.59  1.48     4    10  1.00     507.     631.


pi_est <- summary |>
  dplyr::filter(str_starts(variable, "pi")) |>
  dplyr::mutate(index = str_extract(variable, "[0-9]+") |>
                  as.numeric() |> as.factor())
ggplot(pi_est) +
  geom_point(aes(x = index, y = mean)) +
  geom_segment(aes(x = index, xend = index, y = q5, yend = q95)) +
  labs(x = "Component", y = "Probability")


dens_est <- summary |>
  dplyr::filter(str_starts(variable, "dens")) |>
  dplyr::mutate(x = dens.x)
ggplot() +
  geom_histogram(data = data.frame(velocity = velocity),
                 mapping = aes(x = velocity, y = after_stat(density)),
                 bins = 30, fill = "gray") +
  geom_line(data = dens_est,
            mapping = aes(x = x, y = mean),
            colour = "red", linewidth = 1) +
  geom_ribbon(data = dens_est,
              mapping = aes(x = x, ymin = q5, ymax = q95),
              fill = "red", alpha = 0.3) +
  labs(x = "Galaxy velocity (x 1000 km/h)",
       y = "Probability density")

いずれも、The BUGS Bookにある結果(FIGURE 11.11)を再現できているようです。