R/plotting_functions.R

Defines functions get_date_ticks get_base_plot get_plot_data statenames get_state_name convert_to_plot_data std_model_col_names prepare_empirical_data prepare_model_data prepare_plot_datasets

Documented in convert_to_plot_data get_base_plot get_plot_data get_state_name prepare_empirical_data prepare_model_data prepare_plot_datasets statenames std_model_col_names

#' Prepare the Plotting Datasets
#'
#' This function prepares the model and empirical datasets and binds them together
#' @param us_model_file_name default is NULL
#' @param state_model_file_name default is NULL
#' @param county_model_file_name default is NULL
#' @param model_filter_dates default is NULL
#' @param empirical_data_type default is NULL, what is the source of the empirical data, either csse or usafacts or a file name/path to custom csv
#' @param csse_repo_path location of the csse repo (default NULL)
#' @param hospscaler hosp scaler to get infection estimate, defaults to NULL
#' @param updategit default is FALSE; should the local csse git repo be queried for possible update?
#' @return Returns a list of dataframes, one for each model file name provided. Each frame will contain the combined
#' data from both the model and the empirical data
#' @export
#' @examples
#' prepare_plot_datasets("usfile.csv", "statefile.csv", model_filter_dates=c("2020-03-01", "2020-06-15"), "usafacts","repo_dir/")
prepare_plot_datasets <- function(us_model_file_name=NULL,
                                  state_model_file_name=NULL,
                                  county_model_file_name=NULL,
                                  model_filter_dates=NULL,
                                  empirical_data_type=NULL,
                                  csse_repo_path=NULL,
                                  hospscaler = NULL,
                                  updategit=F
) {

  result = list()

  #################################
  ### Get the  Empirical Data
  #################################
  empirical_data = prepare_empirical_data(empirical_data_type, csse_repo_path, updategit=updategit)


  #################################
  ### Get the  Model Data
  #################################

  if(!is.null(us_model_file_name)) {
    model_data_us = prepare_model_data(us_model_file_name,filterdates = model_filter_dates, hospscaler = hospscaler)
    us_df <- convert_to_plot_data(model_data_us,empirical_data[["us"]],bycols=c("date"="Date"))
    result[["us"]]=us_df
  }

  if(!is.null(state_model_file_name)) {
    model_data_state = prepare_model_data(state_model_file_name,filterdates = model_filter_dates, hospscaler = hospscaler)
    state_df <- convert_to_plot_data(model_data_state,empirical_data[["state"]],bycols=c("date"="Date","USPS"="USPS"))
    result[["state"]]=state_df
  }

  if(!is.null(county_model_file_name)) {
    model_data_county = prepare_model_data(county_model_file_name,filterdates = model_filter_dates, hospscaler = hospscaler)
    county_df <- convert_to_plot_data(model_data_county, empirical_data[["county"]],bycols=c("date"="Date", "FIPS"="FIPS"))
    result[["county"]]=county_df
  }

  return(result)

}


#' Prepare the Model Data
#'
#' Given a string filename for the model input (csv) format, function will return a
#' dataframe of the model output, restricted to the given dates
#' @param model_path filename for the model input
#' @param filterdates a vector of two dates, indicating the start and end
#' @param hospscaler a factor by which to scale infections to equal a multiple of hospitalizations
#' @return Returns a dataframe
#' @export
#' @examples
#' prepare_model_data(usfname)
#' prepare_model_data(usfname, c("2020-03-01", "2020-06-15"))
prepare_model_data <- function(model_path, filterdates=NULL, hospscaler=NULL) {

  df <- data.table::fread(model_path)
  df[,date:=as.Date(date,"%Y-%m-%d")]

  df <- std_model_col_names(df, inf,c("inf", "incidI", "infection","infections"))
  df <- std_model_col_names(df, cuminf,c("cuminf", "cumincidI", "cuminfection","cuminfections",
                                         "cum_inf", "cum_incidI", "cum_infection","cum_infections"))
  df <- std_model_col_names(df, deaths,c("deaths", "incidD", "death"))
  df <- std_model_col_names(df, cumdeaths,c("cum_deaths","cum_death", "cum_incidD",
                                            "cumdeaths", "cumdeath", "cumincidD"))


  if(!("USPS" %in% colnames(df))) {
    df[,USPS:=""]
  }
  df[,USPS:=dplyr::if_else(is.na(USPS),as.character(""),as.character(USPS))]


  if(!is.null(filterdates)) {
    df <- df[date>=filterdates[1] & date<=filterdates[2],]
  }

  #scale infections = hosp*hosp if called for
  if(!is.null(hospscaler)) {
    #Check for valid column names
    df <- std_model_col_names(df, hosps,c("incidH", "hosps", "hospitalizations"))
    df <- std_model_col_names(df, inf,c("inf", "incidI", "infection","infections"))
    df[, inf := hosps*hospscaler]
    #check for cumulative hospitalizations
    validhospnames <- c("cumhosps","cum_hosps","cumincidH",
                        "cum_incidH","cum_hospitalizations", "cumhospitalizations")
    if(length(intersect(validhospnames,colnames(df)))>0) {
      df <- std_model_col_names(df,cumhosps, validhospnames)
      df[,cuminf := cumhosps*hospscaler]
    }

  }

  return(df)
}

#' Prepare the Empirical Data
#'
#' This function prepares the empirical data. Depending on source it will call functions to prepare the csse or usafacts data
#' @param empirical_source source of empirical data.. must be 'csse','usafacts' or a custom file path
#' @param csse_repo_path defaults to NULL, otherwise a path to local clone of JHU CSSE repo
#' @param filterdates defaults to all dates forward from March 1st, 2020, but can be specifed with a vector of dates
#' @param updategit defaults to FALSE; should the local csse git repo be queried for possible update?
#' @export
#' @examples
#' prepare_empirical_data("usafacts", "jhudata/", c("2020-03-01", "2020-06-01"))
#' prepare_empirical_data("my_empirical_data.csv", filterdates = c("2020-03-01", "2020-06-01"))
prepare_empirical_data <- function(empirical_source, csse_repo_path=NULL, filterdates=c("2020-03-01", "2030-03-01"),updategit=F) {

  # a reverse cumulative function
  reverseCumul <- function(x) {
    lenx = length(x)
    return(c(x[1],(dplyr::lead(x)-x)[1:(lenx-1)]))
  }

  states <- data.table::setnames(statenames(), old=c("state_abbreviation", "state_name"), new=c("USPS","Province_State"))

  emp_type = dplyr::if_else(empirical_source %in% c("usafacts","csse"),empirical_source, "custom")

  if(emp_type=="usafacts") {
    empirical_base <- usafactsdata()[,list(Province_State, Admin2, FIPS, Date, Confirmed, Deaths)]
    data.table::setnames(empirical_base,old="Province_State", new="USPS")
    empirical_base <- data.table::merge.data.table(empirical_base, states[,list(USPS,Province_State)], by="USPS")

    #As of 2020-05-05, usafacts does not have data for the terriorties, so
    #we are going to supplement by pulling the csse data, and adding in the
    #rows from the territories
    supplement_with_csse_territory_data = T

    if(supplement_with_csse_territory_data) {
      csse_territory_data <- cssedata(csse_repo_path,updategit = updategit)[, list(Province_State, Admin2, FIPS,Date,Confirmed,Deaths)]
      csse_territory_data <- data.table::merge.data.table(csse_territory_data, states[USPS %in% c("AS","GU","MP","PR","VI"),])

      empirical_base <- data.table::setorder(rbind(empirical_base,csse_territory_data),USPS, Admin2, Date)
    }
  }

  if(emp_type=="csse") {
    empirical_base <- data.table::merge.data.table(cssedata(csse_repo_path, updategit=updategit)[, list(Province_State, Admin2, FIPS,Date,Confirmed,Deaths)],
                                                     states,
                                                     by="Province_State")
  }

  if(emp_type=="custom") {
    empirical_base <- prepare_custom_repo(empirical_source)
  }

  if(emp_type %in% c("usafacts","csse")) {
    #lets drop any rows without FIPS, because we are going to plot across these
    empirical_base <- empirical_base[!is.na(FIPS),list(USPS,Province_State, Admin2, FIPS, Date, Confirmed, Deaths)]
    data.table::setnames(empirical_base, old=c("Confirmed", "Deaths"), new=c("cumConfirmed","cumDeaths"))

    county <- data.table::copy(empirical_base)
    county[,`:=`(Confirmed=reverseCumul(cumConfirmed), Deaths=reverseCumul(cumDeaths)), by="FIPS"]
    county <- county[Date>=filterdates[1] & Date<=filterdates[2],]

    state <- data.table::copy(empirical_base)
    state <- state[,lapply(.SD, sum), by=c("USPS","Date"), .SDcols=c("cumConfirmed","cumDeaths")]
    state[,`:=`(Confirmed=reverseCumul(cumConfirmed), Deaths=reverseCumul(cumDeaths)), by="USPS"]
    state <- state[Date>=filterdates[1] & Date<=filterdates[2],]

    #for last, one, US, no need to make copy
    empirical_base <- empirical_base[,lapply(.SD, sum), by="Date", .SDcols=c("cumConfirmed","cumDeaths")]
    empirical_base[,`:=`(Confirmed=reverseCumul(cumConfirmed), Deaths=reverseCumul(cumDeaths))]
    empirical_base[Date>=filterdates[1] & Date<=filterdates[2],]
  }

  if(emp_type == "custom") {
    empirical_base <- empirical_base[!is.na(FIPS),]

    county <- data.table::copy(empirical_base)
    county <- county[Date>=filterdates[1] & Date<=filterdates[2],]

    state <- data.table::copy(empirical_base)
    state <- state[,lapply(.SD, sum), by=c("USPS","Date"), .SDcols=c("cumConfirmed","cumDeaths", "Confirmed", "Deaths")]
    state <- state[Date>=filterdates[1] & Date<=filterdates[2],]

    #for last, one, US, no need to make copy
    empirical_base <- empirical_base[,lapply(.SD, sum), by="Date", .SDcols=c("cumConfirmed","cumDeaths", "Confirmed", "Deaths")]
    empirical_base[Date>=filterdates[1] & Date<=filterdates[2],]
  }

  return(list(us = empirical_base, state=state, county=county))

}

#' Standardize model names
#'
#' This function allows for accepting model data frames with alternative
#' names for inf, death.  A dataframe, a column type, and its possible alternative names
#' are provided; if one and only one of the alternatives is found, it is accepted and
#' converted to the column type inf or deaths
#' @param dt data frame of model data
#' @param coltype an unquoted column type, inf or deaths
#' @param valid_names possible names
#' @param ignore_error In general, one and only one valid column name should be found; if this is not the case, the error will be ignored by default,
#' unless this parameter is set to TRUE
#' @export
#' @examples
#' std_model_col_name(model_data, inf, c("inf","infection","infections","incidI"))
#' std_model_col_name(model_data, deaths, c("death","deaths","incidD"))
std_model_col_names <- function(dt, coltype, valid_names, ignore_error=F) {
  coltype <- rlang::enquo(coltype)
  col_index <- which(colnames(dt) %in% valid_names)
  if(length(col_index)!=1) {
    if(!ignore_error) stop(paste0("Invalid column name for ", rlang::as_label(coltype),"; must be one of ", valid_names), call.=F)
  }
  else dt <- dplyr::rename(dt, !!coltype := col_index[1])
  return(dt)
}

#' Convert the model and empirical data to plot data
#'
#' This function combines model and empirical data, selecting quantiles, swinging to wide format
#' and merging with the empirical data
#' @param model_data data frame of model data
#' @param empirical_data data frame of empirical data
#' @param bycols columns to join on
#' @export
#' @examples
#' convert_to_plot_data(idd_data_frame, csse_data_frame,bycols=c("USPS","Date"))

convert_to_plot_data <- function(model_data, empirical_data, bycols) {

  if(!"USPS" %in% colnames(model_data)) model_data[,USPS:=""]
  df <- model_data[q %in% c("p025","p250","p500","p750","p975"), list(date, USPS, q, inf, deaths, cuminf, cumdeaths)]
  df <- dcast(df,date+USPS~q,value.var = c("inf","deaths", "cuminf","cumdeaths"))
  df <- data.table::merge.data.table(df, empirical_data, by.x=names(bycols), by.y=bycols, all.x = T)
  data.table::setorder(df, USPS, date)
  data.table::setnames(df, old = c("Date","USPS",
                    "infections_p025","infections_p250","infections_p500","infections_p750","infections_p975",
                    "deaths_p025","deaths_p250","deaths_p500","deaths_p750","deaths_p975",
                    "cuminfections_p025","cuminfections_p250","cuminfections_p500","cuminfections_p750","cuminfections_p975",
                    "cumdeaths_p025","cumdeaths_p250","cumdeaths_p500","cumdeaths_p750","cumdeaths_p975",
                    #Add empirical columns
                    "cumConfirmed","cumDeaths","Confirmed","Deaths"))
  return(df)
}

#' Retrieve a state name, given an abbreviation
#'
#' This function returns a state name, given the abbreviation
#' @param state_abb two-character string abbreviation for state or territory
#' @export
#' @examples
#' get_state_name("MD")
get_state_name <- function(state_abb) {
  statenames()[state_abbreviation == state_abb,state_name]
}

#' Prepare a utility dataframe for state names and abbreviations
#'
#' This function prepares a utility dataframe for state names and abbreviations
#' @export
#' @examples
#' statenames()
statenames <- function() {
  states_names <- c("Alabama","Alaska","Arizona","Arkansas","California","Colorado","Connecticut",
             "Delaware","District of Columbia","Florida","Georgia","Hawaii","Idaho","Illinois",
             "Indiana","Iowa","Kansas","Kentucky","Louisiana","Maine","Maryland","Massachusetts",
             "Michigan","Minnesota","Mississippi","Missouri","Montana","Nebraska","Nevada",
             "New Hampshire","New Jersey","New Mexico","New York","North Carolina",
             "North Dakota","Ohio","Oklahoma","Oregon","Pennsylvania",
             "Rhode Island","South Carolina","South Dakota","Tennessee","Texas","Utah","Vermont",
             "Virginia","Washington","West Virginia","Wisconsin","Wyoming","Puerto Rico",
             "American Samoa","Virgin Islands","Guam","Northern Mariana Islands")
  state_abbreviations <- c("AL","AK","AZ","AR","CA","CO","CT","DE","DC","FL","GA","HI","ID","IL","IN","IA","KS","KY","LA","ME","MD","MA",
                               "MI","MN","MS","MO","MT","NE","NV","NH","NJ","NM","NY","NC","ND","OH","OK","OR","PA","RI","SC","SD","TN","TX",
                               "UT","VT","VA","WA","WV","WI","WY","PR","AS","VI","GU", "MP")

  return(data.table::data.table(state_name = states_names, state_abbreviation=state_abbreviations))

}

#' Returns the subset of data for a specific plot
#'
#' This function prepares the subset of data that will be fed to the base plotting function
#' @param dt dataframe prepared using the prepare_plot_datasets() function
#' @param outcome character string, one of "Cases" or "Deaths"
#' @param geo_level character string, one of "us", "state", "county"
#' @param state defaults to NULL, but if geo_level is "state", should be two-character state abbreviation
#' @param cumulative defaults to FALSE (and daily outcome is plotted), but can be TRUE for cumulative results
#' @param per100K currently not implemented (FALSE)
#' @param popsize currenlty not implemented (NULL)
#' @return Returns a list including the subset of data (plot_data), the location (location), the outcome,
#' and whether or not the data are cumulative or not
#' @export
#' @examples
#' get_plot_data(dt=prepared_data_frame, outcome="Cases", geo_level="us", state=NULL, cumulative=F)
#' get_plot_data(dt=prepared_data_frame, outcome="Deaths", geo_level="state", state="TX", cumulative=T)
get_plot_data <- function(dt, outcome, geo_level, state=NULL, cumulative=F, per100K=F, popsize=NULL) {

  #First, set up location name
  if(geo_level=="state") {
    loc_name = get_state_name(state)
  } else {
    loc_name = "United States"
    state = ""
  }

  #Get the target variables, based on the outcome
  if(cumulative) {
    targ_vars = list(Cases =  c(colnames(dt)[stringr::str_starts(colnames(dt),"cuminfections")], "cumConfirmed"),
                     Deaths = c(colnames(dt)[stringr::str_starts(colnames(dt),"cumdeaths")], "cumDeaths"))
  } else {
    targ_vars = list(Cases =  c(colnames(dt)[stringr::str_starts(colnames(dt),"infections")], "Confirmed"),
                     Deaths = c(colnames(dt)[stringr::str_starts(colnames(dt),"deaths")], "Deaths"))
  }

  targ_vars = c("USPS","Date",targ_vars[[outcome]])
  #Filter the data and select the variables

  plot_data <- dt[,..targ_vars]

  if(geo_level=="state") plot_data <- plot_data[USPS==state,]

  data.table::setnames(plot_data, c("USPS","Date","p025","p250","p500","p750","p975","empirical"))

  #if per 100K, then we must have population data for this location
  if(per100K) {
    if(is.null(popsize)) {
      stop("There is no poulation data for this location")
    } else {
      targcols <- which(sapply(plot_data,is.numeric))
      plot_data[,(targcols):=lapply(.SD, function(x) x*100000/popsize), .SDcols=targcols]
    }
  }

  return(list(plot_data=plot_data, location=loc_name,cumulative=cumulative, outcome=outcome))

}

#' Generates the plot
#'
#' This function generates the base plot, upon receiving a list of plot data information
#' @param dt plot data, from the get_plot_data function
#' @param ytransform defaults to "sqrt" but can be "identify", "log", "log10", etc
#' @param empirical_label label for the plot legend for the empirical data (default is "Reported")
#' @param model_label label for the plot legend for the modelled data (default is "Modelled")
#' @param x_label label for the x-axis (default is "Date")
#' @param fill_label label for the legend denoting the ribbon/range fills (default is "Range Estimates")
#' @param color_label label for the legend deonoting the outcome color
#' @param title_string default is to use the location from the dt list, but will be the title of plot
#' @param dateticks a vector of dates for the x-axis
#' @param caption_below_plot string value for the caption below (default is no caption)
#' @param yticklabels a vector of ticks specified for the yaxis
#' @param smooth add a generalized additive model cubic regression spline smoother to empirical data? (default is FALSE)
#' @param smooth_ci add CI to the empirical data smoother? (default is F)
#' @param legend_location specify the location of the legend, either 'belowplot' (default), 'topleft', or 'bottomright'
#' @param labsize specify the text size for the axis tick labels and the legend (default is 8)
#' @export
#' @examples
#' get_base_plot(plt_df, "sqrt", caption="My Caption")
get_base_plot <- function(dt,
                          ytransform="sqrt",
                          empirical_label=paste0("Reported ", dt$outcome),
                          model_label = paste0("Modelled ", dt$outcome),
                          x_label="Date",
                          fill_label = "Range Estimates",
                          color_label = "Outcomes",
                          title_string=paste0(dt$location),
                          dateticks=NULL,
                          caption_below_plot=NULL,
                          xlimits=NULL,
                          yticklabels = ggplot2::waiver(),
                          smooth=F,
                          smooth_ci=F,
                          legend_location=c("belowplot","bottomright","topleft"),
                          labsize=8) {

  legend_location <- match.arg(legend_location)
  legend_location <- dplyr::case_when(
    legend_location=="topleft"~c(.1,.95),
    legend_location=="bottomright"~c(.90,.4),
    legend_location=="belowplot"~c(-1,-1)
  )
  if(legend_location[1]==-1) legend_location="bottom"

  outcome = dt$outcome
  g_input = dt$plot_data

  #get the dateticks if dateticks is null
  if(is.null(dateticks)) {
    dateticks <- get_date_ticks(min(g_input$Date), max(g_input$Date))
  }
  #angle the xticks if more than 6 dates
  if(length(dateticks)>6) xtickspecs=list(angle=45,vjust=1,hjust=1)
  else xtickspecs=list(angle=0, vjust=1, hjust=0.5)

  y_label=paste0(dplyr::if_else(dt$cumulative,"Cumulative ", ""), "Number of ",dt$outcome)

  #overide ticklabels if sqrt
  if(ytransform == "sqrt") {
    maxy = g_input %>%
      dplyr::select_if(is.numeric) %>%
      max(na.rm=TRUE)
    yticklabels = unlist(lapply(seq(1,10), function(x) c(1,2.5,5)*10^x))
    yticklabels = yticklabels[c(which(yticklabels<maxy), which(yticklabels>=maxy)[1])]
  }
  colorvalues = c("red","blue")
  names(colorvalues) = c(model_label,empirical_label)

  plt <- ggplot2::ggplot(g_input)
  plt <- plt + ggplot2::geom_point(ggplot2::aes(Date, empirical,color=empirical_label), size=1.7)
  if(smooth) {
    plt <- plt + ggplot2::geom_smooth(method="gam", mapping=ggplot2::aes(Date, empirical, color=empirical_label), size=1.7,se = smooth_ci)
  }
  plt <- plt +
    ggplot2::geom_line(ggplot2::aes(Date,p500, color=model_label), size=1.2) +
    ggplot2::geom_ribbon(ggplot2::aes(x=Date,ymin=p025,ymax=p975,fill="2.5% - 97.5%"),alpha=.1)+
    ggplot2::geom_ribbon(ggplot2::aes(x=Date,ymin=p250,ymax=p750,fill="25.0% - 75.0%"),alpha=.2)+
    ggplot2::scale_y_continuous(label=scales::comma,trans=ytransform, breaks=yticklabels) +
    ggplot2::scale_x_date(breaks=dateticks,limits = xlimits)+
    ggplot2::scale_fill_manual(values = c("2.5% - 97.5%"="black","25.0% - 75.0%"="red")) +
    ggplot2::scale_color_manual(values = colorvalues) +

    ggplot2::ylab(y_label) +
    ggplot2::xlab(x_label) +
    ggplot2::ggtitle(title_string) +
    ggplot2::labs(fill=fill_label) +
    ggplot2::labs(color=color_label) +
    ggplot2::labs(caption=caption_below_plot)+
    ggplot2::theme(
      axis.text.x = ggplot2::element_text(angle = xtickspecs[["angle"]], hjust=xtickspecs[["hjust"]],vjust=xtickspecs[["vjust"]],size=labsize),
      axis.text.y = ggplot2::element_text(size=labsize),
      legend.text = ggplot2::element_text(size=labsize),
      legend.position = legend_location,
      legend.justification = c("center", "top"),
      legend.box.just = "center",
      legend.margin = ggplot2::margin(3, 3, 3, 3),
      legend.background = ggplot2::element_rect(fill = "transparent",color = NA)) +
    ggplot2::guides(color = ggplot2::guide_legend(order = 1),fill = ggplot2::guide_legend(order = 2))



  return(plt)

}

#' Generates date ticks 1st, and 15th of each month
#'
#' This function gets data ticks for the 1st and 15th of the month.. is the default if dates are not provided to
#' get_base_plot function
#' @param mind The minimum date
#' @param maxd The maximum date
#' @noRd
#' @examples
#' get_date_ticks(mymindate, mymaxdate)

get_date_ticks <- function(mind,maxd) {
  dates <- c(
    lubridate::floor_date(seq(mind,maxd,by="month")),
    lubridate::floor_date(seq(mind,maxd,by="month")+14)
  )
  return(as.Date(dates[order(dates)]))
}
lmullany/iddplotting documentation built on July 26, 2020, 8:05 p.m.