library(cmdstanr)
## This is cmdstanr version 0.5.3
## - CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
## - CmdStan path: /usr/local/cmdstan
## - CmdStan version: 2.32.2
::knit_engines$set(stan = cmdstanr::eng_cmdstan)
knitrlibrary(ggplot2)
library(mvtnorm)
ある試験を2段階でおこなうとき、1次試験で一定の点数以上をとった人だけ2次試験を受けられるという状況を考えます。このとき、1次試験で落ちた人が2次試験も受けていたと想定した場合の2次試験の平均点と、1次試験と2次試験の点数の相関係数を求めます。
準備
パッケージを読み込みます。
データ
仮想データを生成します。1次試験と2次試験の点数は、平均がともに500点、分散共分散行列が\(\left(\begin{array}{cc} 15^2 & 100 \\ 100 & 10^2\end{array}\right)\)の2変量正規分布にしたがうとします。
set.seed(1234)
<- 400
N <- 500
m1 <- 500
m2 <- 15^2
var1 <- 10^2
var2 <- 100
cov12 <- matrix(c(var1, cov12, cov12, var2), 2, 2)
Sigma <- rmvnorm(N, mean = c(m1, m2), sigma = Sigma) |>
x round()
<- data.frame(x1 = x[, 1], x2 = x[, 2]) d
1次試験の点が490点未満のデータを打ち切ります。グラフの青い点が、2次試験の点がない打ち切りデータ点となります。
<- 490
cens
$cens <- d$x1 < cens
d<- d[!d$cens, ]
dobs <- d$x1[d$cens]
x1cens
ggplot(d) +
geom_point(aes(x = x1, y = x2, colour = cens))
平均と相関係数
打ち切りのない場合とある場合の2次試験の平均値を比較します。打ち切りデータで打ち切り前のデータを推定しようとすると、当然ながら値は過大評価になってしまいます。
# mean of censored data
mean(dobs$x2)
## [1] 502.9724
# true mean of x2
mean(d$x2)
## [1] 499.5825
一方、1次試験の点数と2次試験の点数との相関係数は過小評価になってしまいます。なお、相関係数の値は理論的には0.6666667になります。
# tru value of the correlation coefficient
cor(d$x1, d$x2)
## [1] 0.6977945
# value base on the censored data
cor(dobs$x1, dobs$x2)
## [1] 0.5483848
# theoretical value of the correlation coefficient
/ (sqrt(var1) * sqrt(var2))
cov12 ## [1] 0.6666667
Stanモデル
Stanで、1次試験の点数(x1
)と2次試験の点数(x2
)の回帰モデルをつくります。2次試験を受けられなかった人の2次試験の点数x2cens
はパラメータとして推定します。
data {
int<lower=0> Nobs; // number of observed data
int<lower=0> Ncens; // number of censored data
vector[Nobs] X1obs; // x1 of observed data
vector[Nobs] X2obs; // x2 of observed data
vector[Ncens] X1cens; // x1 of censored data
}
transformed data {
vector[Nobs + Ncens] x1 = append_row(X1obs, X1cens);
real mu_x1 = mean(x1);
}
parameters {
vector[Ncens] x2cens; // x2 of censored data
array[2] real beta; // intercept and coefficient
real<lower=0> sigma; // standard deviation
}
transformed parameters {
vector[Nobs + Ncens] x2 = append_row(X2obs, x2cens);
real mu_x2 = mean(x2);
}
model {
1] + beta[2] * X1obs, sigma);
X2obs ~ normal(beta[1] + beta[2] * X1cens, sigma);
x2cens ~ normal(beta[
}
generated quantities {
// calculate correlation coefficient including estimated x2
real cor = sum((x1 - mu_x1) .* (x2 - mu_x2)) / (Nobs + Ncens - 1)
/ (sd(x1) * sd(x2)); }
モデルをあてはめます。上のStanモデルをmodel
というオブジェクトに入れておきます。
<- list(Nobs = nrow(dobs),
stan_data Ncens = length(x1cens),
X1obs = dobs$x1,
X2obs = dobs$x2,
X1cens = x1cens)
<- model$sample(data = stan_data,
fit iter_warmup = 1000,
iter_sampling = 2000,
chains = 4,
parallel_chains = 4,
refresh = 0)
## Running MCMC with 4 parallel chains...
##
## Chain 3 finished in 10.4 seconds.
## Chain 1 finished in 10.8 seconds.
## Chain 2 finished in 10.8 seconds.
## Chain 4 finished in 10.8 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 10.7 seconds.
## Total execution time: 10.9 seconds.
結果
beta
とsigma
の事後分布の要約です。
$print(c("beta", "sigma"))
fit## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## beta[1] 277.09 276.88 20.23 20.29 243.78 310.69 1.00 1664 3413
## beta[2] 0.45 0.45 0.04 0.04 0.38 0.51 1.00 1684 3509
## sigma 7.51 7.50 0.31 0.30 7.02 8.04 1.00 5621 5332
実際には観測されていないx2
の推定値です。
# estimated x2
$print(c("x2cens"))
fit## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## x2cens[1] 492.79 492.84 7.57 7.61 480.44 505.17 1.00 9947 6024
## x2cens[2] 495.10 495.19 7.51 7.53 482.84 507.29 1.00 9099 6465
## x2cens[3] 494.61 494.54 7.51 7.38 482.16 506.90 1.00 9912 6467
## x2cens[4] 495.03 495.16 7.62 7.55 482.31 507.60 1.00 10052 5713
## x2cens[5] 495.01 495.07 7.56 7.62 482.38 507.44 1.00 11140 5828
## x2cens[6] 495.09 495.13 7.66 7.66 482.61 507.66 1.00 10181 6073
## x2cens[7] 492.89 492.87 7.72 7.60 480.00 505.52 1.00 8474 5690
## x2cens[8] 494.67 494.67 7.63 7.57 482.06 507.27 1.00 8558 5594
## x2cens[9] 487.45 487.52 7.74 7.81 474.56 500.06 1.00 7952 5338
## x2cens[10] 483.47 483.58 7.48 7.54 471.03 495.58 1.00 7551 5637
##
## # showing 10 of 110 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)
打ち切り前の2次試験の点数の平均と、1次試験の点数と2次試験の点数との相関係数の推定値の事後分布の要約です。だいたい元の値に近い値が推定されました。
# mean of x2 and correlation coefficient
$print(c("mu_x2", "cor"))
fit## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## mu_x2 499.83 499.83 0.36 0.37 499.23 500.42 1.00 1490 3186
## cor 0.67 0.68 0.03 0.03 0.63 0.72 1.00 1544 3148