library(TMB)
library(ggplot2)
set.seed(123)
TMBを混合効果モデルを書いてみました。まだ不慣れなので、いろいろ試行錯誤しましたが、なんとかできました。
準備
前回と同様、TMBとggplot2を読み込みます。擬似乱数も固定します。
データ
ゼロ過剰ポアソン分布にしたがう模擬データを生成します。今回は群ごとの変量効果も加えています。
<- 400
N <- 10
N_group <- rep(1:N_group, each = N / N_group)
group <- 0.3
p <- -1
alpha <- 0.8
beta <- -1.6 # sigma = 0.202
log_sigma <- rnorm(N_group, 0, exp(log_sigma))
ranef
<- runif(N, 0, 5)
x <- rbinom(N, 1, 1 - p) *
y rpois(N, exp(alpha + beta * x + ranef[group]))
<- data.frame(x = x, y = y, group = group) example
データを確認します。
ggplot(example, aes(x = x, y = y)) +
geom_point()
モデル
TMBのC++のコードです。"models/zipmix.cpp"
に保存しておきます。
はじめ、9行目のDATA_IVECTOR(group);
をDATA_VECTOR(group);
にしてしまっていて(I
がぬけていた)、コンパイルエラーとなっていました。エラーメッセージはC++のものなので、なかなか原因が特定できず、ちょっと苦労しました。
zipmix.cpp
// Zero-Inflated Poisson
#include <TMB.hpp>
template<class Type>
Type objective_function<Type>::operator() ()
{
DATA_VECTOR(Y);
DATA_VECTOR(X);
DATA_IVECTOR(group);
PARAMETER(p);
PARAMETER(alpha);
PARAMETER(beta);
PARAMETER_VECTOR(epsilon);
PARAMETER(log_sigma);
Type lp = 0;
lp += -sum(dnorm(epsilon, Type(0.0), exp(log_sigma), true));
for (int i = 0; i < Y.size(); i++) {
Type lambda = exp(alpha + beta * X(i) + epsilon(group(i)));
lp += -dzipois(Y(i), lambda, p, true);
}
return lp;
}
コンパイルと最適化
モデルをコンパイルして、できたライブラリをロードします。
<- "zipmix"
model_name file.path("models", paste(model_name, "cpp", sep = ".")) |>
compile()
file.path("models", dynlib(model_name)) |>
dyn.load()
MakeADFun
関数で、最適化関数に渡すオブジェクトを作成して、nlminb
関数で最適化します。MakeADFun
関数のrandom
引数に、モデル中の変量効果であるepsilon
を指定します。
<- list(Y = example$y, X = example$x,
data group = example$group - 1)
<- list(p = 0.5, alpha = 1, beta = 1,
parameters epsilon = rep(0, N_group), log_sigma = 0)
<- MakeADFun(data, parameters, DLL = model_name,
obj_zipmix random = "epsilon")
<- nlminb(obj_zipmix$par, obj_zipmix$fn, obj_zipmix$gr) opt_zipmix
結果
結果です。パラメータは、だいたいデータを生成した元の値を再現できていました。
print(opt_zipmix)
## $par
## p alpha beta log_sigma
## 0.2992974 -1.0778807 0.8271015 -1.5772257
##
## $objective
## [1] 678.5374
##
## $convergence
## [1] 0
##
## $iterations
## [1] 24
##
## $evaluations
## function gradient
## 32 25
##
## $message
## [1] "relative convergence (4)"