R/gen_CART.R

Defines functions compare_cart gen_cart

Documented in compare_cart gen_cart

#' Generate synthetic data using CART.
#'
#' \code{gen_cart} uses Classification and Regression Trees (CART)
#' to generate synthetic data by sequentially predicting the value of
#' each variable depending on the value of other variables. Details can
#' be found in \code{\link[synthpop:syn]{syn}}.
#'
#' @param training_set A data frame of the training data. The generated data will
#'     have the same size as the \code{training_set}.
#' @param structure A string of the relationships between variables from
#'    \code{\link[bnlearn:model string utilities]{modelstring}}. If structure is NA,
#'    the default structure would be the sequence of the variables in the \code{training_set}
#'    data frame.
#' @return The output is a list of three objects: i) structure: the dependency/relationship
#' between the variables (a \code{\link[bnlearn:bn class]{bn-class}} object); ii) fit_model:
#' the fitted CART model ((a \code{\link[synthpop:syn]{syn}}) object and iii) gen_data:
#' the generated synthetic data.
#' @examples
#' adult_data <- split_data(adult[1:100,], 70)
#' cart <- gen_cart(adult_data$training_set)
#' bn_structure <- "[native_country][income][age|marital_status:education]"
#' bn_structure = paste0(bn_structure, "[sex][race|native_country][marital_status|race:sex]")
#' bn_structure = paste0(bn_structure,"[relationship|marital_status][education|sex:race]")
#' bn_structure = paste0(bn_structure,"[occupation|education][workclass|occupation]")
#' bn_structure = paste0(bn_structure,"[hours_per_week|occupation:workclass]")
#' bn_structure = paste0(bn_structure,"[capital_gain|occupation:workclass:income]")
#' bn_structure = paste0(bn_structure,"[capital_loss|occupation:workclass:income]")
#' cart_elicit <- gen_cart(adult_data$training_set, bn_structure)
#'
#' @export
gen_cart <- function(training_set, structure = NA)
{
  if (!is.character(structure))
  {
    message("generating data using sequence of variables")
    gen_synth_synthpop <- synthpop::syn(data = training_set, m = 1,
                                        k = nrow(training_set), drop.not.used = FALSE)
    m <- gen_synth_synthpop$predictor.matrix
  } else
  {
    message("generating data using defined relationships")
    m <- matrix(0, nrow = length(training_set), ncol = length(training_set))
    rownames(m) <- colnames(training_set)
    colnames(m) <- colnames(training_set)
    tmp <- bnlearn::model2network(structure)$arcs

    for (i in 1:nrow(tmp))
    {
      m[tmp[i, 1], tmp[i, 2]] <- 1
    }
    gen_synth_synthpop <- synthpop::syn(data = training_set, m = 1,
                                        k = nrow(training_set), predictor.matrix = m)
  }

  return(list(structure = m, fit_model = gen_synth_synthpop, gen_data = gen_synth_synthpop$syn[]))
}


#' Compare the synthetic data generated by CART with the real data.
#'
#' \code{compare_cart} compare the synthetic data generated by CART with the real data.
#'
#' @param training_set A data frame of the training data. The generated data will
#'     have the same size as the \code{training_set}.
#' @param fit_model A \code{\link[synthpop:syn]{syn}}) object.
#' @param var_list A string vector of the names of variables that we want to compare.
#' @return A plot of the comparision of the distribution of
#'     synthetic data vs real data.
#' @examples
#' adult_data <- split_data(adult[1:100,], 70)
#' cart <- gen_cart(adult_data$training_set)
#' compare_cart(adult_data$training_set, cart$fit_model, c("age", "workclass", "sex"))
#'
#' @export
compare_cart <- function(training_set, fit_model, var_list)
{
  synthpop::compare(fit_model, training_set, vars = var_list,
                    nrow = 1, ncol = length(var_list))$plot
}

Try the sdglinkage package in your browser

Any scripts or data that you put into this service are public.

sdglinkage documentation built on April 27, 2020, 5:09 p.m.