Skip to content

Commit

Permalink
Updated to bal.plot to fix warnings due to gpplot2 3.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ngreifer committed Mar 11, 2024
1 parent 12ebe2f commit 6856039
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 39 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: cobalt
Title: Covariate Balance Tables and Plots
Version: 4.5.4
Version: 4.5.4.9000
Authors@R: c(
person("Noah", "Greifer", role=c("aut", "cre"),
email = "noah.greifer@gmail.com",
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
`cobalt` News and Updates
======

# cobalt (development version)

* Minor updates to `bal.plot()` to prevent warnings due to `ggplot2` 3.5.0.

# cobalt 4.5.4

* Minor update to accommodate `ggplot2` 3.5.0. Thanks to @teunbrand. (#80)
Expand Down
85 changes: 49 additions & 36 deletions R/bal.plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@
#' type = "histogram", colors = c("white", "black"))
#'
#' @examplesIf requireNamespace("WeightIt", quietly = TRUE)
#' #PS weighting with a continuous treatment
#' #Entropy balancing with a continuous treatment
#' w.out <- WeightIt::weightit(re75 ~ age + I(age^2) + educ +
#' race + married + nodegree,
#' data = lalonde)
#' data = lalonde, method = "ebal")
#'
#' bal.plot(w.out, "age", which = "both")
#' bal.plot(w.out, "married", which = "both")

#' @rdname bal.plot
#' @export
bal.plot <- function(x, var.name, ..., which, which.sub = NULL, cluster = NULL, which.cluster = NULL,
Expand Down Expand Up @@ -107,7 +107,7 @@ bal.plot <- function(x, var.name, ..., which, which.sub = NULL, cluster = NULL,
if (is_null(X$covs.list)) {
#Point treatment
X$covs <- .get_C2(X$covs, addl = X$addl, distance = X$distance, cluster = X$cluster, treat = X$treat,
drop = FALSE)
drop = FALSE)
co.names <- attr(X$covs, "co.names")
if (missing(var.name)) {
var.name <- NULL; k = 1
Expand Down Expand Up @@ -150,7 +150,7 @@ bal.plot <- function(x, var.name, ..., which, which.sub = NULL, cluster = NULL,
#Longitudinal
X$covs.list <- lapply(seq_along(X$covs.list), function(i) {
.get_C2(X$covs.list[[i]], addl = X$addl.list[[i]], distance = X$distance.list[[i]], cluster = X$cluster,
treat = X$treat.list[[i]], drop = FALSE)
treat = X$treat.list[[i]], drop = FALSE)
})
co.names.list <- lapply(X$covs.list, attr, "co.names")
ntimes <- length(X$covs.list)
Expand Down Expand Up @@ -244,7 +244,7 @@ bal.plot <- function(x, var.name, ..., which, which.sub = NULL, cluster = NULL,
if (is_null(X$subclass) && NCOL(X$weights) != 1L) names(X$weights)
else "adjusted"
}

if (is_null(X$s.weights)) {
X$s.weights <- rep(1, length(X$treat))
}
Expand Down Expand Up @@ -638,11 +638,16 @@ bal.plot <- function(x, var.name, ..., which, which.sub = NULL, cluster = NULL,
D$var.mean <- ave_w.m(D[["var"]], D[facet], w = D[["s.weights"]])
D$treat.mean <- ave_w.m(D[["treat"]], D[facet], w = D[["s.weights"]])

bp <- ggplot2::ggplot(D, mapping = aes(x = .data$var, y = .data$treat, weight = .data$weights * .data$s.weights))
bp <- ggplot2::ggplot(D, mapping = aes(x = .data$var, y = .data$treat,
weight = .data$weights * .data$s.weights))

if (identical(which, "Unadjusted Sample") || isFALSE(alpha.weight)) bp <- bp + ggplot2::geom_point(alpha = .9)
else bp <- bp + ggplot2::geom_point(aes(alpha = .data$weights), show.legend = FALSE) +
ggplot2::scale_alpha(range = c(.04, 1))
if (identical(which, "Unadjusted Sample") || isFALSE(alpha.weight)) {
bp <- bp + ggplot2::geom_point(alpha = .9)
}
else {
bp <- bp + ggplot2::geom_point(aes(alpha = .data$weights), show.legend = FALSE) +
ggplot2::scale_alpha(range = c(.04, 1))
}

bp <- bp +
ggplot2::geom_smooth(method = "lm", formula = y ~ x, se = FALSE, color = "firebrick2",
Expand Down Expand Up @@ -688,9 +693,13 @@ bal.plot <- function(x, var.name, ..., which, which.sub = NULL, cluster = NULL,
which.treat <- character(0)
}

if (is_not_null(which.treat) && !anyNA(which.treat)) D <- D[D$treat %in% which.treat,]
if (is_not_null(which.treat) && !anyNA(which.treat)) {
D <- D[D$treat %in% which.treat,]
}

for (i in which(vapply(D, is.factor, logical(1L)))) D[[i]] <- factor(D[[i]])
for (i in which(vapply(D, is.factor, logical(1L)))) {
D[[i]] <- factor(D[[i]])
}

D$weights <- ave(D[["weights"]] * D[["s.weights"]],
D[c("treat", facet)],
Expand Down Expand Up @@ -724,9 +733,10 @@ bal.plot <- function(x, var.name, ..., which, which.sub = NULL, cluster = NULL,
D$var <- factor(D$var)
bp <- ggplot2::ggplot(D, mapping = aes(x = .data$var, fill = .data$treat, weight = .data$weights)) +
ggplot2::geom_bar(position = "dodge", alpha = .8, color = "black") +
ggplot2::labs(x = var.name, y = "Proportion", fill = "Treatment", title = title, subtitle = subtitle) +
ggplot2::scale_x_discrete(drop=FALSE) +
ggplot2::scale_fill_manual(drop=FALSE, values = colors) +
ggplot2::labs(x = var.name, y = "Proportion", fill = "Treatment",
title = title, subtitle = subtitle) +
ggplot2::scale_x_discrete(drop = FALSE) +
ggplot2::scale_fill_manual(drop = FALSE, values = colors) +
ggplot2::geom_hline(yintercept = 0) +
ggplot2::scale_y_continuous(expand = ggplot2::expansion(mult = c(0, .05)))
}
Expand Down Expand Up @@ -767,14 +777,18 @@ bal.plot <- function(x, var.name, ..., which, which.sub = NULL, cluster = NULL,
fill = names(colors)[t]),
alpha = alpha, bins = args$bins, color = "black"),
NULL)
if (isTRUE(disp.means)) out[[2]] <-
ggplot2::geom_segment(data = unique(D[D$treat == levels(D$treat)[t], c("var.mean", facet), drop = FALSE]),
mapping = aes(x = .data$var.mean, xend = .data$var.mean, y = 0, yend = posneg[t]*Inf),
color = if (isTRUE(mirror)) "black" else colors[[t]])
if (isTRUE(disp.means)) {
out[[2]] <- ggplot2::geom_segment(data = unique(D[D$treat == levels(D$treat)[t], c("var.mean", facet), drop = FALSE]),
mapping = aes(x = .data$var.mean, xend = .data$var.mean, y = 0, yend = posneg[t]*Inf),
color = if (isTRUE(mirror)) "black" else colors[[t]])
}
clear_null(out)
}
ylab <- "Proportion"

bp <- Reduce(init = ggplot2::ggplot(), "+", lapply(seq_len(nlevels.treat), geom_fun)) +
ggplot2::scale_fill_manual(values = colors)

}
else if (type %in% c("ecdf")) {

Expand All @@ -798,12 +812,13 @@ bal.plot <- function(x, var.name, ..., which, which.sub = NULL, cluster = NULL,
D <- rbind(extra[names(D)], D)

geom_fun <- function(t) {

ggplot2::geom_step(data = D[D$treat == levels(D$treat)[t],],
mapping = aes(x = .data$var, y = .data$cum.pt, color = names(colors)[t]))

}
ylab <- "Cumulative Proportion"

bp <- Reduce(init = ggplot2::ggplot(), "+", lapply(seq_len(nlevels.treat), geom_fun)) +
ggplot2::scale_color_manual(values = colors)
}
else {
#Density arguments supplied through ...
Expand All @@ -815,13 +830,11 @@ bal.plot <- function(x, var.name, ..., which, which.sub = NULL, cluster = NULL,
if (is.character(bw)) {
t.sizes <- tapply(rep(1, NROW(D)), D$treat, sum)
smallest.t <- names(t.sizes)[which.min(t.sizes)]
if (is.function(get0(paste0("bw.", bw)))) {
bw <- get0(paste0("bw.", bw))(D$var[D$treat == smallest.t])
}
else {
if (!is.function(get0(paste0("bw.", bw)))) {
.err(sprintf("%s is not an acceptable entry to `bw`. See `?stats::density` for allowable options",
add_quotes(bw, "`")))
}
bw <- get0(paste0("bw.", bw))(D$var[D$treat == smallest.t])
}

if (isTRUE(disp.means)) {
Expand All @@ -837,25 +850,25 @@ bal.plot <- function(x, var.name, ..., which, which.sub = NULL, cluster = NULL,
kernel = kernel, n = n, trim = TRUE,
outline.type = "full", stat = StatDensity2),
NULL)
if (isTRUE(disp.means)) out[[2]] <-
ggplot2::geom_segment(data = unique(D[D$treat == levels(D$treat)[t], c("var.mean", facet), drop = FALSE]),
mapping = aes(x = .data$var.mean, xend = .data$var.mean, y = 0, yend = posneg[t]*Inf),
color = if (isTRUE(mirror)) "black" else colors[[t]])
if (isTRUE(disp.means)) {
out[[2]] <- ggplot2::geom_segment(data = unique(D[D$treat == levels(D$treat)[t], c("var.mean", facet), drop = FALSE]),
mapping = aes(x = .data$var.mean, xend = .data$var.mean, y = 0, yend = posneg[t]*Inf),
color = if (isTRUE(mirror)) "black" else colors[[t]])
}
clear_null(out)
}
ylab <- "Density"

bp <- Reduce(init = ggplot2::ggplot(), "+", lapply(seq_len(nlevels.treat), geom_fun)) +
ggplot2::scale_fill_manual(values = colors)
}

bp <- Reduce("+", c(list(ggplot2::ggplot()),
lapply(seq_len(nlevels.treat), geom_fun))) +
ggplot2::scale_fill_manual(values = colors, guide = ggplot2::guide_legend(override.aes = list(alpha = legend.alpha))) +
ggplot2::scale_color_manual(values = colors) +
if (isTRUE(mirror)) bp <- bp + ggplot2::geom_hline(yintercept = 0)

bp <- bp +
ggplot2::labs(x = var.name, y = ylab, title = title, subtitle = subtitle,
fill = "Treatment", color = "Treatment") +
ggplot2::scale_y_continuous(expand = expandScale)

if (isTRUE(mirror)) bp <- bp + ggplot2::geom_hline(yintercept = 0)
}
}

Expand Down Expand Up @@ -916,7 +929,7 @@ bal.plot <- function(x, var.name, ..., which, which.sub = NULL, cluster = NULL,
scales = if ("subclass" %in% facet) "free_x" else "fixed")
}

bp
bp
}

# Helper functions
Expand Down
4 changes: 2 additions & 2 deletions man/bal.plot.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6856039

Please sign in to comment.