Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ngreifer committed Feb 26, 2024
1 parent 8be75fe commit 12ebe2f
Show file tree
Hide file tree
Showing 5 changed files with 2 additions and 159 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ love.plot(m.out, stats = c("mean.diffs", "variance.ratios"),
Please remember to cite this package when using it to analyze data. For
example, in a manuscript, you could write: “Matching was performed using
the *Matching* package (Sekhon, 2011), and covariate balance was
assessed using *cobalt* (Greifer, 2023), both in R (R Core Team, 2023).”
assessed using *cobalt* (Greifer, 2024), both in R (R Core Team, 2023).”
Use `citation("cobalt")` to generate a bibliographic reference for the
`cobalt` package.

Expand Down
157 changes: 0 additions & 157 deletions do_not_include/Under_construction/under_construction.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,163 +27,6 @@ skew.diff <- function(x, group, weights, var.type) {
return(skew.cols)
}

#Scalar balance functions - not implemented yet
bal.sum <- function(mat, treat, weights = NULL, type, s.weights = NULL, check = TRUE, ...) {
uni.type <- c("smd", "ks", "ovl")
agg.funs <- c("max", "mean", "rms")
sample.type <- c("mahalanobis", "gwd", "cstat", "wr2", "design.effect")
uni.type.expanded <- expand.grid_string(uni.type, agg.funs, collapse = ".")
shortcuts <- c("all", "rec")
allowable.type <- c(uni.type.expanded, sample.type, shortcuts)
if (missing(type)) stop("type must be specified.", call. = FALSE)
else type <- match_arg(type, allowable.type, several.ok = TRUE)

if ("all" %in% type) type <- unique(c(type[type != "all"], uni.type.expanded, sample.type))
if ("rec" %in% type) type <- unique(c(type[type != "rec"],
"smd.mean", "smd.rms",
"ks.mean", "ks.rms",
"mahalanobis", "gwd", "wr2"))

A <- list(...)

if (check) {
bad.mat <- FALSE
if (missing(mat)) bad.mat <- TRUE
else {
if (is.data.frame(mat)) {
if (any(vapply(mat, function(x) is_(x, c("character", "factor")), logical(1L))))
mat <- splitfactor(mat)
mat <- as.matrix.data.frame(mat)
}
else if (is.vector(mat, "numeric")) mat <- matrix(mat, ncol = 1)
else if (!is.matrix(mat) || !is.numeric(mat)) bad.mat <- TRUE
}
if (bad.mat) stop("mat must be a numeric matrix.")

if (missing(treat) || !(is.factor(treat) || is.atomic(treat)) || !is_binary(treat)) stop("treat must be a binary variable.")

if (is_null(weights)) weights <- rep(1, NROW(mat))
if (is_null(s.weights)) s.weights <- rep(1, NROW(mat))
if (!all_the_same(c(NROW(mat), length(treat), length(weights), length(s.weights)))) {
stop("mat, treat, weights, and s.weights (if supplied) must have the same number of units.")
}
}

if (any(paste.("smd", agg.funs) %in% type)) {
if (!exists("s.d.denom")) s.d.denom <- get.s.d.denom(A[["s.d.denom"]], estimand = A[["estimand"]], weights = data.frame(weights), treat = treat)
smd <- col_w_smd(mat, treat, weights, std = TRUE, s.d.denom, abs = TRUE, s.weights = s.weights, check = FALSE)
if (is_null(A[["smd.weights"]])) smd.weights <- rep(1, ncol(mat))
else if (!is.vector(A[["smd.weights"]], "numeric")) {
warning("smd.weights is not numeric. Ignoring smd.weights",
call. = FALSE, immediate. = TRUE)
smd.weights <- rep(1, ncol(mat))
}
else if (length(A[["smd.weights"]]) == ncol(mat)) {
smd.weights <- A[["smd.weights"]]
}
else {
warning("smd.weights should be of length ncol(mat). Ignoring smd.weights",
call. = FALSE, immediate. = TRUE)
smd.weights <- rep(1, ncol(mat))
}
smd.weights <- smd.weights/mean(smd.weights) #Make sum to ncol(mat)
smd <- smd.weights*smd
}
if (any(paste.("ks", agg.funs) %in% type)) {
ks <- col_w_ks(mat, treat, weights, bin.vars = A[["bin.vars"]], check = is_null(A[["bin.vars"]]))
}
if (any(paste.("ovl", agg.funs) %in% type)) {
ovl <- do.call(col_w_ovl, c(list(mat = mat, treat = treat, weights = weights, check = is_null(A[["bin.vars"]])), A))
}
if ("gwd" %in% type) {
if (!exists("s.d.denom")) s.d.denom <- get.s.d.denom(A[["s.d.denom"]], estimand = A[["estimand"]], weights = data.frame(weights), treat = treat)
}

bal <- setNames(vapply(type, function(m) {
if (endsWith(m, ".mean")) {
agg <- mean
m <- substr(m, 1, nchar(m) - nchar(".mean"))
}
else if (endsWith(m, ".max")) {
agg <- max
m <- substr(m, 1, nchar(m) - nchar(".max"))
}
else if (endsWith(m, ".rms")) {
agg <- function(x, ...) {sqrt(mean(x^2, ...))}
m <- substr(m, 1, nchar(m) - nchar(".rms"))
}

if (m %in% uni.type) {
return(agg(get0(m), na.rm = TRUE))
}
else if (m == "mahalanobis") {
if (is_null(s.weights)) s.weights <- rep(1, nrow(mat))
mdiff <- matrix(col_w_smd(mat, treat, weights, std = FALSE, abs = FALSE, check = FALSE), ncol = 1)
wcov <- cov.wt(mat, s.weights)$cov
mahal <- crossprod(mdiff, solve(wcov)) %*% mdiff
return(mahal)
}
else if (m == "cstat") {
tval1 <- treat[1]
d <- data.frame(treat, mat)
f <- formula(d)
pred <- glm(f, data = d, family = quasibinomial(),
weights = weights)$fitted
wi <- wilcox.test(pred ~ treat)
cstat <- wi$statistic/(sum(treat==tval1)*sum(treat!=tval1))
cstat <- 2*max(cstat, 1-cstat)-1
return(cstat)
}
else if (m == "wr2") {
tval1 <- treat[1]
d <- data.frame(treat, mat)
f <- formula(d)
fit <- glm(f, data = d, family = quasibinomial(),
weights = weights)
r2 <- 1 - (pi^2/3)/(3*var(fit$linear.predictors) + pi^2/3)
return(r2)
}
else if (m == "gwd") {
co.names <- setNames(lapply(colnames(mat), function(x) setNames(list(x, "base"), c("component", "type"))), colnames(mat))
new <- int.poly.f2(mat, int = TRUE, poly = 2, center = isTRUE(A[["center"]]),
sep = getOption("cobalt_int_sep", default = " * "), co.names = co.names)
mat_ <- cbind(mat, new)

smd <- col_w_smd(mat_, treat, weights, std = TRUE, s.d.denom, abs = TRUE, check = FALSE)
if (is_null(A[["gwd.weights"]])) gwd.weights <- c(rep(1, ncol(mat)), rep(.5, ncol(new)))
else if (!is.vector(A[["gwd.weights"]], "numeric")) {
warning("gwd.weights is not numeric. Ignoring gwd.weights.",
call. = FALSE, immediate. = TRUE)
gwd.weights <- c(rep(1, ncol(mat)), rep(.5, ncol(new)))
}
else if (length(A[["gwd.weights"]]) == 1L) {
gwd.weights <- rep(A[["gwd.weights"]], ncol(mat_))
}
else if (length(A[["gwd.weights"]]) == 2L) {
gwd.weights <- c(rep(A[["gwd.weights"]][1], ncol(mat)), rep(A[["gwd.weights"]][2], ncol(new)))
}
else {
warning("gwd.weights should be of length 1 or 2. Ignoring gwd.weights.",
call. = FALSE, immediate. = TRUE)
gwd.weights <- c(rep(1, ncol(mat)), rep(.5, ncol(new)))
}

gwd.weights <- gwd.weights/sum(gwd.weights) #Make sum to 1
return(sum(gwd.weights*smd, na.rm = TRUE))
}
else if (m == "design.effect") {
tval1 <- treat[1]
q <- sum(treat == tval1)/length(treat)
des.eff <- function(w) length(w)*sum(w^2)/sum(w)^2
des <- c(des.eff(weights[treat == tval1]), des.eff(weights[treat != tval1]))
de <- des[1]*(1-q) + des[2]*q
return(de)
}
}, numeric(1L)), type)

return(bal)
}

col_w_edist <- function(mat, treat, weights = NULL, s.weights = NULL, bin.vars, subset = NULL, na.rm = TRUE, ...) {
needs.splitting <- FALSE
if (!is.matrix(mat)) {
Expand Down
Binary file modified man/figures/README-unnamed-chunk-3-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/README-unnamed-chunk-3-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion vignettes/faq.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ b <- bal.tab(treat ~ age + educ + race + married + re74,
data = lalonde, s.d.denom = "treated",
disp = "means", stats = c("m", "v"))
# View the structure of hte object
# View the structure of the object
str(b, give.attr = FALSE)
b$Balance
Expand Down

0 comments on commit 12ebe2f

Please sign in to comment.