Skip to content

Commit

Permalink
Merge pull request #47 from nhejazi/46-tlme-tilting-in-two-phase-samp…
Browse files Browse the repository at this point in the history
…ling-dgps

tmle tilting for two-phase sampling
  • Loading branch information
nhejazi committed Mar 4, 2024
2 parents bf151d2 + f10857b commit ed73beb
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 67 deletions.
54 changes: 6 additions & 48 deletions R/estimators.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
utils::globalVariables(c("..w_names", "A", "Z", "Y", "R"))
utils::globalVariables(c("..w_names", "A", "Z", "Y", "R", "v_star"))

#' EIF for natural and interventional (in)direct effects
#'
Expand Down Expand Up @@ -501,6 +501,7 @@ est_onestep <- function(data,
)

# get estimated efficient influence function
v_star <- do.call(rbind, cv_eif_results[[1]])$v_star
obs_valid_idx <- do.call(c, lapply(folds, `[[`, "validation_set"))
cv_eif_est <- unlist(cv_eif_results$D_star)[order(obs_valid_idx)]

Expand All @@ -522,6 +523,7 @@ est_onestep <- function(data,
# output
os_est_out <- list(
theta = os_est,
theta_plugin = est_plugin(v_star),
var = os_var,
eif = (eif_est_out - os_est),
type = "onestep"
Expand Down Expand Up @@ -695,56 +697,13 @@ est_tml <- function(data,
n_obs <- nrow(data)
se_eif <- sqrt(var(cv_eif_est$D_star) / n_obs)
tilt_stop_crit <- se_eif / log(n_obs)
r_score <- b_score <- q_score <- Inf
b_score <- q_score <- Inf
tilt_two_phase_weights <- sum(data$R) != nrow(data)
d_pred <- unlist(cv_eif_results$D_pred)[order(obs_valid_idx)]

# perform iterative targeting
while (!eif_stop_crit && n_iter <= max_iter) {

# tilt the two-phase sampling weights if necessary
if (tilt_two_phase_weights && mean(r_score) > tilt_stop_crit) {

# tilting model for known weights using weighting approah
two_phase_prob_logit <- (1 / data$two_phase_weights) %>%
bound_precision() %>%
stats::qlogis()
suppressWarnings(
tilted_two_phase_fit <- stats::glm(
stats::as.formula(
"R ~ -1 + offset(two_phase_prob_logit) + weighted_d_pred"
),
data = data.table::data.table(
R = data$R,
two_phase_prob_logit = two_phase_prob_logit,
weighted_d_pred = d_pred * data$two_phase_weights
),
family = "binomial",
start = 0
)
)

# housekeeping for the tilting coefficient
if (is.na(stats::coef(tilted_two_phase_fit))) {
tilted_two_phase_fit$coefficients <- 0
} else if (abs(max(stats::coef(tilted_two_phase_fit))) > tiltmod_tol) {
tilted_two_phase_fit$coefficients <- 0
}

# tilt the two-phase sampling probs
tilted_two_phase_prob <- predict(tilted_two_phase_fit, type = "response")

# update the two-phase sampling weights
data$two_phase_weights <- 1 / tilted_two_phase_prob

# record the two-phase sampling score
r_score <- d_pred * data$two_phase_weights *
(data$R - tilted_two_phase_prob)

} else {
r_score <- 0
}

if (mean(b_score) > tilt_stop_crit) {

# compute auxiliary covariates from updated estimates
Expand Down Expand Up @@ -880,9 +839,7 @@ est_tml <- function(data,
}

# check convergence and iterate the iterator
eif_stop_crit <- all(
abs(c(mean(b_score), mean(q_score), mean(r_score))) < tilt_stop_crit
)
eif_stop_crit <- all(abs(c(mean(b_score), mean(q_score))) < tilt_stop_crit)
n_iter <- n_iter + 1
}

Expand Down Expand Up @@ -953,6 +910,7 @@ est_tml <- function(data,
# output
tmle_out <- list(
theta = tml_est,
theta_plugin = est_plugin(cv_eif_est$v_star),
var = tmle_var,
eif = (eif_est_out - tml_est),
n_iter = n_iter,
Expand Down
23 changes: 14 additions & 9 deletions R/fit_mechanisms.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ fit_treat_mech <- function(train_data,
# update observation weights with two-phase sampling weights, if necessary
# NOTE: importantly, re-weighting the propensity score (g) estimator is
# not necessary under two-phase sampling of the mediators
train_data[, obs_weights := R * two_phase_weights * obs_weights]
valid_data[, obs_weights := R * two_phase_weights * obs_weights]
train_data[, obs_weights := two_phase_weights * obs_weights]
valid_data[, obs_weights := two_phase_weights * obs_weights]

# remove observations that were not sampled in second stage
train_data <- train_data[R == 1, ]
Expand Down Expand Up @@ -185,8 +185,8 @@ fit_out_mech <- function(train_data,
m_names,
w_names) {
# update observation weights with two-phase sampling weights, if necessary
train_data[, obs_weights := R * two_phase_weights * obs_weights]
valid_data[, obs_weights := R * two_phase_weights * obs_weights]
train_data[, obs_weights := two_phase_weights * obs_weights]
valid_data[, obs_weights := two_phase_weights * obs_weights]

# remove observations that were not sampled in second stage
train_data <- train_data[R == 1, ]
Expand Down Expand Up @@ -372,13 +372,18 @@ fit_moc_mech <- function(train_data,
## construct task for nuisance parameter fit
if (type == "q") {
cov_names <- w_names

# update observation weights with two-phase sampling weights, if necessary
# NOTE: Might not be necessary, check with Nima
train_data[, obs_weights := two_phase_weights * obs_weights]
valid_data[, obs_weights := two_phase_weights * obs_weights]

} else if (type == "r") {
cov_names <- c(m_names, w_names)

# update observation weights with two-phase sampling weights, if necessary
# NOTE: Should this be applied to q as well? Probably.
train_data[, obs_weights := R * two_phase_weights * obs_weights]
valid_data[, obs_weights := R * two_phase_weights * obs_weights]
train_data[, obs_weights := two_phase_weights * obs_weights]
valid_data[, obs_weights := two_phase_weights * obs_weights]

# remove observations that were not sampled in second stage
train_data <- train_data[R == 1, ]
Expand Down Expand Up @@ -577,8 +582,8 @@ fit_nuisance_u <- function(train_data,
h_out,
w_names) {
# update observation weights with two-phase sampling weights, if necessary
train_data[, obs_weights := R * obs_weights]
valid_data[, obs_weights := R * obs_weights]
train_data[, obs_weights := two_phase_weights * obs_weights]
valid_data[, obs_weights := two_phase_weights * obs_weights]

## extract nuisance estimates necessary for constructing pseudo-outcome
b_prime <- b_out$b_est_train$b_pred_A_prime
Expand Down
21 changes: 11 additions & 10 deletions R/medoutcon.R
Original file line number Diff line number Diff line change
Expand Up @@ -299,17 +299,18 @@ medoutcon <- function(W,
pm_theta_est <- 1 - log(est_params[[3]]$theta / est_params[[2]]$theta) /
log(est_params[[1]]$theta / est_params[[2]]$theta)
pm_eif_est <- -est_params[[3]]$eif /
(est_params[[3]]$theta * log(est_params[[1]]$theta /
est_params[[2]]$theta)) +
(est_params[[3]]$theta_plugin * log(est_params[[1]]$theta_plugin /
est_params[[2]]$theta_plugin)) +
est_params[[2]]$eif * (
(log(est_params[[1]]$theta / est_params[[2]]$theta) -
log(est_params[[3]]$theta / est_params[[2]]$theta)) /
(est_params[[2]]$theta *
(log(est_params[[1]]$theta / est_params[[2]]$theta))^2)) +
est_params[[1]]$eif * log(est_params[[3]]$theta /
est_params[[2]]$theta) /
(est_params[[1]]$theta * (log(est_params[[1]]$theta /
est_params[[2]]$theta))^2)
(log(est_params[[1]]$theta_plugin / est_params[[2]]$theta_plugin) -
log(est_params[[3]]$theta_plugin / est_params[[2]]$theta_plugin)) /
(est_params[[2]]$theta_plugin *
(log(est_params[[1]]$theta_plugin /
est_params[[2]]$theta_plugin))^2)) +
est_params[[1]]$eif * log(est_params[[3]]$theta_plugin /
est_params[[2]]$theta_plugin) /
(est_params[[1]]$theta_plugin * (log(est_params[[1]]$theta_plugin /
est_params[[2]]$theta_plugin))^2)
pm_var_est <- stats::var(pm_eif_est) / nrow(data)

# construct output in same style as for contrast-specific parameter
Expand Down

0 comments on commit ed73beb

Please sign in to comment.