Teacher models for causal distillation trees
teacher_models.Rd
These functions are wrappers around various heterogeneous treatment effect
learners (and their associated predict
methods) that can be easily used as
teacher models in the causal distillation tree framework.
causal_forest()
: wrapper aroundgrf::causal_forest()
predict_causal_forest()
: wrapper aroundpredict()
forcausal_forest()
models.bcf()
: wrapper aroundbcf::bcf()
predict_bcf()
: wrapper aroundpredict()
forbcf()
models.rlearner_teacher()
: wrapper around model functions from therlearner
package to convert them to teacher models for CDT.rboost()
: (defunct) wrapper aroundrlearner::rboost()
.rlasso()
: (defunct) wrapper aroundrlearner::rlasso()
.rkern()
: (defunct) wrapper aroundrlearner::rkern()
.
Warning: The rboost()
, rlasso()
, and rkern()
functions
are defunct as of version 1.0.0. Use rlearner_teacher()
(e.g.,
rlearner_teacher(rlearner::rboost)
) instead to convert
rlearner
functions into correct format for use as teacher model in
CDT.
Usage
causal_forest(X, Y, Z, W = NULL, ...)
predict_causal_forest(...)
rlearner_teacher(rlearner_fun, ...)
rboost(X, Y, Z, W = NULL, ...)
rlasso(X, Y, Z, W = NULL, ...)
rkern(X, Y, Z, W = NULL, ...)
bcf(
X,
Y,
Z,
W = NULL,
pihat = "default",
w = NULL,
nburn = 2000,
nsim = 1000,
n_threads = 1,
...
)
predict_bcf(...)
Arguments
- X
A tibble, data.frame, or matrix of covariates.
- Y
A vector of outcomes.
- Z
A vector of treatments.
- W
A vector of weights corresponding to treatment propensities.
- ...
Additional arguments to pass to the base model functions.
- rlearner_fun
One of
rlearner::rboost
,rlearner::rlasso
, orrlearner::rkern
to be transformed to teacher model format for CDT.- pihat
Length n estimates of propensity score
- w
An optional vector of weights. When present, BCF fits a model \(y | x ~ N(f(x), \sigma^2 / w)\), where \(f(x)\) is the unknown function.
- nburn
Number of burn-in MCMC iterations
- nsim
Number of MCMC iterations to save after burn-in. The chain will run for nsim*nthin iterations after burn-in
- n_threads
An optional integer of the number of threads to parallelize within chain bcf operations on