Skip to content

Commit

Permalink
Bug fixes for wrong-ish method names to satisfy revdep checks for mvG…
Browse files Browse the repository at this point in the history
…PS; changed "multi" to "multi-category"
  • Loading branch information
ngreifer committed Aug 19, 2024
1 parent 7f2ce32 commit 6cdb3c8
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 50 deletions.
40 changes: 23 additions & 17 deletions R/functions_for_processing.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
.method_to_proper_method <- function(method) {
if (is_null(method)) return(NULL)

if (!is.character(method) || method %nin% unlist(grab(.weightit_methods, "alias"))) {
if (!is.character(method)) {
return(method)
}

method <- tolower(method)

if (method %nin% unlist(grab(.weightit_methods, "alias"))) {
return(method)
}

Expand All @@ -10,26 +16,25 @@
setNames(rep(m, length(alias)), alias)
}))

method <- tolower(method)
unname(.allowable.methods[method])
}

.check_acceptable_method <- function(method, msm = FALSE, force = FALSE) {
bad.method <- FALSE

if (missing(method)) method <- "glm"
else if (is_null(method)) return(invisible(NULL))
else if (length(method) > 1L) bad.method <- TRUE
else if (is.character(method)) {
if (tolower(method) %nin% unlist(grab(.weightit_methods, "alias"))) bad.method <- TRUE
if (missing(method)) {
method <- "glm"
}
else if (is_null(method)) {
return(invisible(NULL))
}
else if (!is.function(method)) bad.method <- TRUE

if (bad.method) {
if (identical(method, "twang")) {
.err('"twang" is no longer an acceptable argument to `method`. Please use "gbm" for generalized boosted modeling')
}
if (identical(method, "twang")) {
.err('"twang" is no longer an acceptable argument to `method`. Please use "gbm" for generalized boosted modeling')
}

if ((!is.character(method) && !is.function(method)) ||
(is.character(method) && (length(method) > 1L ||
.method_to_proper_method(method) %nin% names(.weightit_methods)))) {
.err(sprintf("`method` must be a string of length 1 containing the name of an acceptable weighting method or a function that produces weights. Allowable methods:\n%s",
word_list(names(.weightit_methods), and.or = FALSE, quotes = 2)),
tidy = FALSE)
Expand All @@ -46,7 +51,8 @@

.check_method_treat.type <- function(method, treat.type) {
if (is_not_null(method) && is.character(method) &&
treat.type %nin% .weightit_methods[[method]]$treat_type) {
(method %in% names(.weightit_methods)) &&
(treat.type %nin% .weightit_methods[[method]]$treat_type)) {
.err(sprintf("%s can only be used with a %s treatment",
.method_to_phrase(method),
word_list(.weightit_methods[[method]]$treat_type, and.or = "or")))
Expand Down Expand Up @@ -122,7 +128,7 @@
else .weightit_methods[[method]]$estimand
}

if (treat.type == "multi") {
if (treat.type == "multi-category") {
allowable_estimands <- setdiff(allowable_estimands, "ATOS")
}

Expand Down Expand Up @@ -330,14 +336,14 @@
if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
treat.type <- get_treat_type(treat)

unique.treat <- unique(treat, nmax = switch(treat.type, "binary" = 2, "multi" = length(treat)/4))
unique.treat <- unique(treat, nmax = switch(treat.type, "binary" = 2, "multi-category" = length(treat)/4))

#Check focal
if (is_not_null(focal) && (length(focal) > 1L || focal %nin% unique.treat)) {
.err("the argument supplied to `focal` must be the name of a level of treatment")
}

if (treat.type == "multi") {
if (treat.type == "multi-category") {

if (estimand %nin% c("ATT", "ATC") && is_not_null(focal)) {
.wrn(sprintf("`estimand = %s` is not compatible with `focal`. Setting `estimand` to \"ATT\"",
Expand Down
2 changes: 1 addition & 1 deletion R/get_w_from_ps.R
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ get_w_from_ps <- function(ps, treat, estimand = "ATE", focal = NULL, treated = N
}

}
else if (treat.type == "multi") {
else if (treat.type == "multi-category") {
if (is.matrix(ps)) {
if (!is.numeric(ps)) {
.err("`ps` must be numeric when supplied as a matrix")
Expand Down
6 changes: 3 additions & 3 deletions R/sbps.R
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ sbps <- function(obj, obj2 = NULL, moderator = NULL, formula = NULL, data = NULL
abs = TRUE, s.weights = s.weights[moderator.factor == g],
bin.vars = bin.vars)))
}
else if (treat.type == "multi") {
else if (treat.type == "multi-category") {
if (is_not_null(focal)) {
bin.treat <- as.numeric(treat == focal)
s.d.denom <- switch(estimand, ATT = "treated", ATC = "control", "all")
Expand Down Expand Up @@ -337,7 +337,7 @@ sbps <- function(obj, obj2 = NULL, moderator = NULL, formula = NULL, data = NULL
abs = TRUE, s.weights = s.weights[moderator.factor == g],
bin.vars = bin.vars)))
}
else if (treat.type == "multi") {
else if (treat.type == "multi-category") {
if (is_not_null(focal)) {
bin.treat <- as.numeric(treat == focal)
s.d.denom <- switch(estimand, ATT = "treated", ATC = "control", "all")
Expand Down Expand Up @@ -547,7 +547,7 @@ summary.weightit.sbps <- function(object, top = 5, ignore.s.weights = FALSE, ...
nn["Weighted", ] <- c(ESS(w[t==0]),
ESS(w[t==1]))
}
else if (treat.type == "multi") {
else if (treat.type == "multi-category") {
out$weight.range <- setNames(lapply(levels(t), function(x) c(min(w[w > 0 & t == x]),
max(w[w > 0 & t == x]))),
levels(t))
Expand Down
2 changes: 1 addition & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ assign_treat_type <- function(treat, use.multi = FALSE) {
treat.type <- "binary"
}
else if (use.multi || chk::vld_character_or_factor(treat)) {
treat.type <- "multi"
treat.type <- "multi-category"
if (!inherits(treat, "processed.treat")) treat <- factor(treat)
}
else {
Expand Down
4 changes: 2 additions & 2 deletions R/weightit.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ weightit <- function(formula, data = NULL, method = "glm", estimand = "ATE", sta
}
else if (is.character(method)) {
method <- .method_to_proper_method(method)
attr(method, "name") <- method
.check_method_treat.type(method, treat.type)
attr(method, "name") <- method
}
else { #function
method.name <- deparse1(substitute(method))
Expand Down Expand Up @@ -414,7 +414,7 @@ print.weightit <- function(x, ...) {
cat(sprintf(" - treatment: %s\n",
switch(treat.type,
"continuous" = "continuous",
"multi" = sprintf("%s-category (%s)",
"multi-category" = sprintf("%s-category (%s)",
nunique(x[["treat"]]),
word_list(levels(x[["treat"]]), and.or = FALSE)),
"binary" = "2-category")))
Expand Down
2 changes: 1 addition & 1 deletion R/weightit.fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ weightit.fit <- function(covs, treat, method = "glm", s.weights = NULL, by.facto
}

fun <- switch(treat.type,
"multi" = paste.(fun, "multi"),
"multi-category" = paste.(fun, "multi"),
"continuous" = paste.(fun, "cont"),
fun)
}
Expand Down
16 changes: 8 additions & 8 deletions R/weightit2cbps.R
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ weightitMSM2cbps <- function(covs.list, treat.list, s.weights, subset, missing,
covs.list[[i]] <- add_missing_indicators(covs.list[[i]])
}

if (treat.types[i] %in% c("binary", "multi")) {
if (treat.types[i] %in% c("binary", "multi-category")) {
covs.list[[i]] <- cbind(.int_poly_f(covs.list[[i]], poly = moments,
int = int, center = TRUE),
.quantile_f(covs.list[[i]], qu = A[["quantile"]],
Expand All @@ -797,7 +797,7 @@ weightitMSM2cbps <- function(covs.list, treat.list, s.weights, subset, missing,
treat.list[[i]] <- switch(
treat.types[i],
"binary" = binarize(treat.list[[i]], one = get_treated_level(treat.list[[i]])),
"multi" = factor(treat.list[[i]]),
"multi-category" = factor(treat.list[[i]]),
"continuous" = scale_w(treat.list[[i]], s.weights)
)

Expand Down Expand Up @@ -828,7 +828,7 @@ weightitMSM2cbps <- function(covs.list, treat.list, s.weights, subset, missing,
coef_ind[[i]] <- length(unlist(coef_ind)) + switch(
treat.types[i],
"binary" = seq_col(covs.list[[i]]),
"multi" = seq_len((nlevels(treat.list[[i]]) - 1L) * ncol(covs.list[[i]])),
"multi-category" = seq_len((nlevels(treat.list[[i]]) - 1L) * ncol(covs.list[[i]])),
"continuous" = seq_len(3L + ncol(covs.list[[i]]))
)
}
Expand All @@ -838,7 +838,7 @@ weightitMSM2cbps <- function(covs.list, treat.list, s.weights, subset, missing,
"binary" = function(B, X, A) {
plogis(drop(X %*% B))
},
"multi" = function(B, X, A) {
"multi-category" = function(B, X, A) {
qq <- exp(X %*% matrix(B, nrow = ncol(X)))

pp <- cbind(1, qq) / (1 + rowSums(qq))
Expand All @@ -859,7 +859,7 @@ weightitMSM2cbps <- function(covs.list, treat.list, s.weights, subset, missing,
"binary" = function(p, A, B) {
A / p + (1 - A) / (1 - p)
},
"multi" = function(p, A, B) {
"multi-category" = function(p, A, B) {
w <- numeric(length(A))
for (a in levels(A)) {
w[A == a] <- 1 / p[A == a, a]
Expand All @@ -886,7 +886,7 @@ weightitMSM2cbps <- function(covs.list, treat.list, s.weights, subset, missing,
"binary" = function(w, B, X, A, SW) {
SW * w * (A - (1 - A)) * X
},
"multi" = function(w, B, X, A, SW) {
"multi-category" = function(w, B, X, A, SW) {
do.call("cbind", lapply(utils::combn(levels(treat.list[[i]]), 2, simplify = FALSE), function(co) {
SW * w * ((A == co[1]) - (A == co[2])) * X
}))
Expand Down Expand Up @@ -928,7 +928,7 @@ weightitMSM2cbps <- function(covs.list, treat.list, s.weights, subset, missing,
switch(treat.types[i],
"binary" = glm.fit(covs.list[[i]], treat.list[[i]], family = binomial(),
weights = s.weights)$coefficients,
"multi" = .multinom_weightit.fit(covs.list[[i]], treat.list[[i]], hess = FALSE,
"multi-category" = .multinom_weightit.fit(covs.list[[i]], treat.list[[i]], hess = FALSE,
weights = s.weights)$coefficients,
"continuous" = {
init.fit <- lm.wfit(covs.list[[i]], treat.list[[i]], w = s.weights)
Expand Down Expand Up @@ -961,7 +961,7 @@ weightitMSM2cbps <- function(covs.list, treat.list, s.weights, subset, missing,
"binary" = function(p, X, A, SW) {
SW * (A - p) * X
},
"multi" = function(p, X, A, SW) {
"multi-category" = function(p, X, A, SW) {
do.call("cbind", lapply(levels(A), function(i) {
SW * ((A == i) - p[,i]) * X
}))
Expand Down
6 changes: 3 additions & 3 deletions R/weightit2gbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ weightit2gbm <- function(covs, treat, s.weights, estimand, focal, subset,
}

if (any(A[["use.offset"]])) {
if (treat.type == "multi") {
if (treat.type == "multi-category") {
.err("`use.offset` cannot be used with multi-category treatments")
}

Expand Down Expand Up @@ -515,11 +515,11 @@ weightit2gbm <- function(covs, treat, s.weights, estimand, focal, subset,
info <- list(best.tree = best.tree,
tree.val = tree.val)

if (treat.type == "multi") best.ps <- NULL
if (treat.type == "multi-category") best.ps <- NULL
}
}

if (treat.type == "multi") ps <- NULL
if (treat.type == "multi-category") ps <- NULL
}

if (nrow(tune) > 1) {
Expand Down
2 changes: 1 addition & 1 deletion R/weightit2optweight.R
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ weightitMSM2optweight <- function(covs.list, treat.list, s.weights, subset, miss
covs.list[[i]] <- cbind(covs.list[[i]],
.int_poly_f(covs.list[[i]], poly = moments, int = int))

if (treat.types[i] %in% c("binary", "multi")) {
if (treat.types[i] %in% c("binary", "multi-category")) {
covs.list[[i]] <- cbind(.int_poly_f(covs.list[[i]], poly = moments, int = int, center = TRUE),
.quantile_f(covs.list[[i]], qu = A[["quantile"]], s.weights = s.weights,
treat = treat.list[[i]]))
Expand Down
2 changes: 1 addition & 1 deletion R/weightitMSM.R
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ print.weightitMSM <- function(x, ...) {
i,
switch(treat.types[i],
"continuous" = "continuous",
"multi" = sprintf("%s-category (%s)",
"multi-category" = sprintf("%s-category (%s)",
nunique(x[["treat.list"]][[i]]),
word_list(levels(x[["treat.list"]][[i]]), and.or = FALSE)),
"binary" = "2-category")))
Expand Down
22 changes: 11 additions & 11 deletions R/weightit_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#' @details
#' Each component is itself a list containing the following components:
#' \describe{
#' \item{`treat_type`}{at least one of `"binary"`, `"multi"`, or `"continuous"` indicating which treatment types are available for this method.}
#' \item{`treat_type`}{at least one of `"binary"`, `"multi-category"`, or `"continuous"` indicating which treatment types are available for this method.}
#' \item{`estimand`}{which estimands are available for this method. All methods that support binary and multi-category treatments accept `"ATE"`, `"ATT"`, and `"ATC"`, as well as some other estimands depending on the method. See [get_w_from_ps()] for more details about what each estimand means.}
#' \item{`alias`}{a character vector of aliases for the method. When an alias is supplied, the corresponding method will still be dispatched. For example, the canonical method to request entropy balancing is `"ebal"`, but `"ebalance"` and `"entropy"` also work. The first value is the canonical name.}
#' \item{`description`}{a string containing the description of the name in English.}
Expand Down Expand Up @@ -51,7 +51,7 @@
#' @export
.weightit_methods <- {list(
"glm" = list(
treat_type = c("binary", "multi", "continuous"),
treat_type = c("binary", "multi-category", "continuous"),
estimand = c("ATE", "ATT", "ATC", "ATO", "ATM", "ATOS"),
alias = c("glm", "ps"),
description = "propensity score weighting with GLM",
Expand All @@ -69,7 +69,7 @@
plot.weightit_ok = FALSE
),
"bart" = list(
treat_type = c("binary", "multi", "continuous"),
treat_type = c("binary", "multi-category", "continuous"),
estimand = c("ATE", "ATT", "ATC", "ATO", "ATM", "ATOS"),
alias = c("bart"),
description = "propensity score weighting with BART",
Expand All @@ -87,7 +87,7 @@
plot.weightit_ok = FALSE
),
"cbps" = list(
treat_type = c("binary", "multi", "continuous"),
treat_type = c("binary", "multi-category", "continuous"),
estimand = c("ATE", "ATT", "ATC"),
alias = c("cbps", "cbgps"),
description = "covariate balancing propensity score weighting",
Expand All @@ -105,7 +105,7 @@
plot.weightit_ok = FALSE
),
"ebal" = list(
treat_type = c("binary", "multi", "continuous"),
treat_type = c("binary", "multi-category", "continuous"),
estimand = c("ATE", "ATT", "ATC"),
alias = c("ebal", "ebalance", "entropy"),
description = "entropy balancing",
Expand All @@ -123,7 +123,7 @@
plot.weightit_ok = FALSE
),
"energy" = list(
treat_type = c("binary", "multi", "continuous"),
treat_type = c("binary", "multi-category", "continuous"),
estimand = c("ATE", "ATT", "ATC"),
alias = c("energy", "dcows"),
description = "energy balancing",
Expand All @@ -141,7 +141,7 @@
plot.weightit_ok = FALSE
),
"gbm" = list(
treat_type = c("binary", "multi", "continuous"),
treat_type = c("binary", "multi-category", "continuous"),
estimand = c("ATE", "ATT", "ATC", "ATO", "ATM", "ATOS"),
alias = c("gbm", "gbr"),
description = "propensity score weighting with GBM",
Expand All @@ -159,7 +159,7 @@
plot.weightit_ok = TRUE
),
"ipt" = list(
treat_type = c("binary", "multi"),
treat_type = c("binary", "multi-category"),
estimand = c("ATE", "ATT", "ATC"),
alias = c("ipt"),
description = "inverse probability tilting",
Expand All @@ -177,7 +177,7 @@
plot.weightit_ok = FALSE
),
"npcbps" = list(
treat_type = c("binary", "multi", "continuous"),
treat_type = c("binary", "multi-category", "continuous"),
estimand = c("ATE"),
alias = c("npcbps", "npcbgps"),
description = "non-parametric covariate balancing propensity score weighting",
Expand All @@ -195,7 +195,7 @@
plot.weightit_ok = FALSE
),
"optweight" = list(
treat_type = c("binary", "multi", "continuous"),
treat_type = c("binary", "multi-category", "continuous"),
estimand = c("ATE", "ATT", "ATC"),
alias = c("optweight", "sbw"),
description = "stable balancing weights",
Expand All @@ -213,7 +213,7 @@
plot.weightit_ok = TRUE
),
"super" = list(
treat_type = c("binary", "multi", "continuous"),
treat_type = c("binary", "multi-category", "continuous"),
estimand = c("ATE", "ATT", "ATC", "ATO", "ATM", "ATOS"),
alias = c("super", "superlearner"),
description = "propensity score weighting with SuperLearner",
Expand Down
2 changes: 1 addition & 1 deletion man/dot-weightit_methods.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6cdb3c8

Please sign in to comment.