Stanで打ち切りデータの解析

R
Stan
作者

伊東宏樹

公開

2023年7月16日

ある試験を2段階でおこなうとき、1次試験で一定の点数以上をとった人だけ2次試験を受けられるという状況を考えます。このとき、1次試験で落ちた人が2次試験も受けていたと想定した場合の2次試験の平均点と、1次試験と2次試験の点数の相関係数を求めます。

準備

パッケージを読み込みます。

library(cmdstanr)
## This is cmdstanr version 0.8.1
## - CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
## - CmdStan path: /usr/local/cmdstan
## - CmdStan version: 2.36.0
knitr::knit_engines$set(stan = cmdstanr::eng_cmdstan)
library(ggplot2)
library(mvtnorm)

データ

仮想データを生成します。1次試験と2次試験の点数は、平均がともに500点、分散共分散行列が\(\left(\begin{array}{cc} 15^2 & 100 \\ 100 & 10^2\end{array}\right)\)の2変量正規分布にしたがうとします。

set.seed(1234)
N <- 400
m1 <- 500
m2 <- 500
var1 <- 15^2
var2 <- 10^2
cov12 <- 100
Sigma <- matrix(c(var1, cov12, cov12, var2), 2, 2)
x <- rmvnorm(N, mean = c(m1, m2), sigma = Sigma) |>
  round()
d <- data.frame(x1 = x[, 1], x2 = x[, 2])

1次試験の点が490点未満のデータを打ち切ります。グラフの青い点が、2次試験の点がない打ち切りデータ点となります。

cens <- 490

d$cens <- d$x1 < cens
dobs <- d[!d$cens, ]
x1cens <- d$x1[d$cens]

ggplot(d) +
  geom_point(aes(x = x1, y = x2, colour = cens))

平均と相関係数

打ち切りのない場合とある場合の2次試験の平均値を比較します。打ち切りデータで打ち切り前のデータを推定しようとすると、当然ながら値は過大評価になってしまいます。

# mean of censored data
mean(dobs$x2)
## [1] 502.9724

# true mean of x2
mean(d$x2)
## [1] 499.5825

一方、1次試験の点数と2次試験の点数との相関係数は過小評価になってしまいます。なお、相関係数の値は理論的には0.6666667になります。

# tru value of the correlation coefficient
cor(d$x1, d$x2)
## [1] 0.6977945

# value base on the censored data
cor(dobs$x1, dobs$x2)
## [1] 0.5483848

# theoretical value of the correlation coefficient
cov12 / (sqrt(var1) * sqrt(var2))
## [1] 0.6666667

Stanモデル

Stanで、1次試験の点数(x1)と2次試験の点数(x2)の回帰モデルをつくります。2次試験を受けられなかった人の2次試験の点数x2censはパラメータとして推定します。

data {
  int<lower=0> Nobs;    // number of observed data
  int<lower=0> Ncens;   // number of censored data
  vector[Nobs] X1obs;   // x1 of observed data
  vector[Nobs] X2obs;   // x2 of observed data
  vector[Ncens] X1cens; // x1 of censored data
}

transformed data {
  vector[Nobs + Ncens] x1 = append_row(X1obs, X1cens);
  real mu_x1 = mean(x1);
}

parameters {
  vector[Ncens] x2cens; // x2 of censored data
  array[2] real beta;   // intercept and coefficient
  real<lower=0> sigma;  // standard deviation
}

transformed parameters {
  vector[Nobs + Ncens] x2 = append_row(X2obs, x2cens);
  real mu_x2 = mean(x2);
}

model {
  X2obs ~ normal(beta[1] + beta[2] * X1obs, sigma);
  x2cens ~ normal(beta[1] + beta[2] * X1cens, sigma);
}

generated quantities {
  // calculate correlation coefficient including estimated x2
  real cor = sum((x1 - mu_x1) .* (x2 - mu_x2)) / (Nobs + Ncens - 1)
               / (sd(x1) * sd(x2));
}

モデルをあてはめます。上のStanモデルをmodelというオブジェクトに入れておきます。

stan_data <- list(Nobs = nrow(dobs),
                  Ncens = length(x1cens),
                  X1obs = dobs$x1,
                  X2obs = dobs$x2,
                  X1cens = x1cens)
fit <- model$sample(data = stan_data,
                    iter_warmup = 1000,
                    iter_sampling = 2000,
                    chains = 4,
                    parallel_chains = 4,
                    refresh = 0)
## Running MCMC with 4 parallel chains...
## 
## Chain 2 finished in 10.1 seconds.
## Chain 4 finished in 10.3 seconds.
## Chain 3 finished in 10.9 seconds.
## Chain 1 finished in 13.4 seconds.
## 
## All 4 chains finished successfully.
## Mean chain execution time: 11.2 seconds.
## Total execution time: 13.5 seconds.

結果

betasigmaの事後分布の要約です。

fit$print(c("beta", "sigma"))
##  variable   mean median    sd   mad     q5    q95 rhat ess_bulk ess_tail
##   beta[1] 277.85 277.89 20.71 20.84 243.64 311.92 1.00     2065     3381
##   beta[2]   0.44   0.44  0.04  0.04   0.38   0.51 1.00     2082     3381
##   sigma     7.52   7.51  0.31  0.31   7.02   8.04 1.00     5644     5268

実際には観測されていないx2の推定値です。

# estimated x2
fit$print(c("x2cens"))
##    variable   mean median   sd  mad     q5    q95 rhat ess_bulk ess_tail
##  x2cens[1]  492.85 492.82 7.68 7.56 480.37 505.52 1.00     9209     6317
##  x2cens[2]  495.14 495.23 7.53 7.54 482.62 507.41 1.00    10961     6061
##  x2cens[3]  494.54 494.55 7.63 7.56 482.07 507.04 1.00    10723     5891
##  x2cens[4]  495.14 495.12 7.47 7.33 482.64 507.39 1.00    10119     6080
##  x2cens[5]  495.12 495.08 7.58 7.71 482.55 507.69 1.00     9483     6335
##  x2cens[6]  495.11 495.07 7.61 7.55 482.54 507.69 1.00     9594     5902
##  x2cens[7]  492.91 492.96 7.78 7.87 480.17 505.61 1.00     9482     5679
##  x2cens[8]  494.64 494.67 7.61 7.43 481.98 507.01 1.00     9374     5941
##  x2cens[9]  487.49 487.50 7.63 7.65 474.91 499.89 1.00     7371     5158
##  x2cens[10] 483.48 483.57 7.65 7.58 470.83 496.15 1.00     8485     6435
## 
##  # showing 10 of 110 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)

打ち切り前の2次試験の点数の平均と、1次試験の点数と2次試験の点数との相関係数の推定値の事後分布の要約です。だいたい元の値に近い値が推定されました。

# mean of x2 and correlation coefficient
fit$print(c("mu_x2", "cor"))
##  variable   mean median   sd  mad     q5    q95 rhat ess_bulk ess_tail
##     mu_x2 499.84 499.85 0.37 0.37 499.22 500.45 1.00     1890     2785
##     cor     0.67   0.67 0.03 0.03   0.62   0.72 1.00     1855     3098

参考文献

Stan User’s Guide: 4.3 Censored data