diff --git a/r-package/balnet/NAMESPACE b/r-package/balnet/NAMESPACE index f9b5817..8fe32d0 100644 --- a/r-package/balnet/NAMESPACE +++ b/r-package/balnet/NAMESPACE @@ -11,8 +11,12 @@ S3method(predict,balnet) S3method(predict,balnet.fit) S3method(predict,cv.balnet) S3method(print,balnet) +S3method(print,balweights.contrast) +S3method(print,coef.balnet.contrast) S3method(print,cv.balnet) S3method(summary,balnet) +S3method(summary,balweights.contrast) +S3method(summary,coef.balnet.contrast) S3method(summary,cv.balnet) export(balnet) export(balweights) diff --git a/r-package/balnet/R/balnet.R b/r-package/balnet/R/balnet.R index a97e8d0..d20b6a8 100644 --- a/r-package/balnet/R/balnet.R +++ b/r-package/balnet/R/balnet.R @@ -282,12 +282,37 @@ coef.balnet <- function( out.nn <- out[!vapply(out, is.null, logical(1))] if (length(out.nn) > 1) { - return(out.nn) + return(structure(out.nn, class = "coef.balnet.contrast")) } else { return(out.nn[[1]]) } } +#' @rdname coef.balnet +#' @param x A `coef.balnet.contrast` object. +#' @method print coef.balnet.contrast +#' @export +print.coef.balnet.contrast <- function(x, ...) { + cat("Control arm coefficients:", "\n") + print(x[["control"]], ...) + cat("\n") + cat("Treated arm coefficients:", "\n") + print(x[["treated"]], ...) + + invisible(x) +} + +#' @rdname coef.balnet +#' @method summary coef.balnet.contrast +#' @export +summary.coef.balnet.contrast <- function(object, ...) { + cat("Control arm coefficients:", "\n") + print(summary(object[["control"]], ...)) + cat("\n") + cat("Treated arm coefficients:", "\n") + print(summary(object[["treated"]], ...)) +} + #' Predict using a balnet object. #' #' @param object A `balnet` object. @@ -657,12 +682,37 @@ balweights.balnet <- function( out.nn <- out[!vapply(out, is.null, logical(1))] if (length(out.nn) > 1) { - return(out.nn) + return(structure(out.nn, class = "balweights.contrast")) } else { return(out.nn[[1]]) } } +#' @rdname balweights +#' @param x A `balweights.contrast` object. +#' @method print balweights.contrast +#' @export +print.balweights.contrast <- function(x, ...) { + cat("Control arm weights:", "\n") + print(x[["control"]], ...) + cat("\n") + cat("Treated arm weights:", "\n") + print(x[["treated"]], ...) + + invisible(x) +} + +#' @rdname balweights +#' @method summary balweights.contrast +#' @export +summary.balweights.contrast <- function(object, ...) { + cat("Control arm weights:", "\n") + print(summary(object[["control"]], ...)) + cat("\n") + cat("Treated arm weights:", "\n") + print(summary(object[["treated"]], ...)) +} + get_path <- function(fit, W.hat, W, ..., lambda, devs) { target <- fit[["target"]] sample.weights <- fit[["sample.weights"]] diff --git a/r-package/balnet/man/balweights.Rd b/r-package/balnet/man/balweights.Rd index c611577..0c874fe 100644 --- a/r-package/balnet/man/balweights.Rd +++ b/r-package/balnet/man/balweights.Rd @@ -3,6 +3,8 @@ \name{balweights} \alias{balweights} \alias{balweights.balnet} +\alias{print.balweights.contrast} +\alias{summary.balweights.contrast} \alias{balweights.cv.balnet} \title{Extract balancing weights from a balnet object.} \usage{ @@ -10,6 +12,10 @@ balweights(object, lambda = NULL, ...) \method{balweights}{balnet}(object, lambda = NULL, ...) +\method{print}{balweights.contrast}(x, ...) + +\method{summary}{balweights.contrast}(object, ...) + \method{balweights}{cv.balnet}(object, lambda = "lambda.min", ...) } \arguments{ @@ -26,6 +32,8 @@ arm and the second to the treatment. }} \item{...}{Additional arguments (currently ignored).} + +\item{x}{A \code{balweights.contrast} object.} } \value{ Estimated balancing weights diff --git a/r-package/balnet/man/coef.balnet.Rd b/r-package/balnet/man/coef.balnet.Rd index bc2e3ad..a8f9a72 100644 --- a/r-package/balnet/man/coef.balnet.Rd +++ b/r-package/balnet/man/coef.balnet.Rd @@ -2,9 +2,15 @@ % Please edit documentation in R/balnet.R \name{coef.balnet} \alias{coef.balnet} +\alias{print.coef.balnet.contrast} +\alias{summary.coef.balnet.contrast} \title{Extract coefficients from a balnet object.} \usage{ \method{coef}{balnet}(object, lambda = NULL, ...) + +\method{print}{coef.balnet.contrast}(x, ...) + +\method{summary}{coef.balnet.contrast}(object, ...) } \arguments{ \item{object}{A \code{balnet} object.} @@ -20,6 +26,8 @@ arm and the second to the treatment. }} \item{...}{Additional arguments (currently ignored).} + +\item{x}{A \code{coef.balnet.contrast} object.} } \value{ Estimated logistic coefficients diff --git a/r-package/balnet/tests/testthat/test_balnet.R b/r-package/balnet/tests/testthat/test_balnet.R index 15326f5..72a834d 100644 --- a/r-package/balnet/tests/testthat/test_balnet.R +++ b/r-package/balnet/tests/testthat/test_balnet.R @@ -8,11 +8,17 @@ test_that("basic balnet runs", { fit <- balnet(X, W) capture.output(print(fit)) + capture.output(summary(fit)) plot(fit) plot(fit, lambda = 0) - coef(fit) + cf <- coef(fit) + capture.output(print(cf)) + capture.output(summary(cf)) coef(fit, lambda = list(0, 1)) predict(fit, X) + wts <- balweights(fit) + capture.output(print(wts)) + capture.output(summary(wts)) fit.gr <- balnet(X, W, groups = list(age = 10:15, 3:7)) @@ -176,10 +182,10 @@ test_that("balnet has not changed", { expect_equal( coef(fit, lambda = fit[["_lambda"]]), - list(control = new("dgCMatrix", i = c(0L, 0L, 1L, 2L, 3L, 0L, - 1L, 2L, 3L, 0L, 1L, 2L, 3L, 0L, 1L, 2L, 3L), p = c(0L, 1L, 5L, - 9L, 13L, 17L), Dim = 4:5, Dimnames = list(c("(Intercept)", "X1", - "X2", "X3"), NULL), x = c(0.652873281422005, 0.66802481765921, + structure(list(control = new("dgCMatrix", i = c(0L, 0L, 1L, 2L, + 3L, 0L, 1L, 2L, 3L, 0L, 1L, 2L, 3L, 0L, 1L, 2L, 3L), p = c(0L, + 1L, 5L, 9L, 13L, 17L), Dim = 4:5, Dimnames = list(c("(Intercept)", + "X1", "X2", "X3"), NULL), x = c(0.652873281422005, 0.66802481765921, 0.0507573767431381, 0.158769917286064, 0.103686341052803, 0.680645954386779, 0.0964925117386255, 0.210324237163201, 0.145107291953369, 0.685941154728131, 0.111004789201444, 0.227049604256553, 0.157882123136175, 0.687747421060082, @@ -193,7 +199,7 @@ test_that("balnet has not changed", { -0.218193176446944, -0.681102012267433, -0.0496225138747977, -0.149449789859248, -0.237897822417655, -0.680577640484851, -0.0533492030403526, -0.151711346250273, -0.244157630108172 - ), factors = list())) + ), factors = list())), class = "coef.balnet.contrast") ) expect_equal( diff --git a/r-package/balnet/tests/testthat/test_cv.balnet.R b/r-package/balnet/tests/testthat/test_cv.balnet.R index 1b6c78d..9940fc3 100644 --- a/r-package/balnet/tests/testthat/test_cv.balnet.R +++ b/r-package/balnet/tests/testthat/test_cv.balnet.R @@ -8,9 +8,15 @@ test_that("basic cv.balnet runs", { fit <- cv.balnet(X, W) capture.output(print(fit)) + capture.output(summary(fit)) plot(fit) - coef(fit) + cf <- coef(fit) + capture.output(print(cf)) + capture.output(summary(cf)) predict(fit, X) + wts <- balweights(fit) + capture.output(print(wts)) + capture.output(summary(wts)) expect_true(TRUE) }) @@ -80,14 +86,14 @@ test_that("cv.balnet has not changed", { expect_equal( coef(fit), - list(control = new("dgCMatrix", i = 0:3, p = c(0L, 4L), Dim = c(4L, - 1L), Dimnames = list(c("(Intercept)", "X1", "X2", "X3"), NULL), - x = c(0.680645954386779, 0.0964925117386255, 0.210324237163201, - 0.145107291953369), factors = list()), treated = new("dgCMatrix", + structure(list(control = new("dgCMatrix", i = 0:3, p = c(0L, + 4L), Dim = c(4L, 1L), Dimnames = list(c("(Intercept)", "X1", + "X2", "X3"), NULL), x = c(0.680645954386779, 0.0964925117386255, + 0.210324237163201, 0.145107291953369), factors = list()), treated = new("dgCMatrix", i = 0:3, p = c(0L, 4L), Dim = c(4L, 1L), Dimnames = list( c("(Intercept)", "X1", "X2", "X3"), NULL), x = c(-0.681102012267433, -0.0496225138747977, -0.149449789859248, -0.237897822417655 - ), factors = list())) + ), factors = list())), class = "coef.balnet.contrast") ) expect_equal(