Stanで打ち切りデータの解析

R
Stan
作者

伊東宏樹

公開

2023年7月16日

ある試験を2段階でおこなうとき、1次試験で一定の点数以上をとった人だけ2次試験を受けられるという状況を考えます。このとき、1次試験で落ちた人が2次試験も受けていたと想定した場合の2次試験の平均点と、1次試験と2次試験の点数の相関係数を求めます。

準備

パッケージを読み込みます。

library(cmdstanr)
## This is cmdstanr version 0.5.3
## - CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
## - CmdStan path: /usr/local/cmdstan
## - CmdStan version: 2.32.2
knitr::knit_engines$set(stan = cmdstanr::eng_cmdstan)
library(ggplot2)
library(mvtnorm)

データ

仮想データを生成します。1次試験と2次試験の点数は、平均がともに500点、分散共分散行列が\(\left(\begin{array}{cc} 15^2 & 100 \\ 100 & 10^2\end{array}\right)\)の2変量正規分布にしたがうとします。

set.seed(1234)
N <- 400
m1 <- 500
m2 <- 500
var1 <- 15^2
var2 <- 10^2
cov12 <- 100
Sigma <- matrix(c(var1, cov12, cov12, var2), 2, 2)
x <- rmvnorm(N, mean = c(m1, m2), sigma = Sigma) |>
  round()
d <- data.frame(x1 = x[, 1], x2 = x[, 2])

1次試験の点が490点未満のデータを打ち切ります。グラフの青い点が、2次試験の点がない打ち切りデータ点となります。

cens <- 490

d$cens <- d$x1 < cens
dobs <- d[!d$cens, ]
x1cens <- d$x1[d$cens]

ggplot(d) +
  geom_point(aes(x = x1, y = x2, colour = cens))

平均と相関係数

打ち切りのない場合とある場合の2次試験の平均値を比較します。打ち切りデータで打ち切り前のデータを推定しようとすると、当然ながら値は過大評価になってしまいます。

# mean of censored data
mean(dobs$x2)
## [1] 502.9724

# true mean of x2
mean(d$x2)
## [1] 499.5825

一方、1次試験の点数と2次試験の点数との相関係数は過小評価になってしまいます。なお、相関係数の値は理論的には0.6666667になります。

# tru value of the correlation coefficient
cor(d$x1, d$x2)
## [1] 0.6977945

# value base on the censored data
cor(dobs$x1, dobs$x2)
## [1] 0.5483848

# theoretical value of the correlation coefficient
cov12 / (sqrt(var1) * sqrt(var2))
## [1] 0.6666667

Stanモデル

Stanで、1次試験の点数(x1)と2次試験の点数(x2)の回帰モデルをつくります。2次試験を受けられなかった人の2次試験の点数x2censはパラメータとして推定します。

data {
  int<lower=0> Nobs;    // number of observed data
  int<lower=0> Ncens;   // number of censored data
  vector[Nobs] X1obs;   // x1 of observed data
  vector[Nobs] X2obs;   // x2 of observed data
  vector[Ncens] X1cens; // x1 of censored data
}

transformed data {
  vector[Nobs + Ncens] x1 = append_row(X1obs, X1cens);
  real mu_x1 = mean(x1);
}

parameters {
  vector[Ncens] x2cens; // x2 of censored data
  array[2] real beta;   // intercept and coefficient
  real<lower=0> sigma;  // standard deviation
}

transformed parameters {
  vector[Nobs + Ncens] x2 = append_row(X2obs, x2cens);
  real mu_x2 = mean(x2);
}

model {
  X2obs ~ normal(beta[1] + beta[2] * X1obs, sigma);
  x2cens ~ normal(beta[1] + beta[2] * X1cens, sigma);
}

generated quantities {
  // calculate correlation coefficient including estimated x2
  real cor = sum((x1 - mu_x1) .* (x2 - mu_x2)) / (Nobs + Ncens - 1)
               / (sd(x1) * sd(x2));
}

モデルをあてはめます。上のStanモデルをmodelというオブジェクトに入れておきます。

stan_data <- list(Nobs = nrow(dobs),
                  Ncens = length(x1cens),
                  X1obs = dobs$x1,
                  X2obs = dobs$x2,
                  X1cens = x1cens)
fit <- model$sample(data = stan_data,
                    iter_warmup = 1000,
                    iter_sampling = 2000,
                    chains = 4,
                    parallel_chains = 4,
                    refresh = 0)
## Running MCMC with 4 parallel chains...
## 
## Chain 3 finished in 10.4 seconds.
## Chain 1 finished in 10.8 seconds.
## Chain 2 finished in 10.8 seconds.
## Chain 4 finished in 10.8 seconds.
## 
## All 4 chains finished successfully.
## Mean chain execution time: 10.7 seconds.
## Total execution time: 10.9 seconds.

結果

betasigmaの事後分布の要約です。

fit$print(c("beta", "sigma"))
##  variable   mean median    sd   mad     q5    q95 rhat ess_bulk ess_tail
##   beta[1] 277.09 276.88 20.23 20.29 243.78 310.69 1.00     1664     3413
##   beta[2]   0.45   0.45  0.04  0.04   0.38   0.51 1.00     1684     3509
##   sigma     7.51   7.50  0.31  0.30   7.02   8.04 1.00     5621     5332

実際には観測されていないx2の推定値です。

# estimated x2
fit$print(c("x2cens"))
##    variable   mean median   sd  mad     q5    q95 rhat ess_bulk ess_tail
##  x2cens[1]  492.79 492.84 7.57 7.61 480.44 505.17 1.00     9947     6024
##  x2cens[2]  495.10 495.19 7.51 7.53 482.84 507.29 1.00     9099     6465
##  x2cens[3]  494.61 494.54 7.51 7.38 482.16 506.90 1.00     9912     6467
##  x2cens[4]  495.03 495.16 7.62 7.55 482.31 507.60 1.00    10052     5713
##  x2cens[5]  495.01 495.07 7.56 7.62 482.38 507.44 1.00    11140     5828
##  x2cens[6]  495.09 495.13 7.66 7.66 482.61 507.66 1.00    10181     6073
##  x2cens[7]  492.89 492.87 7.72 7.60 480.00 505.52 1.00     8474     5690
##  x2cens[8]  494.67 494.67 7.63 7.57 482.06 507.27 1.00     8558     5594
##  x2cens[9]  487.45 487.52 7.74 7.81 474.56 500.06 1.00     7952     5338
##  x2cens[10] 483.47 483.58 7.48 7.54 471.03 495.58 1.00     7551     5637
## 
##  # showing 10 of 110 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)

打ち切り前の2次試験の点数の平均と、1次試験の点数と2次試験の点数との相関係数の推定値の事後分布の要約です。だいたい元の値に近い値が推定されました。

# mean of x2 and correlation coefficient
fit$print(c("mu_x2", "cor"))
##  variable   mean median   sd  mad     q5    q95 rhat ess_bulk ess_tail
##     mu_x2 499.83 499.83 0.36 0.37 499.23 500.42 1.00     1490     3186
##     cor     0.67   0.68 0.03 0.03   0.63   0.72 1.00     1544     3148

参考文献

Stan User’s Guide: 4.3 Censored data