Subgroup stability diagnostics
evaluate_subgroup_stability.Rd
TODO
Usage
evaluate_subgroup_stability(
estimator,
fit,
X,
y,
Z = NULL,
rpart_control = NULL,
B = 100,
max_depth = NULL
)
Arguments
- estimator
Function used to estimate subgroups of individuals and their corresponding estimated treatment effects. The function should take in
X
,y
, and optionallyZ
(if input is notNULL
) and return a model fit (e.g,. output ofrpart
) that can be coerced into aparty
object viapartykit::as_party()
. Typically,student_rpart
will be used as theestimator
.- fit
Fitted subgroup model (often, the output of
estimator()
). Mainly used to determine an appropriatemax_depth
for the stability diagnostics. Iffit
is not anrpart
object, stability diagnostics will be skipped.- X
A tibble, data.frame, or matrix of covariates.
- y
A vector of responses to predict.
- Z
A vector of treatments.
- rpart_control
A list of control parameters for the
rpart
algorithm. See? rpart.control
for details.- B
Number of bootstrap samples to use in evaluating stability diagnostics. Default is 100.
- max_depth
Maximum depth of the tree to consider when evaluating stability diagnostics. If
NULL
, the default is max(4, max depth offit
).
Value
A list with the following elements:
- jaccard_mean
Vector of mean Jaccard similarity index for each tree depth. The tree depth is given by the vector index.
- jaccard_distribution
List of Jaccard similarity indices across all bootstraps for each tree depth.
- bootstrap_predictions
List of mean student model predictions (for training (non-holdout) data) across all bootstraps for each tree depth.
- bootstrap_predictions_var
List of variance of student model predictions (for training (non-holdout) data) across all bootstraps for each tree depth.
- leaf_ids
List of leaf node identifiers, indicating the leaf membership of each training sample in the (original) fitted student model.
Examples
n <- 200
p <- 10
X <- matrix(rnorm(n * p), nrow = n, ncol = p)
Z <- rbinom(n, 1, 0.5)
Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1)
# run causal distillation trees without stability diagnostics
out <- causalDT(X, Y, Z, B_stability = 0)
# run stability diagnostics manually
stability_out <- evaluate_subgroup_stability(
estimator = student_rpart,
fit = out$student_fit$fit,
X = X[-out$holdout_idxs, , drop = FALSE],
y = out$student_fit$predictions
)