Skip to contents

These functions are wrappers around various heterogeneous treatment effect learners (and their associated predict methods) that can be easily used as teacher models in the causal distillation tree framework.

  • causal_forest(): wrapper around grf::causal_forest()

  • predict_causal_forest(): wrapper around predict() for causal_forest() models.

  • bcf(): wrapper around bcf::bcf()

  • predict_bcf(): wrapper around predict() for bcf() models.

  • rlearner_teacher(): wrapper around model functions from the rlearner package to convert them to teacher models for CDT.

  • rboost(): (defunct) wrapper around rlearner::rboost().

  • rlasso(): (defunct) wrapper around rlearner::rlasso().

  • rkern(): (defunct) wrapper around rlearner::rkern().

Warning: The rboost(), rlasso(), and rkern() functions are defunct as of version 1.0.0. Use rlearner_teacher() (e.g., rlearner_teacher(rlearner::rboost)) instead to convert rlearner functions into correct format for use as teacher model in CDT.

Usage

causal_forest(X, Y, Z, W = NULL, ...)

predict_causal_forest(...)

rlearner_teacher(rlearner_fun, ...)

rboost(X, Y, Z, W = NULL, ...)

rlasso(X, Y, Z, W = NULL, ...)

rkern(X, Y, Z, W = NULL, ...)

bcf(
  X,
  Y,
  Z,
  W = NULL,
  pihat = "default",
  w = NULL,
  nburn = 2000,
  nsim = 1000,
  n_threads = 1,
  ...
)

predict_bcf(...)

Arguments

X

A tibble, data.frame, or matrix of covariates.

Y

A vector of outcomes.

Z

A vector of treatments.

W

A vector of weights corresponding to treatment propensities.

...

Additional arguments to pass to the base model functions.

rlearner_fun

One of rlearner::rboost, rlearner::rlasso, or rlearner::rkern to be transformed to teacher model format for CDT.

pihat

Length n estimates of propensity score

w

An optional vector of weights. When present, BCF fits a model \(y | x ~ N(f(x), \sigma^2 / w)\), where \(f(x)\) is the unknown function.

nburn

Number of burn-in MCMC iterations

nsim

Number of MCMC iterations to save after burn-in. The chain will run for nsim*nthin iterations after burn-in

n_threads

An optional integer of the number of threads to parallelize within chain bcf operations on