R/prophet.R

Defines functions do_prophet_ add_country_holidays do_prophet is_na_rm_func trim_future

# Wrapper functions around prophet.

# Trim future part of pre-aggregation data, when it is with external regressors or holidays.
trim_future <- function(df, time_col, value_col, periods, time_unit) {
  if (!is.null(value_col)) { # if value_col is there consider rows with values to be history data.
    # Figure out the max time with non-na value, and use it as the boundary.
    # We do this as opposed to filter out all rows with NAs, to work with functions like na_count, and to keep extra regressor info as much as possible.
    # NAs are later handled by na.rm=TRUE option of aggregate function.
    non_na_df <- df %>% dplyr::filter(!is.na(UQ(rlang::sym(value_col))))
    max_non_na_time <- max(non_na_df[[time_col]], na.rm=TRUE)
    df <- df %>% dplyr::filter(!!rlang::sym(time_col) <= !!max_non_na_time)
  }
  else { # if value_col does not exist, use period to determine the boundary between history and future.
    if (time_unit %in% c("second", "sec")) {
      time_unit_func <- lubridate::seconds
    }
    else if (time_unit %in% c("minute", "min")) {
      time_unit_func <- lubridate::minutes
    }
    else if (time_unit == "hour") {
      time_unit_func <- lubridate::hours
    }
    else if (time_unit == "day") {
      time_unit_func <- lubridate::days
    }
    else if (time_unit == "week") {
      time_unit_func <- lubridate::weeks
    }
    else if (time_unit == "month") {
      time_unit_func <- base::months
    }
    else if (time_unit == "quarter") {
      time_unit_func <- function(x) {
        base::months(3 * x)
      }
    }
    else { # assuming it is year.
      time_unit_func <- lubridate::years
    }
    # Keep the rows older than the history/future boundary, as the history data.
    df <- df %>% dplyr::filter(!!rlang::sym(time_col) <= (max(!!rlang::sym(time_col)) - time_unit_func(!!periods)))
  }
  df
}

is_na_rm_func <- function(func) {
  if (identical(sum, func) ||
      identical(mean, func) ||
      identical(median, func) ||
      identical(min, func) ||
      identical(max, func) ||
      identical(sd, func) ||
      identical(var, func) ||
      identical(IQR, func) ||
      identical(mad, func)) {
    return(TRUE)
  }
  else {
    return(FALSE)
  }
}

#' NSE version of do_prophet_
#' @export
do_prophet <- function(df, time, value = NULL, periods = 10, holiday = NULL, ...){
  time_col <- col_name(substitute(time))
  value_col <- col_name(substitute(value))
  holiday_col <- col_name(substitute(holiday))
  do_prophet_(df, time_col, value_col, periods, holiday_col = holiday_col, ...)
}

# Modified version of prophet::add_country_holidays to allow multiple country names
# under R4.2's stricter rule of if condition having to be length-1 vector.
add_country_holidays <- function(m, country_names) {
  if (!is.null(m$history)) {
    stop("Country holidays must be added prior to model fitting.")
  }
  for (country_name in country_names) {
    if (!(country_name %in% generated_holidays$country)){
      stop("Holidays in ", country_name, " are not currently supported!")
    }
  }
  # Validate names.
  for (name in prophet:::get_holiday_names(country_names)) {
    # Allow merging with existing holidays
    prophet:::validate_column_name(m, name, check_holidays = FALSE)
  }
  # Set the holidays.
  if (!is.null(m$country_holidays)) {
    message(
      'Changing country holidays from ', m$country_holidays, ' to ',
      country_names
    )
  }
  m$country_holidays <- country_names
  return(m)
}

#' Forecast time series data
#' @param df - Data frame
#' @param time_col - Column that has time data
#' @param value_col - Column that has value data
#' @param periods - Number of time periods (e.g. days. unit is determined by time_unit) to forecast.
#' @param time_unit - "second"/"sec", "minute"/"min", "hour", "day", "week", "month", "quarter", or "year".
#' @param include_history - Whether to include history data in forecast or not.
#' @param fun.aggregate - Function to aggregate values.
#' @param na_fill_type - Type of NA fill:
#'                       NULL - Skip NA fill. Default behavior.
#'                       "previous" - Fill with previous non-NA value.
#'                       "value" - Fill with the value of na_fill_value.
#'                       "interpolate" - Linear interpolation.
#'                       "spline" - Spline interpolation.
#' @param na_fill_value - Value to fill NA when na_fill_type is "value"
#' @param ... - extra values to be passed to prophet::prophet. listed below.
#' @param growth - This parameter used to specify type of Trend, which can be "linear" or "logistic",
#'        but now we determine this automatically by cap. It is here just to avoid throwing error from prophet,
#'        (about doubly specifying grouth param by our code and by "...") when old caller calls with this parameter.
#' @param cap - Achievable Maximum Capacity of the value to forecast.
#'        https://facebookincubator.github.io/prophet/docs/forecasting_growth.html
#'        It can be numeric or data frame. When numeric, the value is used as cap for both modeling and forecasting.
#'        When it is a data frame, it should be a future data frame with cap column for forecasting.
#'        When this is specified, the original data frame (df) should also have cap column for modeling.
#'        When either a numeric or a data frame is specified, growth argument for prophet becomes "logistic",
#'        as opposed to default "linear".
#' @param seasonality.prior.scale - Strength of seasonality. Default is 10.
#' @param yearly.seasonality - Whether to return yearly seasonality data.
#' @param weekly.seasonality - Whether to return weekly seasonality data.
#' @param n.changepoints - Number of potential changepoints. Default is 25.
#' @param changepoint.prior.scale - Flexibility of automatic changepoint selection. Default is 0.05.
#' @param changepoints - list of potential changepoints.
#' @param holidays.prior.scale - Strength of holiday effect. Default is 10.
#' @param holidays - Holiday definition data frame.
#' @param mcmc.samples - MCMC samples for full bayesian inference. Default is 0.
#' @param interval.width - Width of uncertainty intervals.
#' @param uncertainty.samples - Number of simulations made for calculating uncertainty intervals. Default is 1000.
#' @export
do_prophet_ <- function(df, time_col, value_col = NULL, periods = 10, time_unit = "day", include_history = TRUE, test_mode = FALSE,
                        fun.aggregate = sum, na_fill_type = NULL, na_fill_value = 0,
                        cap = NULL, floor = NULL, growth = NULL, weekly.seasonality = "auto", yearly.seasonality = "auto",
                        quarterly.seasonality = FALSE, monthly.seasonality = FALSE,
                        daily.seasonality = "auto",
                        holiday_col = NULL, holidays = NULL, holiday_country_names = NULL,
                        regressors = NULL, funs.aggregate.regressors = NULL, regressors_na_fill_type = NULL, regressors_na_fill_value = 0,
                        output = "data", ...) {
  validate_empty_data(df)

  # Pseudo code of preprocessing:
  # ----
  # floor_date
  # if (!is.null(regressors)) {
  #   separate df into history and future
  #   aggregate future df
  # }
  # aggregate history df
  # if (test_mode) {
  #   separate history df into training df and test df based on periods
  # } else {
  #   training df is history df as is
  # }

  # we are making default for weekly/yearly.seasonality TRUE since 'auto' does not behave well.
  # it seems that there are cases that weekly.seasonality is turned off as a side-effect of yearly.seasonality turned off.
  # if that happens, since no seasonality is on, prophet forecast result becomes just a linear trend line,
  # which does not look convincing.
  # since there seems to be cases where 'auto' on yearly.seasonality triggers this situation, we are using TRUE as default.
  # we have not seen any issue on 'auto' on weekly.seasonality, but are not using it for now just to be careful.
  # Update on 2021/03 - We have been using Analytics View with "auto" for yearly as well as weekly, set by JS command generator layer,
  # but haven't seen any obvious issues. Setting them back to auto, since now we rather see problem with yearly seasonality enabled
  # for less than 2 years of data.

  loadNamespace("dplyr")
  # For some reason this needs to be library() instead of loadNamespace() to avoid error.
  # Bug in prophet?
  library("prophet")

  grouped_col <- grouped_by(df)

  if (time_unit == "min") {
    time_unit <- "minute"
  }
  else if (time_unit == "sec") {
    time_unit <- "second"
  }

  # column name validation
  if(!time_col %in% colnames(df)){
    stop(paste0(time_col, " is not in column names"))
  }

  if(time_col %in% grouped_col){
    stop(paste0(time_col, " is grouped. Please ungroup it."))
  }

  if(!is.null(value_col)){
    if (!value_col %in% colnames(df)){
      stop(paste0(value_col, " is not in column names"))
    }
    if(value_col %in% grouped_col){
      stop(paste0(value_col, " is grouped. Please ungroup it."))
    }
  }

  if (!is.null(cap) && !is.data.frame(cap) && !is.null(floor) && cap <= floor) {
    # validate this case. otherwise, the error will be misterious "missing value where TRUE/FALSE needed".
    stop("cap must be greater than floor.")
  }

  if (!is.null(holiday_country_names)) {
    # For ISO2C codes, make it upper case.
    holiday_country_names <- dplyr::if_else(stringr::str_length(holiday_country_names) == 2, stringr::str_to_upper(holiday_country_names), holiday_country_names)
    # Mapping to support some ISO2C codes, that are actually supported but with different names.
    holiday_country_names <- dplyr::recode(holiday_country_names, `UnitedKingdom`="GB", `Turkey`="TR", `France`="FR")
  }

  if (!is.null(holidays)) {
    if (!("ds" %in% colnames(holidays))) {
      stop("The holiday data frame needs to have the ds column.")
    }
    if (!("holiday" %in% colnames(holidays))) {
      stop("The holiday data frame needs to have the holiday column.")
    }
    if (!lubridate::is.Date(holidays$ds) && !lubridate::is.POSIXct(holidays$ds)) {
      stop("The type of the ds column of the holiday data frame needs to be Date or POSIXct.")
    }
    if (!is.character(holidays$holiday)) {
      holidays$holiday <- as.character(holidays$holiday)
    }
  }

  # To filter NAs on regressor columns
  filter_args <- list() # default empty list
  if (!is.null(regressors)) {
    filter_args <- purrr::map(regressors, function(cname) {
      quo(!is.na(UQ(rlang::sym(cname))))
    })
    names(filter_args) <- NULL
  }

  # Compose arguments to pass to dplyr::summarise.
  summarise_args <- list() # default empty list
  regressor_output_cols <- NULL # Just declaring variable
  regressor_final_output_cols <- NULL # Just declaring variable
  if (!is.null(regressors) && !is.null(funs.aggregate.regressors)) {
    summarise_args <- purrr::map2(funs.aggregate.regressors, regressors, function(func, cname) {
      # For common functions that require na.rm=TRUE to handle NA, add it.
      if (is_na_rm_func(func)) {
        quo(UQ(func)(UQ(rlang::sym(cname)), na.rm=TRUE))
      }
      else {
        quo(UQ(func)(UQ(rlang::sym(cname))))
      }
    })

    # Keep final output column names.
    if (!is.null(names(regressors))) {
      regressor_final_output_cols <- names(regressors)
    }
    else {
      regressor_final_output_cols <- regressors
    }

    # But use temporary output column names like r1, r2... We will rename them back before final output.
    # We need this because prophet would garble those column names in the output.
    regressor_output_cols <- paste0("r", 1:length(regressors))

    names(summarise_args) <- regressor_output_cols
  }

  # remove rows with NA time
  df <- df[!is.na(df[[time_col]]), ]

  do_prophet_each <- function(df){
    tryCatch({
      # filter the part of external holidays df for this group.
      holidays_df <- NULL
      if (!is.null(holidays)) {
        holidays_df <- holidays
        for (a_grouped_col in grouped_col) {
          if (!is.null(holidays_df[[a_grouped_col]])) {
            holidays_df <- holidays_df[holidays_df[[a_grouped_col]] == df[[a_grouped_col]][[1]],]
          }
        }
      }
      # filter the part of external cap df (future df) for this group.
      cap_df <- NULL
      if (!is.null(cap) && is.data.frame(cap)) {
        cap_df <- cap
        for (a_grouped_col in grouped_col) {
          if (!is.null(cap_df[[a_grouped_col]])) {
            cap_df <- cap_df[cap_df[[a_grouped_col]] == df[[a_grouped_col]][[1]],]
          }
        }
      }
  
      orig_tz <- lubridate::tz(df[[time_col]]) # Keep original time zone for units smaller than day to set it back right before final output.
      df[[time_col]] <- if (time_unit %in% c("day", "week", "month", "quarter", "year")) {
        # Take care of issue that happened in anomaly detection here for prophet too.
        # In this case, convert (possibly) from POSIXct to Date first.
        # If we did this without converting POSIXct to Date, floor_date works, but later at complete stage,
        # data on day-light-saving days would be skipped, since the times seq.POSIXt gives and floor_date does not match.
        # We give the time column's timezone to as.Date, so that the POSIXct to Date conversion is done
        # based on that timezone.
        lubridate::floor_date(as.Date(df[[time_col]], tz = lubridate::tz(df[[time_col]])), unit = time_unit)
      } else {
        # Set timezone to GMT once.
        # prophet (at 1.0) seems to force timezone info on the time info in the model into GMT.
        # Without this, processing after creating model, such as inner_join with external regressor data are messed up while dealing with different timezones.
        lubridate::with_tz(lubridate::floor_date(df[[time_col]], unit = time_unit), tz="GMT")
      }
  
      # extract holiday df from main df
      if (is.null(holidays_df) && !is.null(holiday_col)) {
        holidays_df <- df %>%
          dplyr::transmute(
            ds = UQ(rlang::sym(time_col)),
            holiday = UQ(rlang::sym(holiday_col))
          ) %>%
          dplyr::group_by(ds) %>%
          dplyr::summarise(holiday = first(holiday[!is.na(holiday)])) %>% # take first non-NA value for aggregation.
          dplyr::filter(!is.na(holiday))
        # If holiday column is logical, create holiday df with only TRUE rows, with single value "Holiday".
        # If it is numeric, first convert to logical (0:FALSE, Others:TRUE), then do the above.
        if (is.logical(holidays_df$holiday) || is.numeric(holidays_df$holiday)) {
          holidays_df <- holidays_df %>% dplyr::mutate(holiday = as.logical(holiday)) %>% dplyr::filter(holiday) %>% dplyr::mutate(holiday = "Holiday")
        }
        # Passing empty dataframe causes prophet error: "Column `ds` is of unsupported type NULL". Set it back to NULL if empty.
        if (nrow(holidays_df) == 0) {
          holidays_df <- NULL
        }
      }
  
      if(!is.null(grouped_col)){
        # drop grouping columns
        df <- df[, !colnames(df) %in% grouped_col]
      }
  
      aggregated_future_data <- NULL
      # Extra regressor case. separate the df into history and future based on the value is filled or not.
      # When value column is not specified (forecast is about number of rows.), we do the history/future separation
      # based on the specified period.
      # Exception is when value column is not specified (forecast is about number of rows.) AND it is test mode.
      # In this case, we treat entire data as history data, and just let test mode logic to separate it into training and test.
      if (!is.null(regressors) && !(is.null(value_col) && test_mode)) {
        # We used to filter NAs on regressor columns here, but now we don't and instead add na.rm to funs.aggregate.regressors.
        # This should pick up more info, and works with function like na_count too.
        #df <- df %>% dplyr::filter(!!!filter_args)
        future_df <- df # keep all rows before df is filtered out to become history data.
        df <- trim_future(df, time_col, value_col, periods, time_unit)
        max_floored_date <- max(df[[time_col]])
        future_df <- future_df %>% dplyr::filter(UQ(rlang::sym(time_col)) > max_floored_date)
  
        if (nrow(future_df) > 0) {
          # TODO: in test mode, this is not really necessary. optimize.
          aggregated_future_data <- future_df %>%
            dplyr::transmute( # Keep only time column and regressor columns in future data frame.
              ds = UQ(rlang::sym(time_col)),
              !!!rlang::syms(unname(regressors)) # unname is necessary to avoid error when regressors is named vector.
            ) %>%
            dplyr::group_by(ds) %>%
            dplyr::summarise(!!!summarise_args)

          # It seems prophet internally removes the rows with NA regressor values for future data, unlike for history data, but to avoid being dependent on that behavior,
          # let's remove them here for future data too.
          if (!is.null(regressor_output_cols)) {
            for (regressor_col in regressor_output_cols) {
              aggregated_future_data <- aggregated_future_data %>% dplyr::filter(!is.na(!!sym(regressor_col)))
            }
          }
        }
      }
      # Even if there is no extra regressor, if holiday column is there, we need to strip future holiday rows.
      # Again, exception is when value column is not specified (forecast is about number of rows.) AND it is test mode.
      # In this case, we treat entire data as history data, and just let test mode logic to separate it into training and test.
      else if (!is.null(holiday_col) && !(is.null(value_col) && test_mode)) {
        df <- trim_future(df, time_col, value_col, periods, time_unit)
      }
      else if(!is.null(value_col)) { # no-extra regressor case. if value column is specified (i.e. value is not number of rows), filter NA rows.
        # df <- df[!is.na(df[[value_col]]), ] # Now we handle NAs with na.rm=TRUE on aggregate function.
      }
  
      # note that prophet only takes columns with predetermined names like ds, y, cap, as input
      aggregated_data <- if (!is.null(value_col) && ("cap" %in% colnames(df))) {
        # preserve cap column if it is there, so that cap argument as future data frame works.
        # apply same aggregation as value to cap.
        grouped_df <- df %>%
          dplyr::transmute(
            ds = UQ(rlang::sym(time_col)),
            value = UQ(rlang::sym(value_col)),
            cap_col = cap,
            !!!rlang::syms(unname(regressors)) # this should be able to handle regressor=NULL case fine.
          ) %>%
          # remove NA so that we do not pass data with NA, NaN, or 0 to prophet, which we are not very sure what would happen.
          # we saw a case where rstan crashes with the last row with 0 y value.
          # dplyr::filter(!is.na(value)) %>% # Commented out, since now we handle NAs with na.rm option of fun.aggregate. This way, extra regressor info for each period is preserved better.
          dplyr::group_by(ds)
        # For common functions that require na.rm=TRUE to handle NA, add it.
        if (is_na_rm_func(fun.aggregate)) {
          grouped_df %>% 
            dplyr::summarise(y = fun.aggregate(value), cap = fun.aggregate(cap_col, na.rm=TRUE), !!!summarise_args)
        }
        else {
          grouped_df %>% 
            dplyr::summarise(y = fun.aggregate(value), cap = fun.aggregate(cap_col), !!!summarise_args)
        }
      } else if (!is.null(value_col)){
        grouped_df <- df %>%
          dplyr::transmute(
            ds = UQ(rlang::sym(time_col)),
            value = UQ(rlang::sym(value_col)),
            !!!rlang::syms(unname(regressors)) # this should be able to handle regressor=NULL case fine.
          ) %>%
          # dplyr::filter(!is.na(value)) %>% # Commented out, since now we handle NAs with na.rm option of fun.aggregate. This way, extra regressor info for each period is preserved better.
          dplyr::group_by(ds)
        if (is_na_rm_func(fun.aggregate)) {
          grouped_df %>% 
            dplyr::summarise(y = fun.aggregate(value, na.rm=TRUE), !!!summarise_args)
        }
        else {
          grouped_df %>% 
            dplyr::summarise(y = fun.aggregate(value), !!!summarise_args)
        }
      } else { # value_col is not specified. The forecast is about number of rows.
        # Note: We ignore cap column in this case for now.
        df %>%
          dplyr::transmute(
            ds = UQ(rlang::sym(time_col)),
            !!!rlang::syms(unname(regressors)) # this should be able to handle regressor=NULL case fine.
          ) %>%
          dplyr::group_by(ds) %>%
          dplyr::summarise(y = n(), !!!summarise_args)
      }
  
      # Fill time column and/or regressor columns as specified by arguments.
      # TODO: Check if this would not have daylight saving days issue we had with anomaly detection.
      if (!is.null(na_fill_type) || !is.null(regressors_na_fill_type)) {
        # complete the date time with NA
        aggregated_data <- aggregated_data %>% complete_date("ds", time_unit = time_unit)
        # fill NAs in y with zoo
        aggregated_data <- aggregated_data %>% dplyr::mutate(y = fill_ts_na(y, ds, type = na_fill_type, val = na_fill_value))
        for (regressor_col in regressor_output_cols) {
          aggregated_data <- aggregated_data %>% dplyr::mutate(!!sym(regressor_col) := fill_ts_na(!!sym(regressor_col), ds, type = !!regressors_na_fill_type, val = !!regressors_na_fill_value))
        }
      }

      # If there is still NAs in regressor columns at this point after aggregation and possible fill, they have to be filtered out to avoid error.
      if (!is.null(regressor_output_cols)) {
        for (regressor_col in regressor_output_cols) {
          aggregated_data <- aggregated_data %>% dplyr::filter(!is.na(!!sym(regressor_col)))
        }
      }
  
      # For example, if time_unit is month or larger, having monthly.seasonality or seasonalities of shorter periods does not make sense.
      if (time_unit %in% c("year")) {
        yearly.seasonality <- FALSE
        quarterly.seasonality <- FALSE
        monthly.seasonality <- FALSE
        weekly.seasonality <- FALSE
        daily.seasonality <- FALSE
      }
      else if (time_unit %in% c("quarter")) {
        quarterly.seasonality <- FALSE
        monthly.seasonality <- FALSE
        weekly.seasonality <- FALSE
        daily.seasonality <- FALSE
      }
      else if (time_unit %in% c("month")) {
        monthly.seasonality <- FALSE
        weekly.seasonality <- FALSE
        daily.seasonality <- FALSE
      }
      else if (time_unit %in% c("week")) {
        weekly.seasonality <- FALSE
        daily.seasonality <- FALSE
      }
      else if (time_unit %in% c("day")) {
        daily.seasonality <- FALSE
      }
      # disabling this logic for now, since setting yearly.seasonality FALSE disables weekly.seasonality too.
      # if (time_unit == "year") { # if time_unit is year (the largest unit), having yearly.seasonality does not make sense.
      #   yearly.seasonality = FALSE
      # }
  
      if (test_mode) {
        # Remove end of aggregated_data as test data to make training data.
  
        # Fill aggregated_data$ds with missing data/time.
        # This is necessary to make forecast period correspond with test period in test mode when there is missing date/time in original aggregated_data$ds.
        # Note that this is only for the purpose of correctly determine where to start test period, and we remove those filled data once that purpose is met.
  
        # Create periodical sequence of time to fill missing date/time
        if (time_unit %in% c("hour", "minute", "second")) { # Use seq.POSIXt for unit smaller than day.
          ts <- seq.POSIXt(as.POSIXct(min(aggregated_data$ds)), as.POSIXct(max(aggregated_data$ds)), by=to_time_unit_for_seq(time_unit))
          if (lubridate::is.Date(aggregated_data$ds)) {
            ts <- as.Date(ts)
          }
        }
        else { # Use seq.Date for unit of day or larger. Using seq.POSIXct for month does not always give first day of month.
          ts <- seq.Date(as.Date(min(aggregated_data$ds)), as.Date(max(aggregated_data$ds)), by=to_time_unit_for_seq(time_unit))
          if (!lubridate::is.Date(aggregated_data$ds)) {
            ts <- as.POSIXct(ts)
          }
        }
        ts_df <- data.frame(ds=ts)
        # ts_df has to be the left-hand side to keep the row order according to time order.
        filled_aggregated_data <- dplyr::full_join(ts_df, aggregated_data, by = c("ds" = "ds"))
        
        training_data <- filled_aggregated_data
        if (periods > nrow(training_data)) {
          stop("The time period set for the Test period is longer than the entire data.")
        }
        training_data <- training_data %>% head(-periods)
  
        # We got correct set of training data by filling missing date/time,
        # but now, filter them out again.
        # By doing so, we affect future table, and skip prediction (interpolation)
        # for all missing date/time, which could be expensive if the training data is sparse.
        # keep the last row even if it does not have training data, to mark the end of training period, which is the start of test period.
        training_data <- training_data %>% dplyr::filter(!is.na(y) | row_number() == n())
      }
      else {
        training_data <- aggregated_data
      }

      if (nrow(training_data) < 2) {
        stop("The aggregated training data has less than 2 rows.")
      }

      # Since we aggregate input time series data with time unit, duration for 2 years worth of training data
      # is a little shorter than 2 years, which makes "auto" yearly seasonality of prophet to decide to turn off yearly seasonality.
      # For this reason, we enable yearly seasonality a little more leniently.
      if (yearly.seasonality == "auto") {
        training_days <- as.numeric(difftime(max(training_data$ds), min(training_data$ds), unit="days"))
        if (time_unit %in% c("quarter")) {
          if (training_days + 31*3 >= 365*2) {
            yearly.seasonality <- TRUE
          }
        }
        else if (time_unit %in% c("month")) {
          if (training_days + 31 >= 365*2) {
            yearly.seasonality <- TRUE
          }
        }
        else if (time_unit %in% c("week")) {
          if (training_days + 7 >= 365*2) {
            yearly.seasonality <- TRUE
          }
        }
        else if (time_unit %in% c("day")) {
          if (training_days + 1 >= 365*2) {
            yearly.seasonality <- TRUE
          }
        }
      }
  
      if (!is.null(cap) && is.data.frame(cap)) {
        # in this case, cap is the future data frame with cap, specified by user.
        # this is a back door to allow user to specify cap column.
        if (!is.null(cap$cap)) {
          growth <- "logistic"
        }
        else {
          growth <- "linear"
          # if future data frame is without cap, use it just as a future data frame.
        }
        m <- prophet::prophet(training_data, fit = FALSE, growth = growth,
                              daily.seasonality = daily.seasonality, weekly.seasonality = weekly.seasonality,
                              yearly.seasonality = yearly.seasonality, holidays = holidays_df, ...)
        if (quarterly.seasonality) {
          m <- prophet::add_seasonality(m,
                          "quarterly",
                          365.25/4,
                          fourier.order=8
                          )
        }
        if (monthly.seasonality) {
          m <- prophet::add_seasonality(m,
                          "monthly",
                          365.25/12,
                          fourier.order=6
                          )
        }
        # add regressors to the model.
        if (!is.null(regressor_output_cols)) {
          for (regressor in regressor_output_cols) {
            m <- add_regressor(m, regressor)
          }
        }
        m <- fit.prophet(m, training_data)
        forecast <- stats::predict(m, cap_df)
      }
      else {
        if (!is.null(cap)) { # set cap if it is there
          training_data[["cap"]] <- cap
          if (!is.null(floor)) { # set floor if it is there
            training_data[["floor"]] <- floor
          }
        }
        if (!is.null(cap)) { # if cap is set, use logistic. otherwise use linear.
          growth <- "logistic"
        }
        else {
          growth <- "linear"
        }
        m <- prophet::prophet(training_data, fit = FALSE, growth = growth,
                              daily.seasonality = daily.seasonality, weekly.seasonality = weekly.seasonality,
                              yearly.seasonality = yearly.seasonality, holidays = holidays_df, ...)
        # Default Fourier order for yearly is 10, and weekly is 3. Picked 8 for quarterly.
        # Picked 8 and 6 for quarterly and monthly so that they are inline with the above,
        # in that roughly the square of the Fourier order is within the same order as the days in the period.
        if (quarterly.seasonality) {
          m <- prophet::add_seasonality(m,
                          "quarterly",
                          365.25/4,
                          fourier.order=8
                          )
        }
        if (monthly.seasonality) {
          m <- prophet::add_seasonality(m,
                          "monthly",
                          365.25/12,
                          fourier.order=6
                          )
        }
        if (!is.null(regressor_output_cols)) {
          for (regressor in regressor_output_cols) {
            m <- add_regressor(m, regressor)
          }
        }
        if (!is.null(holiday_country_names)) {
          m <- add_country_holidays(m, country_name = holiday_country_names)
        }
        m <- fit.prophet(m, training_data)
        if (time_unit == "hour") {
          time_unit_for_future_dataframe = 3600
        }
        else if (time_unit == "minute") {
          time_unit_for_future_dataframe = 60
        }
        else if (time_unit == "second") {
          time_unit_for_future_dataframe = 1
        }
        else {
          time_unit_for_future_dataframe = time_unit
        }

        # make_future_dataframe can't handle periods=0. Work it around.
        if (periods == 0) {
          periods_ <- 1
        }
        else {
          periods_ <- periods
        }

        future <- prophet::make_future_dataframe(m, periods = periods_, freq = time_unit_for_future_dataframe, include_history = include_history) #includes past dates

        if (periods == 0) { # Remove the last extra row in case we passed period 1 instead of 0.
          future <- head(future, -1)
        }

        if (!is.null(regressor_output_cols)) {
          regressor_data <- aggregated_data %>%
            dplyr::select(-y) %>%
            dplyr::bind_rows(aggregated_future_data)
          if (lubridate::is.Date(regressor_data$ds)) { # make ds POSIXct so that inner_join works. TODO: is is possible that future$ds is POSIXlt??
            regressor_data$ds <- as.POSIXct(regressor_data$ds)
          }
          future <- future %>%
            # inner_join to keep only rows with regressor values.
            # this works for test mode too, since aggregated_future_data part is ignored by inner_join.
            dplyr::inner_join(regressor_data, by=c('ds'='ds'))
        }
        if (!is.null(cap)) { # set cap to future table too, if it is there
          future[["cap"]] <- cap
          if (!is.null(floor)) { # set floor if it is there
            future[["floor"]] <- floor
          }
        }
        forecast <- stats::predict(m, future)
      }
      # with prophet 0.2.1, now forecast$ds is POSIXct. Cast it to Date when necessary so that full_join works.
      if (lubridate::is.Date(aggregated_data$ds)) {
        forecast$ds <- as.Date(forecast$ds)
      }
  
      # Add is_test_data column before joining aggregated original data so that the end of forecast is correctly marked as test data.
      if (test_mode) {
        ret <- forecast %>% dplyr::mutate(is_test_data = dplyr::row_number() > n() - periods) # FALSE for training period, TRUE for test period.
      }
      else {
        ret <- forecast
      }
  
      # Join original aggregated dataframe to forecast dataframe.
      # Extra regressor columns will conflinct in names. We add _effect to the ones from forecast.
      # TODO: Can we safely assume that all conflicts are from extra regressors?
      # If there is future part of aggregated data, bind it too so that extra regressor values for future are also in the output.
      if (!is.null(aggregated_future_data)) {
        ret <- ret %>% dplyr::full_join(dplyr::bind_rows(aggregated_data, aggregated_future_data), by = c("ds" = "ds"), suffix = c("_effect", ""))
      }
      else {
        ret <- ret %>% dplyr::full_join(aggregated_data, by = c("ds" = "ds"), suffix = c("_effect", ""))
      }
      # drop cap_scaled column, which is just scaled capacity, which does not seem informative.
      if ("cap_scaled" %in% colnames(ret)) {
        ret <- ret %>% dplyr::select(-cap_scaled)
      }
      # TODO: Maybe we should take average when MCMC is used and there are multiple delta values for each channge point.
      if (!is.numeric(m$changepoints)) { # m$changepoints seems to become numeric single 0 when empty.
        changepoints_df <- data.frame(ds = m$changepoints, trend_change = m$params$delta[1,])
        # m$changepoints is POSIXct. Cast it to Date when original data (aggregated_data$ds) is Date so that left_join works.
        if (lubridate::is.Date(aggregated_data$ds)) {
          changepoints_df$ds <- as.Date(changepoints_df$ds)
        }
        ret <- ret %>% dplyr::left_join(changepoints_df, by = c("ds" = "ds"))
      }
      else {
        # there is no changepoint.
        ret <- ret %>% dplyr::mutate(trend_change = NA_real_)
      }

      if (time_unit %in% c("hour", "minute", "second")) { # Set original timezone back.
        ret <- ret %>% dplyr::mutate(ds = lubridate::with_tz(ds, tz = orig_tz))
      }
  
      # adjust order of output columns
      ret <- ret %>% dplyr::select(ds, y, yhat, yhat_lower, yhat_upper, trend, trend_lower, trend_upper, 
                                   matches('^yearly$'), matches('^yearly_lower$'), matches('^yearly_upper$'),
                                   matches('^quarterly$'), matches('^quarterly_lower$'), matches('^quarterly_upper$'),
                                   matches('^monthly$'), matches('^monthly_lower$'), matches('^monthly_upper$'),
                                   matches('^weekly$'), matches('^weekly_lower$'), matches('^weekly_upper$'),
                                   matches('^daily$'), matches('^daily_lower$'), matches('^daily_upper$'),
                                   matches('^cap.y$'), matches('^cap.x$'),
                                   dplyr::everything())
      if (test_mode) { # Bring is_test_data column to the last
        ret <- ret %>% dplyr::select(-is_test_data, is_test_data)
      }
  
      # revive original column names (time_col, value_col)
      if (time_col != "ds") { # if time_col happens to be "ds", do not do this, since it will make the column name "ds.new".
        colnames(ret)[colnames(ret) == "ds"] <- avoid_conflict(colnames(ret), time_col)
      }
      if (is.null(value_col)) {
        value_col <- "count"
      }
      if (value_col != "y") { # if value_col happens to be "y", do not do this, since it will make the column name "y.new".
        colnames(ret)[colnames(ret) == "y"] <- avoid_conflict(colnames(ret), value_col)
      }

      # Replace temporary regressor column names (r1, r2, ...) with the name that includes original column names.
      if (!is.null(regressor_final_output_cols)) {
        i <- 1
        for (output_col in regressor_final_output_cols) {
          tmp_col <- paste0("r", i)
          tmp_effect_col <- paste0("r", i, "_effect")
          tmp_upper_col <- paste0("r", i, "_upper")
          tmp_lower_col <- paste0("r", i, "_lower")
          output_effect_col <- paste0(output_col, "_effect")
          output_upper_col <- paste0(output_col, "_upper")
          output_lower_col <- paste0(output_col, "_lower")
          colnames(ret)[colnames(ret) == tmp_col] <- avoid_conflict(colnames(ret), output_col)
          colnames(ret)[colnames(ret) == tmp_effect_col] <- avoid_conflict(colnames(ret), output_effect_col)
          colnames(ret)[colnames(ret) == tmp_upper_col] <- avoid_conflict(colnames(ret), output_upper_col)
          colnames(ret)[colnames(ret) == tmp_lower_col] <- avoid_conflict(colnames(ret), output_lower_col)
          i <- i + 1
        }
      }
  
      # adjust column name style
      colnames(ret)[colnames(ret) == "yhat"] <- avoid_conflict(colnames(ret), "forecasted_value")
      colnames(ret)[colnames(ret) == "yhat_upper"] <- avoid_conflict(colnames(ret), "forecasted_value_high")
      colnames(ret)[colnames(ret) == "yhat_lower"] <- avoid_conflict(colnames(ret), "forecasted_value_low")
      colnames(ret)[colnames(ret) == "trend_upper"] <- avoid_conflict(colnames(ret), "trend_high")
      colnames(ret)[colnames(ret) == "trend_lower"] <- avoid_conflict(colnames(ret), "trend_low")
      colnames(ret)[colnames(ret) == "yearly_upper"] <- avoid_conflict(colnames(ret), "yearly_high")
      colnames(ret)[colnames(ret) == "yearly_lower"] <- avoid_conflict(colnames(ret), "yearly_low")
      colnames(ret)[colnames(ret) == "weekly_upper"] <- avoid_conflict(colnames(ret), "weekly_high")
      colnames(ret)[colnames(ret) == "weekly_lower"] <- avoid_conflict(colnames(ret), "weekly_low")
      colnames(ret)[colnames(ret) == "cap.x"] <- avoid_conflict(colnames(ret), "cap_forecast")
      colnames(ret)[colnames(ret) == "cap.y"] <- avoid_conflict(colnames(ret), "cap_model")
      if (output == "data") { # Pre-5.5 backward compatibility mode.
        ret
      }
      else {
        regressor_name_map <- regressor_final_output_cols
        names(regressor_name_map) <- regressor_output_cols
        model <- list(result=ret, model=m, test_mode=test_mode, value_col=value_col, regressor_name_map=regressor_name_map)
        class(model) <- c("prophet_exploratory", class(model))
        model
      }
    }, error = function(e){
      if(length(grouped_col) > 0) {
        # keep going if the error is caused by subset of
        # grouped data frame, to show result of data frames that succeed.
        # For debugging purpose, return one row with error message in note column.
        if (output == "data") { # Pre-5.5 backward compatibility mode.
          data.frame(note = e$message)
        }
        else {
          class(e) <- c("prophet_exploratory", class(e))
          e
        }
      } else {
        stop(e)
      }
    })
  }

  # Calculation is executed in each group.
  # Storing the result in this name_col and
  # unnesting the result.
  # name_col is not conflicting with grouping columns
  # thanks to avoid_conflict that is used before,
  # this doesn't overwrite grouping columns.
  if (output == "data") { # Pre-5.5 backward compatibility mode.
    tmp_col <- avoid_conflict(colnames(df), "tmp_col")
    ret <- df %>%
      dplyr::do_(.dots=setNames(list(~do_prophet_each(.)), tmp_col)) %>%
      dplyr::ungroup()
    ret <- ret %>% unnest_with_drop(!!rlang::sym(tmp_col))
    if (length(grouped_col) > 0) {
      ret <- ret %>% dplyr::group_by(!!!rlang::syms(grouped_col))
    }
    ret
  }
  else {
    do_on_each_group(df, do_prophet_each, name = "model", with_unnest = FALSE)
  }
}

#' @export
glance.prophet_exploratory <- function(x) {
  if ("error" %in% class(x)) {
    return(data.frame(Note = x$message))
  }
  else {
    if (x$test_mode) {
      x$result %>% dplyr::summarize(RMSE=exploratory::rmse(!!rlang::sym(x$value_col), forecasted_value, is_test_data), MAE=exploratory::mae(!!rlang::sym(x$value_col), forecasted_value, is_test_data), `MAPE (Ratio)`=exploratory::mape(!!rlang::sym(x$value_col), forecasted_value, is_test_data), MASE=exploratory::mase(!!rlang::sym(x$value_col), forecasted_value, is_test_data), `R Squared`=r_squared(!!rlang::sym(x$value_col), forecasted_value, is_test_data=is_test_data), `Number of Rows for Training`=sum(!is_test_data), `Number of Rows for Test`=sum(is_test_data))
    }
    else {
      x$result %>% dplyr::summarize(RMSE=exploratory::rmse(!!rlang::sym(x$value_col), forecasted_value, !is.na(!!rlang::sym(x$value_col))), MAE=exploratory::mae(!!rlang::sym(x$value_col), forecasted_value, !is.na(!!rlang::sym(x$value_col))), `MAPE (Ratio)`=exploratory::mape(!!rlang::sym(x$value_col), forecasted_value, !is.na(!!rlang::sym(x$value_col))), `R Squared`=r_squared(!!rlang::sym(x$value_col), forecasted_value, is_test_data=!is.na(!!rlang::sym(x$value_col))), `Number of Rows`=sum(!is.na(!!rlang::sym(x$value_col))))
    }
  }
}

#' @export
tidy.prophet_exploratory <- function(x, type="result") {
  if ("error" %in% class(x)) { # Filter error case. We might need to add glance to display the error message.
    return(data.frame())
  }
  if (type == "result") {
    x$result
  }
  else if (type == "coef") { # Returns coefficients (beta) of external regressors and seasonalities.
    # Keep only training data for reverse calculation of beta, as standard deviations of effects.
    if (x$test_mode) {
      res <- x$result %>% dplyr::filter(!is_test_data)
    }
    else {
      if (is.null(x$value_col)) {
        res <- x$result %>% dplyr::filter(!is.na(count))
      }
      else {
        res <- x$result %>% dplyr::filter(!is.na(!!rlang::sym(x$value_col)))
      }
    }
    # Calculate SDs of effects of regressors and seasonalities. For regressors, this equals to (absolute value of) beta by definition.
    # Reference: https://github.com/facebook/prophet/issues/928
    res <- res %>%
      dplyr::select(matches('(_effect$|^yearly$|^quarterly$|^monthly$|^weekly$|^daily$|^hourly$|^holidays$)'))

    # Check if multiple columns are left before further calculation,
    # since no column would result in error, and importance for only one column would be rather pointless. 
    # If not, returning the empty data frame would be handled by the chart as empty data case.
    if (length(colnames(res)) > 1) {
      res <- res %>%
        dplyr::summarise_all(.funs=~sd(.,na.rm=TRUE)) %>%
        tidyr::pivot_longer(everything(), names_to='Variable', values_to='Importance') %>%
        dplyr::mutate(Variable = dplyr::recode(Variable, yearly='Yearly', quarterly='Quarterly', monthly='Monthly', weekly='Weekly', daily='Daily', hourly='Hourly', holidays='Holidays')) %>%
        dplyr::mutate(Variable = stringr::str_remove(Variable, '_effect$'))
    }
    else {
      res <- data.frame()
    }
    res
  }
}
exploratory-io/exploratory_func documentation built on April 23, 2024, 9:15 p.m.