### Header ---------------------------
###
### Title: OTfunctions.R
###
### Description: Functions used to construct cost matrix and compute optimal transport costs
###
### Author: Omkar A. Katta
###
### Notes:
###
###
### Cost Matrix ---------------------------
build_costmatrix <- function(support, bandwidth = 0){
# create cost matrix using the common support provided
costmatrix <- matrix(NA_real_, nrow = length(support), ncol = length(support))
for (i in seq_along(support)){
dist <- abs(support[i] - support)
dist <- ifelse(dist > bandwidth, 1, 0)
costmatrix[i, ] <- dist
}
costmatrix
}
build_costmatrix2 <- function(support_pre, support_post, bandwidth = 0){
# create cost matrix given the minimal support of both the pre-data and post-data
costmatrix <- matrix(NA_real_, nrow = length(support_pre), ncol = length(support_post))
for (i in seq_along(support_pre)){
dist <- abs(support_pre[i] - support_post)
dist <- ifelse(dist > bandwidth, 1, 0)
costmatrix[i, ] <- dist
}
costmatrix
}
### Compute Transport Cost ---------------------------
#' @importFrom transport transport
get_OTcost <- function(pre_df, post_df, support = NULL, bandwidth = 0, var = MSRP, costmat = NULL){
# given pre-data and post-data, compute optimal transport cost given bandwidth
pre <- pre_df$count
post <- post_df$count
pre_support <- pre_df %>% dplyr::select({{var}}) %>% unlist()
post_support <- post_df %>% dplyr::select({{var}}) %>% unlist()
names(pre_support) <- c()
names(post_support) <- c()
if (!identical(pre_support, post_support)){
stop("`pre_df` and `post_df` need to have the same support")
}
if (is.null(support)){
support <- pre_df %>% dplyr::select({{var}}) %>% unlist()
names(support) <- c()
}
if (!identical(support, pre_support)){
stop("`support` is different from `pre_support` and `post_support`")
}
if (!identical(support, post_support)){
stop("`support` is different from `pre_support` and `post_support`")
}
if (is.null(costmat)) {
costm <- build_costmatrix(support, bandwidth)
} else {
costm <- costmat
}
OT <- transport(
as.numeric(sum(post) / sum(pre) * pre),
as.numeric(post),
costm
)
temp <- as.data.frame(OT) %>%
dplyr::rowwise() %>%
dplyr::mutate(cost = costm[from, to]) %>%
dplyr::ungroup() %>%
dplyr::mutate(cost = mass * cost)
tot_cost <- sum(temp$cost)
prop_cost <- tot_cost / sum(post)
list("num_bribe" = tot_cost, "prop_bribe" = prop_cost, "bandwidth" = bandwidth)
}
### Compute Results ---------------------------
#' Obtain Transport Costs and Differences-in-Transports Estimator
#'
#' Given the pre and post probability mass functions as well as a vector of
#' bandwidths, this function returns the associated transport costs.
#' If another set of pre and post probability mass functions are given for the
#' control group, then the differences-in-transports estimator is returned.
#'
#' The \code{pre_main}, \code{post_main}, \code{pre_control}, and
#' \code{post_control} variables are all probability mass functions.
#' That is, they are a tibble with two columns:
#' \itemize{
#' \item column 1 contains the full support of \code{var}, and
#' \item column 2, which should be titled "count", contains the corresponding
#' mass of each value in the support.
#' }
#' Since column 1 contains the full support of \code{var} and all these distributions
#' are of \code{var}, column 1 must be the same for all distributions.
#'
#' @param pre_main probability mass function (see "Details") for \code{var} of the
#' treated group before treatment occurs
#' @param post_main probability mass function (see "Details") for \code{var} of the
#' treated group after treatment occurs
#' @param pre_control probability mass function (see "Details") for \code{var} of the
#' control group before treatment occurs; only required for the computing the
#' differences-in-transports estimator
#' @param post_control probability mass function (see "Details") for \code{var} of the
#' treated group after treatment occurs; only required for the computing the
#' differences-in-transports estimator
#' @param var the title of the first column of \code{pre_main}, \code{post_main},
#' \code{pre_control}, and \code{post_control}; default is \code{MSRP}
#' (see Daljord et al. (2021))
#' @param bandwidth_seq a vector of bandwidth values to try; default is \code{seq(0, 40000, 1000)}
#' @param estimator a string that takes on the value of "dit" for
#' differences-in-transports estimator or "tc" for the transport cost;
#' if \code{pre_control} and \code{post_control} are specified, default is "dit";
#' otherwise, default is "tc"
#' @param conservative if \code{TRUE}, then the bandwidth sequence will be
#' multiplied by 2 to provide a conservative estimate of the transport costs/
#' difference-in-transports estimator; default is \code{FALSE}
#' @param quietly if \code{TRUE}, some results and will be suppressed from printing; default is \code{FALSE}
#' @param suppress_progress_bar if \code{TRUE}, the progress bar will be suppressed; default is \code{FALSE}
#' @param save_dit if \code{TRUE}, the differences-in-transports estimator as
#' well as the associated bandwidth will be returned
#' @param costm_main if \code{NULL}, the cost matrix with common support will be such that if the transport
#' distance is greater than what is specified in \code{bandwidth_seq}, cost is 1 and 0 otherwise.
#' @param costm_control if \code{NULL}, the cost matrix with common support will be such that if the transport
#' distance is greater than what is specified in \code{bandwidth_seq}, cost is 1 and 0 otherwise.
#'
#' @return a data.frame with the transport costs associated with each value of \code{bandwidth_seq}.
#' \itemize{
#' \item \code{bandwidth}: same as \code{bandwidth_seq}
#' \item \code{main}: transport costs associated with main distributions
#' \item \code{main2d}: transport costs associated with main distributions using twice the bandwidth;
#' appears only if \code{conservative = TRUE}
#' \item \code{control}: transport costs associated with the control distributions;
#' appears only if \code{pre_control} and \code{post_control}
#' are specified
#' \item \code{diff}: \code{main - control}
#' \item \code{diff2d}: \code{main2d - control}
#' }
#'
#' If \code{save_dit = TRUE}, then a list is returned, with the first element
#' (labeled \code{out}) being the data.frame described above.
#' The second element (labeled \code{dit}) is the differences-in-transports
#' estimator, and the third and final element (labeled \code{optimal_bandwidth})
#' is the bandwidth associated with the estimator.
#'
#' @export
#' @importFrom rlang enquo
#' @examples
#' # Find conservative transport cost of MSRP in Beijing between 2010 and 2011 using bandwidth = 0
#' # # step 1: find support
#' support_Beijing <- Beijing_sample %>%
#' dplyr::filter(ym >= as.Date("2010-01-01") & ym < "2012-01-01") %>%
#' dplyr::select(MSRP) %>%
#' dplyr::distinct() %>%
#' dplyr::arrange(MSRP) %>%
#' dplyr::filter(!is.na(MSRP)) %>%
#' unlist()
#' temp <- data.frame(MSRP = support_Beijing)
#' # # step 2: prepare probability mass functions
#' pre_Beijing <- Beijing_sample %>%
#' dplyr::filter(ym >= as.Date("2010-01-01") & ym < "2011-01-01") %>%
#' dplyr::group_by(dplyr::across(c(MSRP))) %>%
#' dplyr::summarise(count = sum(sales)) %>%
#' dplyr::filter(!is.na(MSRP)) %>%
#' dplyr::left_join(temp, .) %>%
#' dplyr::select(MSRP, count) %>%
#' tidyr::replace_na(list(count = 0)) %>%
#' tibble::as_tibble()
#' post_Beijing <- Beijing_sample %>%
#' dplyr::filter(ym >= as.Date("2011-01-01") & ym < "2012-01-01") %>%
#' dplyr::group_by(dplyr::across(c(MSRP))) %>%
#' dplyr::summarise(count = sum(sales)) %>%
#' dplyr::filter(!is.na(MSRP)) %>%
#' dplyr::left_join(temp, .) %>%
#' dplyr::select(MSRP, count) %>%
#' tidyr::replace_na(list(count = 0)) %>%
#' tibble::as_tibble()
#' # # step 3: compute results
#' tc <- diftrans(pre_Beijing, post_Beijing, conservative = TRUE, bandwidth = 0)
#' tc$main2d
#'
#' # Find transport cost of MSRP in Beijing between 2010 and 2011 using bandwidth = 10000
#' # tc_10000 <- diftrans(pre_Beijing, post_Beijing, bandwidth = 10000)# tc_10000$main
#' # Find conservative differences-in-transport estimator using Tianjin as a control
#' # # step 1: find support
#' support_Tianjin <- Tianjin_sample %>%
#' dplyr::filter(ym >= as.Date("2010-01-01") & ym < "2012-01-01") %>%
#' dplyr::select(MSRP) %>%
#' dplyr::distinct() %>%
#' dplyr::arrange(MSRP) %>%
#' dplyr::filter(!is.na(MSRP)) %>%
#' unlist()
#' temp <- data.frame(MSRP = support_Tianjin)
#' # # step 2: prepare probability mass functions
#' pre_Tianjin <- Tianjin_sample %>%
#' dplyr::filter(ym >= as.Date("2010-01-01") & ym < "2011-01-01") %>%
#' dplyr::group_by(dplyr::across(c(MSRP))) %>%
#' dplyr::summarise(count = sum(sales)) %>%
#' dplyr::filter(!is.na(MSRP)) %>%
#' dplyr::left_join(temp, .) %>%
#' dplyr::select(MSRP, count) %>%
#' tidyr::replace_na(list(count = 0)) %>%
#' tibble::as_tibble()
#' post_Tianjin <- Tianjin_sample %>%
#' dplyr::filter(ym >= as.Date("2011-01-01") & ym < "2012-01-01") %>%
#' dplyr::group_by(dplyr::across(c(MSRP))) %>%
#' dplyr::summarise(count = sum(sales)) %>%
#' dplyr::filter(!is.na(MSRP)) %>%
#' dplyr::left_join(temp, .) %>%
#' dplyr::select(MSRP, count) %>%
#' tidyr::replace_na(list(count = 0)) %>%
#' tibble::as_tibble()
#' # # step 3: compute results
#' dit <- diftrans(pre_Beijing, post_Beijing, pre_Tianjin, post_Tianjin,
#' conservative = TRUE, bandwidth = seq(0, 40000, 1000),
#' save_dit = TRUE)
#' dit$optimal_bandwidth
#' dit$dit
diftrans <- function(pre_main = NULL, post_main = NULL,
pre_control = NULL, post_control = NULL,
var = MSRP,
bandwidth_seq = seq(0, 40000, 1000),
estimator = ifelse(!is.null(pre_control) & !is.null(post_control), "dit", "tc"),
conservative = F,
quietly = F,
suppress_progress_bar = F,
save_dit = F,
costm_main = NULL,
costm_control = NULL){
var <- rlang::enquo(var)
# error checking
if (is.null(pre_main) | is.null(post_main)){
message("`pre_main` and/or `post_main` is mising.")
}
estimator <- tolower(estimator)
if (estimator == "tc"){
if (!is.null(pre_control) | !is.null(post_control)){
message("`pre_control` and/or `post_control` will be ignored.")
}
est_message <- "Computing Transport Costs..."
est <- "tc"
if (!suppress_progress_bar) message(est_message)
} else if (estimator == "dit" | estimator == "differences-in-transports"){
if (is.null(pre_control) | is.null(post_control)){
message("`pre_control` and/or `post_control` is mising.")
}
est_message <- "Computing Differences-in-Transports Estimator..."
est <- "dit"
if (!suppress_progress_bar) message(est_message)
} else {
stop("Invalid estimator. Choose 'tc' or 'dit' or double check inputs.")
}
if (conservative & !quietly){
message("Note: you are using `conservative = T`.")
}
if (est != "dit" & save_dit){
warning("The differences-in-transports estimator is not being computed so `save_dit = TRUE` is being ignored.")
}
# initialization
main_prop <- rep(NA_real_, length(bandwidth_seq))
if (conservative) maincons_prop <- rep(NA_real_, length(bandwidth_seq))
if (est == "dit") control_prop <- rep(NA_real_, length(bandwidth_seq))
# computation
if (!suppress_progress_bar) pb <- utils::txtProgressBar(min = 0, max = length(bandwidth_seq), initial = 0)
for (i in seq_along(bandwidth_seq)){
if (!suppress_progress_bar) utils::setTxtProgressBar(pb, i)
bandwidth <- bandwidth_seq[i]
main_cost <- get_OTcost(pre_main, post_main, bandwidth = bandwidth, var = !!var, costmat = costm_main)
if (conservative) maincons_cost <- get_OTcost(pre_main, post_main, bandwidth = 2*bandwidth, var = !!var, costmat = costm_main)
if (est == "dit") control_cost <- get_OTcost(pre_control, post_control, bandwidth = bandwidth, var = !!var, costmat = costm_control)
main_prop[i] <- main_cost$prop_bribe
if (conservative) maincons_prop[i] <- maincons_cost$prop_bribe
if (est == "dit") control_prop[i] <- control_cost$prop_bribe
}
if (!suppress_progress_bar) cat("\n")
# compile results
if (est == "dit"){
diffprop <- main_prop - control_prop
if (conservative){
diffprop2d <- maincons_prop - control_prop
out <- data.frame(bandwidth = bandwidth_seq,
main = main_prop,
main2d = maincons_prop,
control = control_prop,
diff = diffprop,
diff2d = diffprop2d)
whichmax <- which.max(diffprop2d)
dit <- diffprop2d[whichmax]
dstar <- bandwidth_seq[whichmax]
if (!quietly) message(paste("The conservative diff-in-transports estimator is ", dit, " at d = ", dstar, sep = ""))
} else {
out <- data.frame(bandwidth = bandwidth_seq,
main = main_prop,
control = control_prop,
diff = diffprop)
whichmax <- which.max(diffprop)
dit <- diffprop[whichmax]
dstar <- bandwidth_seq[whichmax]
if (!quietly) message(paste("The non-conservative diff-in-transports estimator is ", dit, " at d = ", dstar, sep = ""))
}
if (save_dit) out <- list(out = out, dit = dit, optimal_bandwidth = dstar)
}
if (est == "tc"){
out <- data.frame(bandwidth = bandwidth_seq,
main = main_prop)
if (conservative) out <- cbind(out, main2d = maincons_prop)
if (!quietly) message(paste("The transport cost for the specified bandwidths have been computed."))
}
return(invisible(out))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.