library(cmdstanr)
## This is cmdstanr version 0.8.1
## - CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
## - CmdStan path: /usr/local/cmdstan
## - CmdStan version: 2.36.0
::knit_engines$set(stan = cmdstanr::eng_cmdstan)
knitrlibrary(ggplot2)
library(mvtnorm)
ある試験を2段階でおこなうとき、1次試験で一定の点数以上をとった人だけ2次試験を受けられるという状況を考えます。このとき、1次試験で落ちた人が2次試験も受けていたと想定した場合の2次試験の平均点と、1次試験と2次試験の点数の相関係数を求めます。
準備
パッケージを読み込みます。
データ
仮想データを生成します。1次試験と2次試験の点数は、平均がともに500点、分散共分散行列が\(\left(\begin{array}{cc} 15^2 & 100 \\ 100 & 10^2\end{array}\right)\)の2変量正規分布にしたがうとします。
set.seed(1234)
<- 400
N <- 500
m1 <- 500
m2 <- 15^2
var1 <- 10^2
var2 <- 100
cov12 <- matrix(c(var1, cov12, cov12, var2), 2, 2)
Sigma <- rmvnorm(N, mean = c(m1, m2), sigma = Sigma) |>
x round()
<- data.frame(x1 = x[, 1], x2 = x[, 2]) d
1次試験の点が490点未満のデータを打ち切ります。グラフの青い点が、2次試験の点がない打ち切りデータ点となります。
<- 490
cens
$cens <- d$x1 < cens
d<- d[!d$cens, ]
dobs <- d$x1[d$cens]
x1cens
ggplot(d) +
geom_point(aes(x = x1, y = x2, colour = cens))
平均と相関係数
打ち切りのない場合とある場合の2次試験の平均値を比較します。打ち切りデータで打ち切り前のデータを推定しようとすると、当然ながら値は過大評価になってしまいます。
# mean of censored data
mean(dobs$x2)
## [1] 502.9724
# true mean of x2
mean(d$x2)
## [1] 499.5825
一方、1次試験の点数と2次試験の点数との相関係数は過小評価になってしまいます。なお、相関係数の値は理論的には0.6666667になります。
# tru value of the correlation coefficient
cor(d$x1, d$x2)
## [1] 0.6977945
# value base on the censored data
cor(dobs$x1, dobs$x2)
## [1] 0.5483848
# theoretical value of the correlation coefficient
/ (sqrt(var1) * sqrt(var2))
cov12 ## [1] 0.6666667
Stanモデル
Stanで、1次試験の点数(x1
)と2次試験の点数(x2
)の回帰モデルをつくります。2次試験を受けられなかった人の2次試験の点数x2cens
はパラメータとして推定します。
data {
int<lower=0> Nobs; // number of observed data
int<lower=0> Ncens; // number of censored data
vector[Nobs] X1obs; // x1 of observed data
vector[Nobs] X2obs; // x2 of observed data
vector[Ncens] X1cens; // x1 of censored data
}
transformed data {
vector[Nobs + Ncens] x1 = append_row(X1obs, X1cens);
real mu_x1 = mean(x1);
}
parameters {
vector[Ncens] x2cens; // x2 of censored data
array[2] real beta; // intercept and coefficient
real<lower=0> sigma; // standard deviation
}
transformed parameters {
vector[Nobs + Ncens] x2 = append_row(X2obs, x2cens);
real mu_x2 = mean(x2);
}
model {
1] + beta[2] * X1obs, sigma);
X2obs ~ normal(beta[1] + beta[2] * X1cens, sigma);
x2cens ~ normal(beta[
}
generated quantities {
// calculate correlation coefficient including estimated x2
real cor = sum((x1 - mu_x1) .* (x2 - mu_x2)) / (Nobs + Ncens - 1)
/ (sd(x1) * sd(x2)); }
モデルをあてはめます。上のStanモデルをmodel
というオブジェクトに入れておきます。
<- list(Nobs = nrow(dobs),
stan_data Ncens = length(x1cens),
X1obs = dobs$x1,
X2obs = dobs$x2,
X1cens = x1cens)
<- model$sample(data = stan_data,
fit iter_warmup = 1000,
iter_sampling = 2000,
chains = 4,
parallel_chains = 4,
refresh = 0)
## Running MCMC with 4 parallel chains...
##
## Chain 2 finished in 10.1 seconds.
## Chain 4 finished in 10.3 seconds.
## Chain 3 finished in 10.9 seconds.
## Chain 1 finished in 13.4 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 11.2 seconds.
## Total execution time: 13.5 seconds.
結果
beta
とsigma
の事後分布の要約です。
$print(c("beta", "sigma"))
fit## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## beta[1] 277.85 277.89 20.71 20.84 243.64 311.92 1.00 2065 3381
## beta[2] 0.44 0.44 0.04 0.04 0.38 0.51 1.00 2082 3381
## sigma 7.52 7.51 0.31 0.31 7.02 8.04 1.00 5644 5268
実際には観測されていないx2
の推定値です。
# estimated x2
$print(c("x2cens"))
fit## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## x2cens[1] 492.85 492.82 7.68 7.56 480.37 505.52 1.00 9209 6317
## x2cens[2] 495.14 495.23 7.53 7.54 482.62 507.41 1.00 10961 6061
## x2cens[3] 494.54 494.55 7.63 7.56 482.07 507.04 1.00 10723 5891
## x2cens[4] 495.14 495.12 7.47 7.33 482.64 507.39 1.00 10119 6080
## x2cens[5] 495.12 495.08 7.58 7.71 482.55 507.69 1.00 9483 6335
## x2cens[6] 495.11 495.07 7.61 7.55 482.54 507.69 1.00 9594 5902
## x2cens[7] 492.91 492.96 7.78 7.87 480.17 505.61 1.00 9482 5679
## x2cens[8] 494.64 494.67 7.61 7.43 481.98 507.01 1.00 9374 5941
## x2cens[9] 487.49 487.50 7.63 7.65 474.91 499.89 1.00 7371 5158
## x2cens[10] 483.48 483.57 7.65 7.58 470.83 496.15 1.00 8485 6435
##
## # showing 10 of 110 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)
打ち切り前の2次試験の点数の平均と、1次試験の点数と2次試験の点数との相関係数の推定値の事後分布の要約です。だいたい元の値に近い値が推定されました。
# mean of x2 and correlation coefficient
$print(c("mu_x2", "cor"))
fit## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## mu_x2 499.84 499.85 0.37 0.37 499.22 500.45 1.00 1890 2785
## cor 0.67 0.67 0.03 0.03 0.62 0.72 1.00 1855 3098