R/build_forecasting_agent.R

Defines functions build_forecasting_agent node_func_human_review node_explain_forecasting_code node_fix_forecasting_code node_execute_forecasting_code node_create_forecasting_code node_recommend_forecasting_steps create_coding_agent_graph check_forecasting_dependencies forecast_ts plot_forecast get_dataframe_summary make_command interrupt StateGraph make_edge make_node

Documented in build_forecasting_agent

###############################################################################
## 1) SIMPLE STATE GRAPH IMPLEMENTATION
###############################################################################

make_node <- function(func, name = NULL) {
  list(func = func, name = name)
}

make_edge <- function(from, to, condition = NULL, label = NULL) {
  list(from = from, to = to, condition = condition, label = label)
}

StateGraph <- function() {
  graph_env <- new.env(parent = emptyenv())
  graph_env$nodes <- list()
  graph_env$edges <- list()
  graph_env$entry_point <- NULL

  graph_env$add_node <- function(name, func) {
    graph_env$nodes[[name]] <- make_node(func, name)
  }

  graph_env$add_edge <- function(from, to) {
    edge <- make_edge(from, to)
    graph_env$edges <- c(graph_env$edges, list(edge))
  }

  graph_env$add_conditional_edges <- function(node_name, condition_fun, mapping_list) {
    for (lbl in names(mapping_list)) {
      e <- make_edge(
        from = node_name,
        to = mapping_list[[lbl]],
        condition = condition_fun,
        label = lbl
      )
      graph_env$edges <- c(graph_env$edges, list(e))
    }
  }

  graph_env$set_entry_point <- function(node_name) {
    graph_env$entry_point <- node_name
  }

  END_NODE_NAME <- "__end__"

  graph_env$compile <- function(checkpointer = NULL) {
    function(state, verbose = FALSE) {
      current_node <- if (!is.null(state$current_node)) {
        state$current_node
      } else {
        graph_env$entry_point
      }

      while (!identical(current_node, END_NODE_NAME)) {
        node_obj <- graph_env$nodes[[current_node]]
        if (is.null(node_obj)) {
          stop(sprintf("Node '%s' not found in graph.", current_node))
        }

        # Execute node with verbose parameter if the function accepts it
        if ("verbose" %in% names(formals(node_obj$func))) {
          result <- node_obj$func(state, verbose = verbose)
        } else {
          result <- node_obj$func(state)
        }

        # Merge returned list elements
        if (is.list(result)) {
          for (n in names(result)) {
            state[[n]] <- result[[n]]
          }
        }

        # Handle Command-like objects with goto & update
        if (!is.null(result$goto)) {
          next_node <- result$goto
          if (is.list(result$update)) {
            for (k in names(result$update)) {
              state[[k]] <- result$update[[k]]
            }
          }

          if (identical(next_node, END_NODE_NAME)) {
            current_node <- END_NODE_NAME
            break
          } else {
            current_node <- next_node
            if (!is.null(checkpointer)) {
              checkpointer(state, current_node)
            }
            next
          }
        }

        # Look for edges if no direct goto
        edges_from_node <- Filter(function(e) e$from == current_node, graph_env$edges)
        if (length(edges_from_node) == 0) {
          current_node <- END_NODE_NAME
          break
        }

        # Handle single unconditional edge
        if (length(edges_from_node) == 1 && is.null(edges_from_node[[1]]$condition)) {
          current_node <- edges_from_node[[1]]$to
          if (identical(current_node, END_NODE_NAME)) break
          if (!is.null(checkpointer)) checkpointer(state, current_node)
          next
        }

        # Handle conditional edges
        chosen_label <- edges_from_node[[1]]$condition(state)

        edge_matched <- NULL
        for (e in edges_from_node) {
          if (!is.null(e$label) && identical(e$label, chosen_label)) {
            edge_matched <- e
            break
          }
        }
        if (is.null(edge_matched)) {
          stop("No matching edge label found!")
        }

        current_node <- edge_matched$to
        if (identical(current_node, END_NODE_NAME)) break
        if (!is.null(checkpointer)) checkpointer(state, current_node)
      }

      state$current_node <- END_NODE_NAME
      invisible(state)
    }
  }

  list(
    add_node              = graph_env$add_node,
    add_edge              = graph_env$add_edge,
    add_conditional_edges = graph_env$add_conditional_edges,
    set_entry_point       = graph_env$set_entry_point,
    compile               = graph_env$compile,
    END_NODE_NAME         = END_NODE_NAME
  )
}

###############################################################################
## 2) HELPER FUNCTIONS
###############################################################################

interrupt <- function(value) {
  message("\n", value, "\n")
  readline("Enter your response: ")
}

make_command <- function(goto = NULL, update = list()) {
  list(goto = goto, update = update)
}

###############################################################################
## 3) NODE FUNCTIONS
###############################################################################

# A simple summary function for a data frame
get_dataframe_summary <- function(df, n_sample = 30, skip_stats = FALSE) {
  info <- capture.output(str(df))
  summ <- capture.output(summary(df))
  head_txt <- capture.output(head(df, n_sample))
  types <- paste(sapply(df, function(col) paste(class(col), collapse = ", ")), collapse = ", ")

  if (!skip_stats) {
    summary_text <- paste(
      sprintf("Shape: %d rows x %d columns", nrow(df), ncol(df)),
      paste("Column types:", types),
      "Head:",
      paste(head_txt, collapse = "\n"),
      "Summary:",
      paste(summ, collapse = "\n"),
      "Structure:",
      paste(info, collapse = "\n"),
      sep = "\n\n"
    )
  } else {
    summary_text <- paste(
      sprintf("Shape: %d rows x %d columns", nrow(df), ncol(df)),
      paste("Column types:", types),
      "Head:",
      paste(head_txt, collapse = "\n"),
      sep = "\n\n"
    )
  }
  summary_text
}

###############################################################################
## PLOTTING FUNCTION
###############################################################################

plot_forecast <- function(data, mode = "light", line_width = 2) {

  # Check for required packages (plotly, ggplot2)
  get_suggested("plotly")
  get_suggested("ggplot2")

  # Convert input to data frame if needed
  if (!is.data.frame(data)) {
    data <- as.data.frame(data)
  }

  # Check for 'id_col' in the data frame
  id_col <- if ("item_id" %in% names(data)) "item_id" else "id"
  if (!id_col %in% names(data)) {
    data[[id_col]] <- seq_len(nrow(data))
  }

  # Define columns
  date_col <- "date"
  value_col <- "value"
  conf_lo_col <- "conf_lo"
  conf_hi_col <- "conf_hi"

  # Define colors
  if (mode == "dark") {
    background_color <- "black"
    text_color <- "white"
    fill_color <- "rgba(255, 165, 0, 0.3)"
    line_color <- "orange"
    actual_color <- "gray"
  } else {
    background_color <- "white"
    text_color <- "black"
    fill_color <- "rgba(173, 216, 230, 0.7)"
    line_color <- "blue"
    actual_color <- "black"
  }

  # Create base plot
  p <- plotly::plot_ly()

  unique_ids <- unique(data[[id_col]])

  for (i in seq_along(unique_ids)) {
    item_id <- unique_ids[i]
    group <- data[data[[id_col]] == item_id, ]

    # Infer actual vs forecast based on NA in confidence intervals
    actual_data <- group[is.na(group[[conf_lo_col]]) & is.na(group[[conf_hi_col]]), ]
    forecast_data <- group[!is.na(group[[conf_lo_col]]) | !is.na(group[[conf_hi_col]]), ]

    # Add actual line
    p <- p %>% plotly::add_trace(
      data = actual_data,
      x = ~get(date_col),
      y = ~get(value_col),
      type = 'scatter',
      mode = 'lines',
      name = paste('Actual -', item_id),
      line = list(color = actual_color, width = line_width),
      visible = i == 1
    )

    # Add forecast line
    p <- p %>% plotly::add_trace(
      data = forecast_data,
      x = ~get(date_col),
      y = ~get(value_col),
      type = 'scatter',
      mode = 'lines',
      name = paste('Forecast -', item_id),
      line = list(color = line_color, width = line_width),
      visible = i == 1
    )

    # Add upper CI
    p <- p %>% plotly::add_trace(
      data = forecast_data,
      x = ~get(date_col),
      y = ~get(conf_hi_col),
      type = 'scatter',
      mode = 'lines',
      name = paste('Conf_High -', item_id),
      line = list(width = 0),
      showlegend = FALSE,
      visible = i == 1
    )

    # Add lower CI
    p <- p %>% plotly::add_trace(
      data = forecast_data,
      x = ~get(date_col),
      y = ~get(conf_lo_col),
      type = 'scatter',
      mode = 'lines',
      name = paste('Conf_Low -', item_id),
      fill = 'tonexty',
      fillcolor = fill_color,
      line = list(width = 0),
      showlegend = FALSE,
      visible = i == 1
    )
  }

  # Dropdown (4 traces per ID)
  dropdown_buttons <- lapply(seq_along(unique_ids), function(i) {
    vis <- rep(FALSE, length(unique_ids) * 4)
    vis[((i - 1) * 4 + 1):((i - 1) * 4 + 4)] <- TRUE
    list(
      label = as.character(unique_ids[i]),
      method = "update",
      args = list(
        list(visible = vis),
        list(title = paste("Forecast -", unique_ids[i]))
      )
    )
  })

  # Layout
  p %>% plotly::layout(
    title = paste("Forecast -", unique_ids[1]),
    xaxis = list(title = date_col),
    yaxis = list(title = value_col),
    plot_bgcolor = background_color,
    paper_bgcolor = background_color,
    font = list(color = text_color),
    updatemenus = list(
      list(
        active = 0,
        buttons = dropdown_buttons,
        x = 0,
        xanchor = "left",
        y = 1.15,
        yanchor = "top"
      )
    ),
    showlegend = FALSE
  )
}

###############################################################################
## FORECASTING FUNCTION
###############################################################################

forecast_ts <- function(
    data,
    value,
    date,
    group      = NULL,  # ID column (e.g., product_id, store_id)
    horizon    = NULL,  # Determined automatically if not specified
    conf_level   = 0.95,  # Confidence level for prediction intervals
    ...
) {

  # 1. Ensure required packages (Suggests:) are available
  check_forecasting_dependencies()

  # -- 3. Column symbols -------------------------------------------------------
  value_sym <- rlang::sym(value)
  date_sym  <- rlang::sym(date)
  group_sym <- if (!is.null(group)) rlang::sym(group) else NULL

  # -- 4. Horizon detection ----------------------------------------------------
  if (is.null(horizon)) {
    dates <- tryCatch(
      if (is.character(data[[date]]))
        lubridate::parse_date_time(data[[date]],
                                   orders = c("ymd", "dmy", "mdy", "Ymd"))
      else data[[date]],
      error = function(e) NULL
    )

    if (!is.null(dates) && length(dates) > 1) {
      time_diff <- as.numeric(difftime(dates[2], dates[1], units = "days"))
      horizon <- dplyr::case_when(
        time_diff <= 1                    ~ 30,   # daily
        time_diff >= 6  & time_diff <= 8  ~ 13,   # weekly
        time_diff >= 28 & time_diff <= 31 ~ 12,   # monthly
        time_diff >= 89 & time_diff <= 92 ~ 8,    # quarterly
        time_diff >= 360 & time_diff <= 370 ~ 5,  # yearly
        nrow(data) < 100                 ~ ceiling(nrow(data) * 0.20),
        TRUE                              ~ ceiling(nrow(data) * 0.10)
      )
    } else horizon <- 12
  }

  # -- 5. Standardised naming & aggregation -------------------------------------

  if (!is.null(group)) {
    # Warn if group column is numeric
    if (is.numeric(data[[group]])) {
      warning("Group column is numeric. Converting to character for proper grouping.")
    }

    data_renamed <- data %>%
      rename(
        value_col = !!value_sym,
        date_col  = !!date_sym,
        group_col = !!group_sym
      ) %>%
      mutate(group_col = as.character(group_col))
  } else {
    data_renamed <- data %>%
      rename(
        value_col = !!value_sym,
        date_col  = !!date_sym
      ) %>%
      mutate(group_col = "ALL_GROUPS")
  }

  # -- 5b. Aggregate to ensure one row per group-date ---------------------------
  data_renamed <- data_renamed %>%
    group_by(group_col, date_col) %>%
    summarise(value_col = sum(value_col, na.rm = TRUE), .groups = "drop")

  # -- 6. Future frame & split -------------------------------------------------
  full_data_tbl <- data_renamed %>%
    select(group_col, date_col, value_col) %>%
    group_by(group_col) %>%
    future_frame(.date_var = date_col, .length_out = horizon, .bind_data = TRUE) %>%
    ungroup() %>%
    mutate(id = forcats::fct_drop(group_col))

  data_prepared_tbl <- full_data_tbl %>% filter(!is.na(value_col))
  future_tbl        <- full_data_tbl %>% filter(is.na(value_col))

  splits <- data_prepared_tbl %>%
    timetk::time_series_split(
      date_var   = date_col,
      assess     = min(horizon, floor(nrow(data_prepared_tbl) * 0.20)),
      cumulative = TRUE
    )

  # -- 7. Recipes --------------------------------------------------------------
  recipe_spec_1 <- recipes::recipe(value_col ~ ., training(splits)) %>%
    recipes::step_rm(matches("^$")) %>%  # Remove unnamed columns
    timetk::step_timeseries_signature(date_col) %>%
    recipes::step_rm(matches("(.iso$)|(.xts$)|(day)|(hour)|(minute)|(second)|(am.pm)")) %>%
    recipes::step_zv(all_predictors()) %>%  # This removes zero-variance predictors
    recipes::step_normalize(all_numeric_predictors(), -all_outcomes()) %>%
    recipes::step_mutate(date_col_week = factor(date_col_week, ordered = TRUE)) %>%
    recipes::step_dummy(all_nominal(), one_hot = TRUE)

  recipe_spec_2 <- recipe_spec_1 %>% update_role(date_col, new_role = "ID")

  # -- 8. Models & workflows ---------------------------------------------------
  model_list <- list()

  ## 8.1 Prophet
  wflw_fit_prophet <- workflow() %>%
    workflows::add_model(
      prophet_reg(
        seasonality_daily  = FALSE,
        seasonality_weekly = ifelse(horizon >= 7, TRUE, FALSE),
        seasonality_yearly = TRUE
      ) %>% set_engine("prophet") %>% set_mode("regression")
    ) %>%
    workflows::add_recipe(recipe_spec_1) %>%
    fit(training(splits))
  model_list <- c(model_list, list(wflw_fit_prophet))

  ## 8.2 XGBoost
  wflw_fit_xgboost <- workflow() %>%
    workflows::add_model(
      boost_tree() %>% set_engine("xgboost") %>% set_mode("regression")
    ) %>%
    workflows::add_recipe(recipe_spec_2) %>%
    fit(training(splits))
  model_list <- c(model_list, list(wflw_fit_xgboost))

  ## 8.3 Random Forest
  wflw_fit_rf <- workflow() %>%
    workflows::add_model(
      rand_forest() %>% set_engine("ranger") %>% set_mode("regression")
    ) %>%
    workflows::add_recipe(recipe_spec_2) %>%
    fit(training(splits))
  model_list <- c(model_list, list(wflw_fit_rf))

  ## 8.4 SVM
  wflw_fit_svm <- workflow() %>%
    workflows::add_model(
      svm_rbf() %>% set_engine("kernlab") %>% set_mode("regression")
    ) %>%
    workflows::add_recipe(recipe_spec_2) %>%
    fit(training(splits))
  model_list <- c(model_list, list(wflw_fit_svm))

  ## 8.5 Prophet Boost
  wflw_fit_prophet_boost <- workflow() %>%
    workflows::add_model(
      prophet_boost(
        seasonality_daily  = FALSE,
        seasonality_weekly = ifelse(horizon >= 7, TRUE, FALSE),
        seasonality_yearly = TRUE
      ) %>% set_engine("prophet_xgboost") %>% set_mode("regression")
    ) %>%
    workflows::add_recipe(recipe_spec_1) %>%
    fit(training(splits))
  model_list <- c(model_list, list(wflw_fit_prophet_boost))

  # -- 9. Ensemble (mean) ------------------------------------------------------
  submodels_tbl        <- do.call(modeltime_table, model_list)
  ensemble_fit_mean    <- submodels_tbl %>% ensemble_average(type = "mean")
  ensemble_tbl         <- modeltime_table(ensemble_fit_mean)
  ensemble_calibrated  <- ensemble_tbl %>% modeltime_calibrate(testing(splits))
  ensemble_refit       <- ensemble_calibrated %>% modeltime_refit(data_prepared_tbl)

  # -- 10. Forecast ------------------------------------------------------------
  forecast_tbl <- ensemble_refit %>%
    modeltime::modeltime_forecast(
      new_data    = future_tbl,
      actual_data = data_prepared_tbl,
      keep_data   = TRUE,
      conf_level  = conf_level
    ) %>%
    select(id, .index, .value, .conf_lo, .conf_hi) %>%
    rename(
      date    = .index,
      value   = .value,
      conf_lo = .conf_lo,
      conf_hi = .conf_hi
    )

  return(forecast_tbl)
}

###############################################################################
## PACKAGES FUNCTION
###############################################################################

check_forecasting_dependencies <- function() {
  # Base / utils
  capture.output <- get_suggested("utils", "capture.output")
  head <- get_suggested("base", "head")
  id <- get_suggested("base", "id")  # fallback placeholder
  na.omit <- get_suggested("stats", "na.omit")

  # dplyr
  mutate <- get_suggested("dplyr", "mutate")
  filter <- get_suggested("dplyr", "filter")
  group_by <- get_suggested("dplyr", "group_by")
  ungroup <- get_suggested("dplyr", "ungroup")
  select <- get_suggested("dplyr", "select")
  rename <- get_suggested("dplyr", "rename")
  summarise <- get_suggested("dplyr", "summarise")
  bind_rows <- get_suggested("dplyr", "bind_rows")

  # tidyr, purrr, magrittr, tibble, forcats
  matches <- get_suggested("tidyr", "matches")
  map <- get_suggested("purrr", "map")
  `%>%` <- get_suggested("magrittr", "%>%")
  tibble <- get_suggested("tibble", "tibble")
  fct_drop <- get_suggested("forcats", "fct_drop")

  # lubridate
  parse_date_time <- get_suggested("lubridate", "parse_date_time")
  year <- get_suggested("lubridate", "year")
  month <- get_suggested("lubridate", "month")
  day <- get_suggested("lubridate", "day")

  # timetk
  future_frame <- get_suggested("timetk", "future_frame")

  # rlang
  sym <- get_suggested("rlang", "sym")

  # jsonlite
  fromJSON <- get_suggested("jsonlite", "fromJSON")
  toJSON <- get_suggested("jsonlite", "toJSON")

  # recipes
  recipe <- get_suggested("recipes", "recipe")
  step_rm <- get_suggested("recipes", "step_rm")
  step_timeseries_signature <- get_suggested("recipes", "step_timeseries_signature")
  step_zv <- get_suggested("recipes", "step_zv")
  step_normalize <- get_suggested("recipes", "step_normalize")
  step_mutate <- get_suggested("recipes", "step_mutate")
  step_dummy <- get_suggested("recipes", "step_dummy")
  update_role <- get_suggested("recipes", "update_role")
  all_predictors <- get_suggested("recipes", "all_predictors")
  all_outcomes <- get_suggested("recipes", "all_outcomes")
  all_numeric_predictors <- get_suggested("recipes", "all_numeric_predictors")
  all_nominal <- get_suggested("recipes", "all_nominal")

  # parsnip
  boost_tree <- get_suggested("parsnip", "boost_tree")
  rand_forest <- get_suggested("parsnip", "rand_forest")
  svm_rbf <- get_suggested("parsnip", "svm_rbf")
  set_engine <- get_suggested("parsnip", "set_engine")
  set_mode <- get_suggested("parsnip", "set_mode")
  prophet_reg <- get_suggested("parsnip", "prophet_reg")
  prophet_boost <- get_suggested("parsnip", "prophet_boost")

  # workflows
  workflow <- get_suggested("workflows", "workflow")
  add_model <- get_suggested("workflows", "add_model")
  add_recipe <- get_suggested("workflows", "add_recipe")

  # rsample
  training <- get_suggested("rsample", "training")
  testing <- get_suggested("rsample", "testing")
  time_series_split <- get_suggested("rsample", "time_series_split")

  # yardstick
  rmse <- get_suggested("yardstick", "rmse")

  # modeltime
  modeltime_table <- get_suggested("modeltime", "modeltime_table")
  modeltime_calibrate <- get_suggested("modeltime", "modeltime_calibrate")
  modeltime_refit <- get_suggested("modeltime", "modeltime_refit")
  modeltime_forecast <- get_suggested("modeltime", "modeltime_forecast")

  # modeltime.ensemble
  ensemble_average <- get_suggested("modeltime.ensemble", "ensemble_average")

  # plotly
  plot_ly <- get_suggested("plotly", "plot_ly")
  add_trace <- get_suggested("plotly", "add_trace")
  layout <- get_suggested("plotly", "layout")

  # Model engines
  xgb_train <- get_suggested("xgboost", "xgb.train")
  ranger <- get_suggested("ranger", "ranger")
  ksvm <- get_suggested("kernlab", "ksvm")

  invisible(TRUE)
}

###############################################################################
## 1) GENERIC GRAPH BUILDER (Equivalent to create_coding_agent_graph in Python)
###############################################################################

create_coding_agent_graph <- function(
    node_functions,
    recommended_steps_node_name,
    create_code_node_name,
    execute_code_node_name,
    fix_code_node_name,
    explain_code_node_name,
    error_key,
    max_retries_key = "max_retries",
    retry_count_key = "retry_count",
    human_validation = FALSE,
    human_review_node_name = "human_review",
    checkpointer = NULL,
    bypass_recommended_steps = FALSE,
    bypass_explain_code = FALSE
) {

  workflow <- StateGraph()

  # Always add create, execute, and fix nodes
  workflow$add_node(create_code_node_name, node_functions[[create_code_node_name]])
  workflow$add_node(execute_code_node_name, node_functions[[execute_code_node_name]])
  workflow$add_node(fix_code_node_name, node_functions[[fix_code_node_name]])

  # Conditionally add the recommended-steps node
  if (!bypass_recommended_steps) {
    workflow$add_node(recommended_steps_node_name, node_functions[[recommended_steps_node_name]])
  }

  # Conditionally add the human review node
  if (human_validation) {
    workflow$add_node(human_review_node_name, node_functions[[human_review_node_name]])
  }

  # Conditionally add the explanation node
  if (!bypass_explain_code) {
    workflow$add_node(explain_code_node_name, node_functions[[explain_code_node_name]])
  }

  # Set the entry point
  entry_point <- if (bypass_recommended_steps) create_code_node_name else recommended_steps_node_name
  workflow$set_entry_point(entry_point)

  if (!bypass_recommended_steps) {
    workflow$add_edge(recommended_steps_node_name, create_code_node_name)
  }

  workflow$add_edge(create_code_node_name, execute_code_node_name)
  workflow$add_edge(fix_code_node_name, execute_code_node_name)

  # Helper to check for error and retry possibility
  error_and_can_retry <- function(s) {
    err <- s[[error_key]]
    retr <- s[[retry_count_key]]
    maxr <- s[[max_retries_key]]
    !is.null(err) && !is.null(retr) && !is.null(maxr) && (retr < maxr)
  }

  if (human_validation) {
    workflow$add_conditional_edges(
      execute_code_node_name,
      function(s) {
        if (error_and_can_retry(s)) "fix_code" else "human_review"
      },
      list(
        human_review = human_review_node_name,
        fix_code = fix_code_node_name
      )
    )
  } else {
    if (!bypass_explain_code) {
      workflow$add_conditional_edges(
        execute_code_node_name,
        function(s) {
          if (error_and_can_retry(s)) "fix_code" else "explain_code"
        },
        list(
          fix_code = fix_code_node_name,
          explain_code = explain_code_node_name
        )
      )
    } else {
      workflow$add_conditional_edges(
        execute_code_node_name,
        function(s) {
          if (error_and_can_retry(s)) "fix_code" else "END"
        },
        list(
          fix_code = fix_code_node_name,
          END = workflow$END_NODE_NAME
        )
      )
    }
  }

  if (!bypass_explain_code) {
    workflow$add_edge(explain_code_node_name, workflow$END_NODE_NAME)
  }

  # Compile the workflow
  if (human_validation && !is.null(checkpointer)) {
    app <- workflow$compile(checkpointer = checkpointer)
  } else {
    app <- workflow$compile()
  }
  app
}

###############################################################################
## NODE FUNCTIONS FOR TIME SERIES FORECASTING
###############################################################################

node_recommend_forecasting_steps <- function(model, verbose = FALSE) {
  function(state) {

    # -- 1. Helper Functions
    `%||%` <- function(a, b) {
      # -- 1. Input Validation
      if (missing(a) || missing(b)) {
        stop("Both arguments must be provided to the %||% operator")
      }

      # -- 2. Check for NULL --
      if (is.null(a)) return(b)

      # -- 3. Handle Zero-Length Vectors
      if (length(a) == 0) return(b)

      # -- 4. Check for Empty Strings (with whitespace)
      if (is.character(a)) {
        if (all(trimws(a) == "")) return(b)
      }

      # -- 5. Handle Special Cases ------
      if (is.na(a) && !is.nan(a)) return(b)  # NA (but not NaN)
      if (identical(a, logical(0))) return(b)
      if (identical(a, numeric(0))) return(b)
      if (identical(a, integer(0))) return(b)
      if (identical(a, character(0))) return(b)

      # -- 6. Handle Data Frames and Matrices ----------
      if (is.data.frame(a) || is.matrix(a)) {
        if (nrow(a) == 0 || ncol(a) == 0) return(b)
      }

      # -- 7. Handle Lists ----
      if (is.list(a)) {
        if (length(a) == 0) return(b)
        if (all(sapply(a, is.null))) return(b)
      }

      # -- 8. Default Case ----
      return(a)
    }
    # -- 2. Console Output -
    if (verbose) message("--- TIME SERIES FORECASTING AGENT ----")
    if (verbose) message("    * RECOMMEND FORECASTING STEPS\n")

    # -- 4. Data Preparation ---------
    if (is.data.frame(state$data_raw)) {
      df <- state$data_raw
    } else if (is.list(state$data_raw)) {
      df <- as.data.frame(state$data_raw)
    } else {
      stop("state$data_raw must be a data.frame or list convertible to data.frame")
    }

    # -- 5. Input Collection ---------
    user_instructions <- state$user_instructions %||% ""
    previous_steps <- state$recommended_steps %||% ""
    all_datasets_summary <- get_dataframe_summary(df, skip_stats = TRUE)

    # -- 6. Prompt Construction ------
    prompt <- sprintf(
      "You are the **Chief Forecasting Supervisor** overseeing the initial scoping of a time-series forecasting project. Your job is to define a **clear, rigorous forecasting blueprint**, referencing two pre-implemented utility functions:

  * `forecast_ts()` - accepts (data, value, date, group, horizon, conf_level)
    and returns a forecast table with columns: id, date, value, conf_lo, conf_hi

  * `plot_forecast()` - takes the forecast table and produces an interactive Plotly visualization.

Write a structured, professional-grade **FORECASTING BLUEPRINT** based on the user's intent, any prior recommendations, and the dataset characteristics provided.

#######################################################################
# USER REQUEST
%s

# PRIOR RECOMMENDATIONS (if available)
%s

# DATA OVERVIEW
%s

#######################################################################
# DELIVERABLE - FORECASTING BLUEPRINT

1. DATA PROFILE
   * Date Column      : <col_name> (<format>)
   * Target Variable  : <col_name> (<numeric type>)
   * Grouping Column  : <list_> (e.g. id, store_id, product_id)
   * Frequency        : <daily | weekly | monthly | other>

2. FORECASTING SPECIFICATIONS
   * Horizon          : <N periods> + rationale
   * Confidence Level : <e.g., 95%% unless overridden>

3. MODELING STRATEGY
   All five models below will be trained on the same training window.
   Their forecasts will be **averaged equally** to form the ensemble:

   | # | Model Family           | Engine / Function        |
   |---|------------------------|--------------------------|
   | 1 | Prophet w/ Regressors  | `prophet_reg()`          |
   | 2 | Gradient Boosting      | `boost_tree()`           |
   | 3 | Random Forest          | `rand_forest()` (ranger) |
   | 4 | Support Vector Machine | `svm_rbf()` (kernlab)    |
   | 5 | Prophet + Boost        | `prophet_boost()`        |

   * Recipes Used:
     - `recipe_spec_1` for Prophet-based models
     - `recipe_spec_2` for tree/kernel models
   * Ensemble Rule: Equal-weighted average across models
   * Rationale   : Increases robustness via model diversity

4. FUNCTION ARGUMENT MAP
   | `forecast_ts()` Argument | Value                     |
   |--------------------------||
   | data                     | <data frame name>         |
   | value                    | <Target variable column>  |
   | date                     | <Date column>             |
   | group                    | <Group column>            |
   | horizon                  | <Int / NULL>              |
   | conf_level               | <Decimal between 0.80-0.99>

   * Output schema verified: id, date, value, conf_lo, conf_hi

** IMPORTANT: The `group` variable will appear in **_90%% of user inputs_**.
Look for it carefully. If not explicitly provided, default to `'ALL GROUPS'`.**

#######################################################################
# STRICT CONSTRAINTS
1. No code or implementation syntax
2. No filesystem paths or data I/O
3. Time-series focus only
4. Avoid vague language - be specific and actionable"
      ,
      user_instructions,
      previous_steps,
      all_datasets_summary
    )

    # -- 7. Model Query
    steps <- model(prompt)

    # -- 8. Diagnostic Logging
    if (!is.null(state$log)) {
      state$log[[length(state$log) + 1]] <- list(
        step = "recommend_steps",
        prompt = prompt,
        response = steps,
        timestamp = Sys.time()
      )
    }

    # -- 9. Return Updated State -----
    list(
      recommended_steps = paste0("\nRecommended Forecasting Steps:\n", trimws(steps)),
      all_datasets_summary = all_datasets_summary,
      previous_recommendations = if (nchar(previous_steps) > 0) {
        c(state$previous_recommendations %||% list(), list(previous_steps))
      } else {
        state$previous_recommendations %||% list()
      }
    )
  }
}

node_create_forecasting_code <- function(model, mode, line_width, bypass_recommended_steps = FALSE) {
  function(state) {

    # Define safe_as.integer function
    safe_as.integer <- function(x) {
      tryCatch({
        x <- as.character(x)
        if (grepl("^\\d+$", x)) as.integer(x) else NA_integer_
      }, warning = function(w) NA_integer_, error = function(e) NA_integer_)
    }

    # Null-coalescing utility
    `%||%` <- function(a, b) if (!is.null(a) && !identical(a, "")) a else b

    # Console banner
    if (bypass_recommended_steps) message("---TIME SERIES FORECASTING AGENT----")
    message("    * CREATE FORECASTING CODE")

    # 1 # Gather data summary & instructions -
    if (bypass_recommended_steps) {
      df  <- if (is.data.frame(state$data_raw)) state$data_raw else as.data.frame(state$data_raw)
      all_datasets_summary      <- get_dataframe_summary(df, skip_stats = FALSE)
      chart_generator_instructions <- state$user_instructions
    } else {
      all_datasets_summary      <- state$all_datasets_summary
      chart_generator_instructions <- state$recommended_steps
    }

    # 2 # Compose elite prompt
    prompt <- sprintf(
      "You are an elite time-series forecasting assistant supporting high-stakes enterprise projects.

Your task: **extract five forecasting parameters** from the user's input and data summary. Return only a **valid JSON object** - no commentary, no markdown, no extra formatting.

#######################################################################
# REQUIRED PARAMETERS
1. `params_value`       - Numeric column to forecast
2. `params_date`        - Date/time column
3. `params_group`       - Grouping column (or NULL for ungrouped)
4. `params_horizon`     - Integer future periods (#1)
5. `params_conf_level`  - Confidence level between 0.80 and 0.99 (default 0.95)

# EXTRACTION RULES
* Use exact column names from the data summary. **Never invent new names.**
* Prioritize numeric columns matching 'sales', 'revenue', 'value', 'demand', 'units', 'traffic'.
* For grouping: if the user says **'by / per / for each' + <column>**, assign that as `params_group`.
* Horizon rules (in order of precedence):
  # Explicit span: \"next 18 months\" # detect unit & convert
  # Calendar phrases: \"through 2027\", \"Q4\", etc.
  # Defaults: monthly=12, weekly=13, daily=30
  # Fallback: #20%% of N# if N < 100; #10%% of N# otherwise
* Confidence: if absent or invalid, default to **0.95**

* Return **all five keys**, even if some are `null`.

# EXAMPLE OUTPUT
{
  \"params_value\":       \"sales\",
  \"params_date\":        \"date\",
  \"params_group\":       \"product_id\",  # Example: id, store_id, region
  \"params_horizon\":     12,
  \"params_conf_level\":  0.95
}

# GOOD USER EXAMPLES (partial list)
* 'Forecast the next 24 months of sales.'
* 'Show revenue forecast **per region** for FY-2026.'
* 'Give me a weekly forecast of visits by channel for Q2.'
* 'Estimate monthly churn through 2027.'
* 'Predict ride-hailing trips per driver next Friday.'

#######################################################################
# DATA SUMMARY
%s

#######################################################################
# USER INSTRUCTIONS
%s"
      , all_datasets_summary, chart_generator_instructions)

    #-- 3 # Invoke model -----
    raw_response <- model(prompt)

    # Optional: append to diagnostic log
    if (!is.null(state$log)) state$log[[length(state$log)+1]] <-
      list(step = "extract_params", prompt = prompt, response = raw_response)

    #-- 4 # Robust JSON parse
    if (!requireNamespace("jsonlite", quietly = TRUE)) {
      stop("Package 'jsonlite' is required for JSON parsing.")
    }
    # Improved JSON parsing
    parsed <- tryCatch({
      # First clean the response
      cleaned_response <- gsub("```json|```", "", raw_response)
      jsonlite::fromJSON(cleaned_response)
    }, error = function(e) {
      warning("Failed to parse JSON response: ", e$message)
      NULL
    })

    #-- 5. Parameter Validation
    if (is.null(parsed)) {
      stop("Could not parse valid forecasting parameters from model response")
    }

    if (is.null(parsed$params_value) || is.null(parsed$params_date)) {
      warning("Missing required parameters: params_value and params_date must be specified")
      return(NULL)
    }

    params <- list(
      params_value = as.character(parsed$params_value),
      params_date = as.character(parsed$params_date),
      params_group = if (!is.null(parsed$params_group)) as.character(parsed$params_group) else NULL,
      params_horizon = if (!is.null(parsed$params_horizon)) {
        horizon <- safe_as.integer(parsed$params_horizon)
        if (is.na(horizon) || horizon <= 0) NULL else horizon
      } else NULL,
      params_conf_level = if (!is.null(parsed$params_conf_level)) {
        conf <- suppressWarnings(as.numeric(parsed$params_conf_level))
        if (is.na(conf) || conf < 0.8 || conf > 0.99) 0.95 else round(conf/0.05)*0.05
      } else 0.95
    )

    # 7 # Assemble result
    params <- list(
      params_value     = parsed$params_value,
      params_date      = parsed$params_date,
      params_group     = parsed$params_group       %||% NULL,
      params_horizon   = horizon,
      params_conf_level = parsed$params_conf_level %||% 0.95
    )

    # 8 # Plot parameters

    # After getting params, we need to return them in the state
    list(
      plot_mode = mode,
      plot_line_width = line_width,
      forecasting_params = params,
      retry_count = 0  # Initialize retry counter
    )
  }
}

node_execute_forecasting_code <- function(state, verbose = FALSE) {
  if (verbose) message("    * EXECUTING FORECASTING CODE")

  # Define required packages
  required_packages <- c(
    "dplyr", "rlang", "lubridate", "tibble", "tidyr", "tidymodels",
    "modeltime", "modeltime.ensemble", "timetk", "forcats", "recipes",
    "parsnip", "workflows", "rsample", "prophet", "ranger", "kernlab",
    "xgboost", "jsonlite"
  )

  # Check for missing packages using requireNamespace
  missing_pkgs <- required_packages[!vapply(required_packages, requireNamespace, logical(1), quietly = TRUE)]

  if (length(missing_pkgs) > 0) {
    if (verbose) message("Missing packages: ", paste(missing_pkgs, collapse = ", "))
    return(list(forecasting_error = paste("Missing required packages:", paste(missing_pkgs, collapse = ", "))))
  }

  # Load packages silently
  suppressPackageStartupMessages({
    invisible(lapply(required_packages, require, character.only = TRUE))
  })

  # 2. Data Preparation
  if (is.null(state$data_raw)) {
    return(list(forecasting_error = "No data available in state$data_raw"))
  }

  df <- if (is.data.frame(state$data_raw)) {
    state$data_raw
  } else if (is.list(state$data_raw)) {
    as.data.frame(state$data_raw)
  } else {
    return(list(forecasting_error = "Unsupported data format in state$data_raw"))
  }

  # 3. Parameter Validation
  params <- state$forecasting_params
  if (is.null(params)) {
    return(list(forecasting_error = "Missing forecasting parameters in state$forecasting_params"))
  }

  validate_parameters <- function(params, df) {
    # Core parameter checks
    if (!params$params_value %in% names(df)) {
      stop("Target column '", params$params_value, "' not found in data")
    }
    if (!params$params_date %in% names(df)) {
      stop("Date column '", params$params_date, "' not found in data")
    }
    if (!is.null(params$params_group) && !params$params_group %in% names(df)) {
      stop("Group column '", params$params_group, "' not found in data")
    }

    # Type checks
    if (!inherits(df[[params$params_date]], c("Date", "POSIXt"))) {
      stop("Date column '", params$params_date, "' must be Date/Datetime type")
    }
    if (!is.numeric(df[[params$params_value]])) {
      stop("Target column '", params$params_value, "' must be numeric")
    }

    # Horizon validation
    if (!is.null(params$params_horizon)) {
      if (!is.integer(params$params_horizon) || params$params_horizon < 1) {
        stop("Horizon must be positive integer, got ", params$params_horizon)
      }
    }
  }

  tryCatch(
    validate_parameters(params, df),
    error = function(e) {
      return(list(forecasting_error = paste("Parameter validation failed:", e$message)))
    }
  )

  # 4. Execute Forecast
  forecast_args <- list(
    data      = df,
    value     = params$params_value,
    date      = params$params_date,
    group     = params$params_group,
    horizon   = params$params_horizon,
    conf_level= params$params_conf_level
  )
  forecast_args <- forecast_args[!sapply(forecast_args, is.null)]

  result <- suppressWarnings(
    tryCatch({
      fcst <- do.call(forecast_ts, forecast_args)

      # Validate output structure
      required_cols <- c("id", "date", "value", "conf_lo", "conf_hi")
      if (!all(required_cols %in% names(fcst))) {
        stop("Invalid forecast output - missing required columns: ",
             paste(setdiff(required_cols, names(fcst)), collapse = ", "))
      }

      # Generate plot if available
      mode <- state$plot_mode
      line_width <- state$plot_line_width

      plot_obj <- tryCatch(
        plot_forecast(
          data = fcst,
          mode = mode,
          line_width = line_width
        ),
        error = function(e) {
          warning("Plot generation failed: ", e$message)
          NULL
        }
      )

      list(forecast_data = fcst, forecast_plot = plot_obj)
    }, error = function(e) {
      return(list(forecasting_error = paste("Forecast execution failed:", e$message)))
    })
  )

  # 5. Return Updated State
  if (!is.null(result$forecasting_error)) {
    return(list(forecasting_error = result$forecasting_error))
  }

  list(
    forecasting_data = result$forecast_data,
    forecasting_result = result$forecast_plot,
    execution_success = TRUE,
    timestamp = Sys.time()
  )
}

node_fix_forecasting_code <- function(model) {
  function(state) {
    message("    * FIX FORECASTING PARAMETERS")
    message("      retry_count:", state$retry_count)

    # Get context
    error_message <- state$forecasting_error
    prev_params <- state$forecasting_params
    data_summary <- state$all_datasets_summary
    user_instructions <- state$user_instructions

    prompt <- sprintf(
      "You are a forecasting quality assurance specialist. Analyze the failed forecast attempt and suggest corrected parameters.

      # CRITICAL REQUIREMENTS
      1. PRESERVE USER INTENT from original instructions
      2. Only modify parameters that caused the error
      3. Maintain consistency with the data structure

      # CONTEXT
      USER REQUEST: '%s'

      ERROR MESSAGE: '%s'

      DATA STRUCTURE:
      %s

      PREVIOUS PARAMETERS:
      %s

      # FIXING RULES
      - Keep value/date columns UNLESS they caused the error
      - Only remove grouping if absolutely necessary
      - Adjust horizon algorithmically:
        * If date error: set to 1 temporarily
        * If memory error: reduce by 50%%
        * Else: keep original
      - Confidence: Only adjust if outside 0.8-0.99 range

      # OUTPUT FORMAT
      Return ONLY this JSON structure with your fixes:
      {
        \"params_value\": \"<exact_column_name>\",
        \"params_date\": \"<exact_date_column>\",
        \"params_group\": \"<column_name_or_null>\",
        \"params_horizon\": <integer_based_on_rules>,
        \"params_conf_level\": <0.80_to_0.99>,
        \"fix_reason\": \"<brief_explanation>\"
      }",

      user_instructions,
      error_message,
      data_summary,
      jsonlite::toJSON(prev_params, auto_unbox = TRUE)
    )
    # Get LLM response
    raw_response <- model(prompt)

    # Parse and validate
    parsed <- tryCatch(jsonlite::fromJSON(raw_response), error = function(e) NULL)

    # Fallback to original params if parse fails
    if(is.null(parsed)) {
      warning("Failed to parse fixed parameters. Using original values.")
      parsed <- prev_params
    }

    # Ensure required fields exist
    if(is.null(parsed$params_value)) parsed$params_value <- prev_params$params_value
    if(is.null(parsed$params_date)) parsed$params_date <- prev_params$params_date

    # Update retry counter
    new_retry_val <- state$retry_count + 1

    list(
      forecasting_params = parsed,
      forecasting_error = NULL,
      retry_count = new_retry_val
    )
  }
}

node_explain_forecasting_code <- function(model) {
  function(state) {
    summary <- if (!is.null(state$forecasting_error)) {
      paste("Error occurred:", state$forecasting_error)
    } else {
      "Forecasting created successfully"
    }

    prompt <- sprintf(
      "Explain these Forecasting transformations:\nSteps: %s\n\nResult:\n%s",
      state$recommended_steps, summary
    )

    explanation <- model(prompt)

    list(
      forecasting_report = explanation,
      forecasting_summary = summary
    )
  }
}

node_func_human_review <- function(
    prompt_text,
    yes_goto,
    no_goto,
    user_instructions_key = "user_instructions",
    recommended_steps_key = "recommended_steps") {
  function(state) {
    message(" * HUMAN REVIEW")
    steps <- if (!is.null(state[[recommended_steps_key]])) state[[recommended_steps_key]] else ""
    prompt_filled <- sprintf(prompt_text, steps)
    user_input <- interrupt(prompt_filled)
    if (tolower(trimws(user_input)) == "yes") {
      return(list(goto = yes_goto, update = list()))
    } else {
      modifications <- paste("Modifications:", user_input, sep = "\n")
      old_val <- state[[user_instructions_key]]
      if (is.null(old_val)) old_val <- ""
      new_val <- paste(old_val, modifications, sep = "\n")
      return(list(goto = no_goto, update = list(user_instructions = new_val)))
    }
  }
}

###############################################################################
## TIME SERIES FORECASTING AGENT IMPLEMENTATION
###############################################################################
#' Build a Time Series Forecasting Agent
#'
#' Constructs a state graph-based forecasting agent that:
#' recommends forecasting steps, extracts parameters, generates code,
#' executes the forecast using `modeltime`, fixes errors if needed,
#' and explains the result. It leverages multiple models including
#' Prophet, XGBoost, Random Forest, SVM, and Prophet Boost, and
#' combines them in an ensemble.
#'
#' @name build_forecasting_agent
#' @param model A function that takes a prompt and returns an LLM-generated result.
#' @param bypass_recommended_steps Logical; skip initial step recommendation.
#' @param bypass_explain_code Logical; skip the final explanation step.
#' @param mode Visualization mode for forecast plots. One of `"light"` or `"dark"`.
#' @param line_width Line width used in plotly forecast visualization.
#' @param verbose Logical; whether to print progress messages.
#'
#' @return A callable agent function that mutates the given `state` list.
#'
#' @examples
#' \dontrun{
#' # 2) Prepare the dataset
#' my_data <- walmart_sales_weekly
#'
#' # 3) Create the forecasting agent
#' forecasting_agent <- build_forecasting_agent(
#'   model = my_llm_wrapper,
#'   bypass_recommended_steps = FALSE,
#'   bypass_explain_code = FALSE,
#'   mode = "dark", # dark or light
#'   line_width = 3,
#'   verbose = FALSE
#' )
#'
#' # 4) Define the initial state
#' initial_state <- list(
#'   user_instructions = "Forecast sales for the next 30 days, using `id` as the grouping variable,
#'   a forecasting horizon of 30, and a confidence level of 90%.",
#'   data_raw = my_data
#' )
#'
#' # 5) Run the agent
#' final_state <- forecasting_agent(initial_state)
#' }
#'
#' @export
NULL

build_forecasting_agent <- function(
    model,
    bypass_recommended_steps = FALSE,
    bypass_explain_code = FALSE,
    mode = "light",
    line_width = 3,
    verbose = FALSE) {

  # no human_validation needed
  human_validation = FALSE

  # Define node functions list
  node_functions <- list(
    recommend_forecasting_steps = node_recommend_forecasting_steps(model, verbose),
    human_review = node_func_human_review(
      prompt_text = "Are the following forecasting instructions correct# (Answer 'yes' or provide modifications)\n%s",
      yes_goto = if (!bypass_explain_code) "explain_forecasting_code" else "__end__",
      no_goto = "recommend_forecasting_steps",
      user_instructions_key = "user_instructions",
      recommended_steps_key = "recommended_steps"
    ),
    create_forecasting_code = node_create_forecasting_code(
      model = model,
      mode = mode,
      line_width = line_width,
      bypass_recommended_steps = bypass_recommended_steps
    ),
    execute_forecasting_code = function(state) node_execute_forecasting_code(state, verbose),
    fix_forecasting_code = node_fix_forecasting_code(model),
    explain_forecasting_code = node_explain_forecasting_code(model)
  )

  # Create the agent graph
  app <- create_coding_agent_graph(
    node_functions = node_functions,
    recommended_steps_node_name = "recommend_forecasting_steps",
    create_code_node_name = "create_forecasting_code",
    execute_code_node_name = "execute_forecasting_code",
    fix_code_node_name = "fix_forecasting_code",
    explain_code_node_name = "explain_forecasting_code",
    error_key = "forecasting_error",
    max_retries_key = "max_retries",
    retry_count_key = "retry_count",
    human_validation = human_validation,
    human_review_node_name = "human_review",
    checkpointer = NULL,
    bypass_recommended_steps = bypass_recommended_steps,
    bypass_explain_code = bypass_explain_code
  )

  # Return a function that can be invoked with state
  function(state) {
    if (is.null(state$retry_count)) state$retry_count <- 0
    if (is.null(state$max_retries)) state$max_retries <- 3

    app(state)
  }
}

Try the LLMAgentR package in your browser

Any scripts or data that you put into this service are public.

LLMAgentR documentation built on June 8, 2025, 10:02 a.m.