Skip to content

Commit

Permalink
Updates to support no covariates being supplied
Browse files Browse the repository at this point in the history
  • Loading branch information
ngreifer committed Mar 13, 2024
1 parent 6856039 commit dd0a26b
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 22 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: cobalt
Title: Covariate Balance Tables and Plots
Version: 4.5.4.9000
Version: 4.5.4.9001
Authors@R: c(
person("Noah", "Greifer", role=c("aut", "cre"),
email = "noah.greifer@gmail.com",
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

* Minor updates to `bal.plot()` to prevent warnings due to `ggplot2` 3.5.0.

* Improved processing when no covariates are specified.

# cobalt 4.5.4

* Minor update to accommodate `ggplot2` 3.5.0. Thanks to @teunbrand. (#80)
Expand Down
51 changes: 32 additions & 19 deletions R/functions_for_processing.R
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ strata2weights <- function(strata, treat, estimand = NULL, focal = NULL) {
if (is_not_null(weights) && length(s.d.denom) == 1 && NCOL(weights) > 1) {
s.d.denom <- rep.int(s.d.denom, NCOL(weights))
}

if (is_not_null(weights) && length(s.d.denom) != NCOL(weights)) {
.err(sprintf("valid inputs to `s.d.denom` or `estimand` must have length 1 or equal to the number of valid sets of weights, which is %s",
NCOL(weights)))
Expand Down Expand Up @@ -2074,7 +2074,9 @@ get_covs_from_formula <- function(f, data = NULL, factor_sep = "_", int_sep = "
}
}

names(co_list[["C"]]) <- vapply(co_list[["C"]], function(x) paste0(x[["component"]], collapse = ""), character(1L))
if (is_not_null(co_list[["C"]])) {
names(co_list[["C"]]) <- vapply(co_list[["C"]], function(x) paste0(x[["component"]], collapse = ""), character(1L))
}

if (is_not_null(distance)) {
if (anyNA(distance, recursive = TRUE)) .err("missing values are not allowed in the distance measure")
Expand Down Expand Up @@ -2130,25 +2132,36 @@ get_covs_from_formula <- function(f, data = NULL, factor_sep = "_", int_sep = "
}

C <- do.call("cbind", clear_null(C_list[c("distance", "C", "int.poly")]))
co.names <- do.call("c", co_list[c("distance", "C", "int.poly")])

for (i in seq_along(co.names)) {
co.names[[i]]$component[co.names[[i]]$type == "fsep"] <- factor_sep
co.names[[i]]$component[co.names[[i]]$type == "isep"] <- int_sep
}
seps["factor"] <- factor_sep
seps["int"] <- int_sep

colnames(C) <- names(co.names) <- vapply(co.names, function(x) paste0(x[["component"]], collapse = ""), character(1L))

attr(co.names, "seps") <- seps

attr(C, "co.names") <- co.names

attr(C, "missing.ind") <- colnames(C)[vapply(co.names, function(x) "na" %in% x[["type"]], logical(1L))]
if ("distance" %in% names(C_list)) attr(C, "distance.names") <- names(co_list[["distance"]])
if (is_null(C)) {
C <- matrix(0, nrow = length(treat), ncol = 0,
dimnames = list(rownames(covs), NULL))
}
else {
co.names <- do.call("c", co_list[c("distance", "C", "int.poly")])

for (i in seq_along(co.names)) {
co.names[[i]]$component[co.names[[i]]$type == "fsep"] <- factor_sep
co.names[[i]]$component[co.names[[i]]$type == "isep"] <- int_sep
}

seps["factor"] <- factor_sep
seps["int"] <- int_sep

colnames(C) <- names(co.names) <- vapply(co.names, function(x) paste0(x[["component"]], collapse = ""), character(1L))


attr(co.names, "seps") <- seps

attr(C, "co.names") <- co.names

attr(C, "missing.ind") <- colnames(C)[vapply(co.names, function(x) "na" %in% x[["type"]], logical(1L))]
if ("distance" %in% names(C_list)) attr(C, "distance.names") <- names(co_list[["distance"]])

attr(C, "var_types") <- .get_types(C)
}

attr(C, "var_types") <- .get_types(C)
class(C) <- c(class(C), "processed_C")

C
Expand Down Expand Up @@ -2613,7 +2626,7 @@ balance.table <- function(C, type, weights = NULL, treat, continuous, binary, s.
abs = abs, s.weights = s.weights, bin.vars = bin.vars,
weighted.weights = weights[[1]], ...)
}

if (!no.adj && (!quick || s %in% disp)) {
for (i in weight.names) {
B[[paste.(STATS[[s]]$bal.tab_column_prefix, i)]] <- STATS[[s]]$fun(C, treat = treat, weights = weights[[i]],
Expand Down
5 changes: 3 additions & 2 deletions R/x2base.R
Original file line number Diff line number Diff line change
Expand Up @@ -2059,8 +2059,9 @@ x2base.weightit <- function(weightit, ...) {
treat <- process_treat(weightit[["treat"]], datalist = list(data, weightit.data))

#Process covs
if (is_null(covs <- weightit[["covs"]])) .err("No covariates were specified in the weightit object")
covs <- get_covs_from_formula(data = covs)
if (is_not_null(covs <- weightit[["covs"]])) {
covs <- get_covs_from_formula(data = covs)
}

#Get estimand
estimand <- weightit[["estimand"]]
Expand Down

0 comments on commit dd0a26b

Please sign in to comment.