library(TMB)
library(ggplot2)
set.seed(123)TMBを混合効果モデルを書いてみました。まだ不慣れなので、いろいろ試行錯誤しましたが、なんとかできました。
準備
前回と同様、TMBとggplot2を読み込みます。擬似乱数も固定します。
データ
ゼロ過剰ポアソン分布にしたがう模擬データを生成します。今回は群ごとの変量効果も加えています。
N <- 400
N_group <- 10
group <- rep(1:N_group, each = N / N_group)
p <- 0.3
alpha <- -1
beta <- 0.8
log_sigma <- -1.6 # sigma = 0.202
ranef <- rnorm(N_group, 0, exp(log_sigma))
x <- runif(N, 0, 5)
y <- rbinom(N, 1, 1 - p) *
rpois(N, exp(alpha + beta * x + ranef[group]))
example <- data.frame(x = x, y = y, group = group)データを確認します。
ggplot(example, aes(x = x, y = y)) +
geom_point()
モデル
TMBのC++のコードです。"models/zipmix.cpp"に保存しておきます。
はじめ、9行目のDATA_IVECTOR(group);をDATA_VECTOR(group);にしてしまっていて(Iがぬけていた)、コンパイルエラーとなっていました。エラーメッセージはC++のものなので、なかなか原因が特定できず、ちょっと苦労しました。
zipmix.cpp
// Zero-Inflated Poisson
#include <TMB.hpp>
template<class Type>
Type objective_function<Type>::operator() ()
{
DATA_VECTOR(Y);
DATA_VECTOR(X);
DATA_IVECTOR(group);
PARAMETER(p);
PARAMETER(alpha);
PARAMETER(beta);
PARAMETER_VECTOR(epsilon);
PARAMETER(log_sigma);
Type lp = 0;
lp += -sum(dnorm(epsilon, Type(0.0), exp(log_sigma), true));
for (int i = 0; i < Y.size(); i++) {
Type lambda = exp(alpha + beta * X(i) + epsilon(group(i)));
lp += -dzipois(Y(i), lambda, p, true);
}
return lp;
}コンパイルと最適化
モデルをコンパイルして、できたライブラリをロードします。
model_name <- "zipmix"
file.path("models", paste(model_name, "cpp", sep = ".")) |>
compile()
file.path("models", dynlib(model_name)) |>
dyn.load()MakeADFun関数で、最適化関数に渡すオブジェクトを作成して、nlminb関数で最適化します。MakeADFun関数のrandom引数に、モデル中の変量効果であるepsilonを指定します。
data <- list(Y = example$y, X = example$x,
group = example$group - 1)
parameters <- list(p = 0.5, alpha = 1, beta = 1,
epsilon = rep(0, N_group), log_sigma = 0)
obj_zipmix <- MakeADFun(data, parameters, DLL = model_name,
random = "epsilon")
opt_zipmix <- nlminb(obj_zipmix$par, obj_zipmix$fn, obj_zipmix$gr)結果
結果です。パラメータは、だいたいデータを生成した元の値を再現できていました。
print(opt_zipmix)
## $par
## p alpha beta log_sigma
## 0.2992974 -1.0778807 0.8271015 -1.5772257
##
## $objective
## [1] 678.5374
##
## $convergence
## [1] 0
##
## $iterations
## [1] 24
##
## $evaluations
## function gradient
## 32 25
##
## $message
## [1] "relative convergence (4)"