Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions r-package/balnet/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ S3method(predict,balnet)
S3method(predict,balnet.fit)
S3method(predict,cv.balnet)
S3method(print,balnet)
S3method(print,balweights.contrast)
S3method(print,coef.balnet.contrast)
S3method(print,cv.balnet)
S3method(summary,balnet)
S3method(summary,balweights.contrast)
S3method(summary,coef.balnet.contrast)
S3method(summary,cv.balnet)
export(balnet)
export(balweights)
Expand Down
54 changes: 52 additions & 2 deletions r-package/balnet/R/balnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,37 @@ coef.balnet <- function(
out.nn <- out[!vapply(out, is.null, logical(1))]

if (length(out.nn) > 1) {
return(out.nn)
return(structure(out.nn, class = "coef.balnet.contrast"))
} else {
return(out.nn[[1]])
}
}

#' @rdname coef.balnet
#' @param x A `coef.balnet.contrast` object.
#' @method print coef.balnet.contrast
#' @export
print.coef.balnet.contrast <- function(x, ...) {
cat("Control arm coefficients:", "\n")
print(x[["control"]], ...)
cat("\n")
cat("Treated arm coefficients:", "\n")
print(x[["treated"]], ...)

invisible(x)
}

#' @rdname coef.balnet
#' @method summary coef.balnet.contrast
#' @export
summary.coef.balnet.contrast <- function(object, ...) {
cat("Control arm coefficients:", "\n")
print(summary(object[["control"]], ...))
cat("\n")
cat("Treated arm coefficients:", "\n")
print(summary(object[["treated"]], ...))
}

#' Predict using a balnet object.
#'
#' @param object A `balnet` object.
Expand Down Expand Up @@ -657,12 +682,37 @@ balweights.balnet <- function(
out.nn <- out[!vapply(out, is.null, logical(1))]

if (length(out.nn) > 1) {
return(out.nn)
return(structure(out.nn, class = "balweights.contrast"))
} else {
return(out.nn[[1]])
}
}

#' @rdname balweights
#' @param x A `balweights.contrast` object.
#' @method print balweights.contrast
#' @export
print.balweights.contrast <- function(x, ...) {
cat("Control arm weights:", "\n")
print(x[["control"]], ...)
cat("\n")
cat("Treated arm weights:", "\n")
print(x[["treated"]], ...)

invisible(x)
}

#' @rdname balweights
#' @method summary balweights.contrast
#' @export
summary.balweights.contrast <- function(object, ...) {
cat("Control arm weights:", "\n")
print(summary(object[["control"]], ...))
cat("\n")
cat("Treated arm weights:", "\n")
print(summary(object[["treated"]], ...))
}

get_path <- function(fit, W.hat, W, ..., lambda, devs) {
target <- fit[["target"]]
sample.weights <- fit[["sample.weights"]]
Expand Down
8 changes: 8 additions & 0 deletions r-package/balnet/man/balweights.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions r-package/balnet/man/coef.balnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 12 additions & 6 deletions r-package/balnet/tests/testthat/test_balnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@ test_that("basic balnet runs", {

fit <- balnet(X, W)
capture.output(print(fit))
capture.output(summary(fit))
plot(fit)
plot(fit, lambda = 0)
coef(fit)
cf <- coef(fit)
capture.output(print(cf))
capture.output(summary(cf))
coef(fit, lambda = list(0, 1))
predict(fit, X)
wts <- balweights(fit)
capture.output(print(wts))
capture.output(summary(wts))

fit.gr <- balnet(X, W, groups = list(age = 10:15, 3:7))

Expand Down Expand Up @@ -176,10 +182,10 @@ test_that("balnet has not changed", {

expect_equal(
coef(fit, lambda = fit[["_lambda"]]),
list(control = new("dgCMatrix", i = c(0L, 0L, 1L, 2L, 3L, 0L,
1L, 2L, 3L, 0L, 1L, 2L, 3L, 0L, 1L, 2L, 3L), p = c(0L, 1L, 5L,
9L, 13L, 17L), Dim = 4:5, Dimnames = list(c("(Intercept)", "X1",
"X2", "X3"), NULL), x = c(0.652873281422005, 0.66802481765921,
structure(list(control = new("dgCMatrix", i = c(0L, 0L, 1L, 2L,
3L, 0L, 1L, 2L, 3L, 0L, 1L, 2L, 3L, 0L, 1L, 2L, 3L), p = c(0L,
1L, 5L, 9L, 13L, 17L), Dim = 4:5, Dimnames = list(c("(Intercept)",
"X1", "X2", "X3"), NULL), x = c(0.652873281422005, 0.66802481765921,
0.0507573767431381, 0.158769917286064, 0.103686341052803, 0.680645954386779,
0.0964925117386255, 0.210324237163201, 0.145107291953369, 0.685941154728131,
0.111004789201444, 0.227049604256553, 0.157882123136175, 0.687747421060082,
Expand All @@ -193,7 +199,7 @@ test_that("balnet has not changed", {
-0.218193176446944, -0.681102012267433, -0.0496225138747977,
-0.149449789859248, -0.237897822417655, -0.680577640484851,
-0.0533492030403526, -0.151711346250273, -0.244157630108172
), factors = list()))
), factors = list())), class = "coef.balnet.contrast")
)

expect_equal(
Expand Down
18 changes: 12 additions & 6 deletions r-package/balnet/tests/testthat/test_cv.balnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@ test_that("basic cv.balnet runs", {

fit <- cv.balnet(X, W)
capture.output(print(fit))
capture.output(summary(fit))
plot(fit)
coef(fit)
cf <- coef(fit)
capture.output(print(cf))
capture.output(summary(cf))
predict(fit, X)
wts <- balweights(fit)
capture.output(print(wts))
capture.output(summary(wts))

expect_true(TRUE)
})
Expand Down Expand Up @@ -80,14 +86,14 @@ test_that("cv.balnet has not changed", {

expect_equal(
coef(fit),
list(control = new("dgCMatrix", i = 0:3, p = c(0L, 4L), Dim = c(4L,
1L), Dimnames = list(c("(Intercept)", "X1", "X2", "X3"), NULL),
x = c(0.680645954386779, 0.0964925117386255, 0.210324237163201,
0.145107291953369), factors = list()), treated = new("dgCMatrix",
structure(list(control = new("dgCMatrix", i = 0:3, p = c(0L,
4L), Dim = c(4L, 1L), Dimnames = list(c("(Intercept)", "X1",
"X2", "X3"), NULL), x = c(0.680645954386779, 0.0964925117386255,
0.210324237163201, 0.145107291953369), factors = list()), treated = new("dgCMatrix",
i = 0:3, p = c(0L, 4L), Dim = c(4L, 1L), Dimnames = list(
c("(Intercept)", "X1", "X2", "X3"), NULL), x = c(-0.681102012267433,
-0.0496225138747977, -0.149449789859248, -0.237897822417655
), factors = list()))
), factors = list())), class = "coef.balnet.contrast")
)

expect_equal(
Expand Down
Loading