R/foot_round_robin.R

Defines functions foot_round_robin

Documented in foot_round_robin

#' Round-robin for football leagues
#'
#' Posterior predictive probabilities for a football season in a round-robin format
#'
#' @param object An object either of class \code{stanFoot}, \code{CmdStanFit}, \code{stanfit}.
#' @param data A data frame containing match data with columns:
#'   \itemize{
#'     \item \code{periods}:  Time point of each observation (integer >= 1).
#'     \item \code{home_team}: Home team's name (character string).
#'     \item \code{away_team}: Away team's name (character string).
#'     \item \code{home_goals}: Goals scored by the home team (integer >= 0).
#'     \item \code{away_goals}: Goals scored by the away team (integer >= 0).
#'   }
#' @param teams An optional character vector specifying team names to include. If \code{NULL}, all teams are included.
#' @param output An optional character string specifying the type of output to return. One of \code{"both"}, \code{"table"},
#'   or \code{"plot"}. Default is \code{"both"}.
#' @details
#'
#' For Bayesian models fitted via \code{stan_foot} the round-robin table is computed according to the
#' simulation from the posterior predictive distribution of future (out-of-sample) matches.
#' The dataset should refer to one or more seasons from a given national football league (Premier League, Serie A, La Liga, etc.).
#'
#' @return
#'
#' If \code{output = "both"} a list with:
#' \itemize{
#'   \item{\code{round_table}}: A data frame of matchups (\code{Home}, \code{Away}), observed scores, and \code{Home_prob} (median posterior probability of a home win).
#'   \item{\code{round_plot}}: A \code{ggplot} heatmap of home‑win probabilities with observed scores overlaid.
#' }
#' If \code{output = "table"} or \code{"plot"}, returns only that component.
#'
#'
#' @author Leonardo Egidi \email{legidi@units.it} and Roberto Macrì Demartino \email{roberto.macridemartino@deams.units.it}
#'
#' @examples
#' \dontrun{
#' if (instantiate::stan_cmdstan_exists()) {
#'   library(dplyr)
#'
#'   data("italy")
#'   italy_1999_2000 <- italy %>%
#'     dplyr::select(Season, home, visitor, hgoal, vgoal) %>%
#'     dplyr::filter(Season == "1999" | Season == "2000")
#'
#'   colnames(italy_1999_2000) <- c("periods", "home_team", "away_team", "home_goals", "away_goals")
#'
#'   fit <- stan_foot(italy_1999_2000, "double_pois", predict = 45, iter_sampling = 200)
#'
#'   foot_round_robin(fit, italy_1999_2000)
#'   foot_round_robin(fit, italy_1999_2000, c("Parma AC", "AS Roma"))
#' }
#' }
#' @importFrom ggplot2 ggplot aes geom_tile geom_text geom_rect scale_fill_gradient
#'   scale_x_continuous scale_y_continuous theme_bw theme element_text ggtitle rel
#' @export


foot_round_robin <- function(object, data, teams = NULL, output = "both") {
  #   ____________________________________________________________________________
  #   Data and arguments checks                                               ####

  # Check required columns in the data
  required_cols <- c("periods", "home_team", "away_team", "home_goals", "away_goals")
  missing_cols <- setdiff(required_cols, names(data))
  if (length(missing_cols) > 0) {
    stop(paste("data is missing required columns:", paste(missing_cols, collapse = ", ")))
  }
  # Check required output
  match.arg(output, c("both", "table", "plot"))


  # Extract simulation draws based on the object's class
  if (inherits(object, c("stanFoot", "CmdStanFit"))) {
    draws <- if (inherits(object, "stanFoot")) {
      object$fit$draws()
    } else {
      object$draws() # for CmdStanFit objects
    }
    draws <- posterior::as_draws_rvars(draws)
    if (!("y_prev" %in% names(draws) && "y_rep" %in% names(draws) || "diff_y_prev" %in% names(draws) && "diff_y_rep" %in% names(draws))) {
      stop("Model does not contain at least one valid pair between 'y_prev' and 'y_rep' or 'diff_y_prev' and 'diff_y_rep' in its samples.")
    }
    sims <- list()
    if ("y_prev" %in% names(draws)) {
      sims$y_prev <- posterior::draws_of(draws[["y_prev"]])
    }
    if ("diff_y_prev" %in% names(draws)) {
      sims$diff_y_prev <- posterior::draws_of(draws[["diff_y_prev"]])
    }
    if ("y_rep" %in% names(draws)) {
      sims$y_rep <- posterior::draws_of(draws[["y_rep"]])
    }
    if ("diff_y_rep" %in% names(draws)) {
      sims$diff_y_rep <- posterior::draws_of(draws[["diff_y_rep"]])
    }
  } else if (inherits(object, "stanfit")) {
    sims <- rstan::extract(object)

    if (!("y_prev" %in% names(sims) && "y_rep" %in% names(sims) || "diff_y_prev" %in% names(sims) && "diff_y_rep" %in% names(sims))) {
      stop("Model does not contain at least one valid pair between 'y_prev' and 'y_rep' or 'diff_y_prev' and 'diff_y_rep' in its samples.")
    }
  } else {
    stop("Provide one among these three model fit classes: 'stanfit', 'CmdStanFit', 'stanFoot'.")
  }

  #   ____________________________________________________________________________
  #   Dataframes and plots                                                    ####

  # Prepare team and season information
  y <- as.matrix(data[, 4:5])
  teams_all <- unique(c(data$home_team, data$away_team))
  team_home <- match(data$home_team, teams_all)
  team_away <- match(data$away_team, teams_all)

  # Determine the model
  if (!is.null(sims$diff_y_prev) && is.null(sims$y_prev)) {
    # t-student case
    N_prev <- dim(sims$diff_y_prev)[2]
    N <- dim(sims$diff_y_rep)[2]
    y_rep1 <- round(abs(sims$diff_y_prev) * (sims$diff_y_prev > 0))
    y_rep2 <- round(abs(sims$diff_y_prev) * (sims$diff_y_prev < 0))
    team1_prev <- team_home[(N + 1):(N + N_prev)]
    team2_prev <- team_away[(N + 1):(N + N_prev)]
  } else if (!is.null(sims$y_prev)) {
    # Skellam, double Poisson, or bivariate Poisson cases
    N_prev <- dim(sims$y_prev)[2]
    N <- dim(sims$y_rep)[2]
    y_rep1 <- sims$y_prev[, , 1]
    y_rep2 <- sims$y_prev[, , 2]
    team1_prev <- team_home[(N + 1):(N + N_prev)]
    team2_prev <- team_away[(N + 1):(N + N_prev)]
  }


  # Condition to ensure that when only the last matchday is predicted, all teams are considered
  if (length(unique(team1_prev)) !=
    length(unique(c(team1_prev, team2_prev)))) {
    team1_prev <- c(team1_prev, team2_prev)
    team2_prev <- c(team2_prev, team1_prev)
  }

  # Select teams to include in output
  if (is.null(teams)) {
    teams <- teams_all[unique(team1_prev)]
  }
  team_index <- match(teams, teams_all)


  if (anyNA(team_index)) {
    stop(paste(
      teams[is.na(team_index)],
      "is not in the test set. Please provide a valid team name. "
    ))
    # team_index <- team_index[!is.na(team_index)]
  }

  # Initialize matrices
  team_names <- teams_all[team_index]
  nteams <- length(unique(team_home))
  nteams_new <- length(team_index)
  M <- dim(sims$diff_y_rep)[1]
  counts_mix <- matrix(0, nteams, nteams)
  number_match_days <- length(unique(team1_prev)) * 2 - 2
  punt <- matrix("-", nteams, nteams)

  suppressWarnings(
    # This condition means that we are "within" the season
    # and that the training set has the same teams as the test set
    cond_1 <- all(sort(unique(team_home)) == sort(unique(team1_prev))) && N < length(unique(team1_prev)) * (length(unique(team1_prev)) - 1)
  )
  # This condition means that the training set does NOT have
  # the same teams as the test set and that we are considering
  # training data from multiple seasons
  suppressWarnings(
    cond_2 <- N > length(unique(team1_prev)) * (length(unique(team1_prev)) - 1) &&
      all(sort(unique(team_home)) == sort(unique(team1_prev))) == FALSE &&
      N %% (length(unique(team1_prev)) * (length(unique(team1_prev)) - 1)) != 0
  )

  suppressWarnings(
    # This condition means that we are at the end of a season
    cond_3 <- N %% (length(unique(team1_prev)) * (length(unique(team1_prev)) - 1)) == 0
  )



  # Fill in the 'punt' matrix with observed match scores based on the condition
  if (cond_1 == TRUE) {
    for (n in 1:N) {
      punt[team_home[n], team_away[n]] <-
        paste(y[n, 1], "-", y[n, 2], sep = "")
    }
  } else if (cond_2 == TRUE) {
    mod <- floor((N / (length(unique(team1_prev)) / 2)) / number_match_days)
    old_matches <- number_match_days * mod * length(unique(team1_prev)) / 2
    new_N <- seq(1 + old_matches, N)

    for (n in new_N) {
      punt[team_home[n], team_away[n]] <-
        paste(y[n, 1], "-", y[n, 2], sep = "")
    }
  }

  # Compute posterior probabilities for home wins in the predicted matches
  for (n in seq_len(N_prev)) {
    prob <- sum(y_rep1[, n] > y_rep2[, n]) / M
    counts_mix[
      unique(team1_prev[n]),
      unique(team2_prev[n])
    ] <- prob
  }

  x1 <- seq(0.5, nteams_new - 1 + 0.5)
  x2 <- seq(1.5, nteams_new - 1 + 1.5)
  x1_x2 <- matrix(0, nteams_new, nteams_new)
  x2_x1 <- matrix(0, nteams_new, nteams_new)
  y1_y2 <- matrix(0, nteams_new, nteams_new)
  y2_y1 <- matrix(0, nteams_new, nteams_new)
  for (j in 1:nteams_new) {
    x1_x2[j, j] <- x1[j]
    x2_x1[j, j] <- x2[j]
    y1_y2[j, j] <- x1[j]
    y2_y1[j, j] <- x2[j]
  }
  x_ex <- seq(1, nteams_new, length.out = nteams_new)
  y_ex <- seq(1, nteams_new, length.out = nteams_new)
  data_ex <- expand.grid(Home = x_ex, Away = y_ex)
  data_ex$prob <- as.double(counts_mix[1:nteams, 1:nteams][team_index, team_index])

  # Pre-compute the rectangle boundaries into a data frame
  rect_df <- data.frame(
    xmin = as.vector(x1_x2),
    xmax = as.vector(x2_x1),
    ymin = as.vector(x1_x2),
    ymax = as.vector(x2_x1)
  )

  # Create the plot
  round_plot <- ggplot(data_ex, aes(x = Home, y = Away)) +
    geom_tile(aes(fill = prob)) +
    geom_text(aes(label = as.vector(punt[team_index, team_index])), size = 4.5) +
    geom_rect(
      data = rect_df,
      aes(xmin = .data$xmin, xmax = .data$xmax, ymin = .data$ymin, ymax = .data$ymax),
      inherit.aes = FALSE,
      fill = "black", color = "black", linewidth = 1
    ) +
    scale_fill_gradient(low = "white", high = "red3", name = "Prob") +
    scale_x_continuous(breaks = x_ex, labels = team_names, name = "Home Team") +
    scale_y_continuous(breaks = y_ex, labels = team_names, name = "Away Team") +
    theme_bw() +
    theme(
      axis.text.x = element_text(angle = 45, hjust = 1, size = rel(1.2)),
      axis.text.y = element_text(size = rel(1.2)),
      legend.text = element_text(size = 11),
      legend.title = element_text(size = 12)
    ) +
    ggtitle("Home win posterior probabilities")

  if (sum(data_ex$prob) == 0) {
    # build a data.frame with Home, Away and Observed
    tbl <- data.frame(
      Home     = teams[data_ex$Home],
      Away     = teams[data_ex$Away],
      Observed = as.vector(punt[team_index, team_index]),
      stringsAsFactors = FALSE
    )

    # keep only rows where Home ≠ Away
    tbl <- tbl[tbl$Home != tbl$Away, ]
    rownames(tbl) <- NULL
  } else {
    # build a data.frame with Home, Away, Home_prob and Observed
    tbl <- data.frame(
      Home      = teams[data_ex$Home],
      Away      = teams[data_ex$Away],
      Home_prob = round(data_ex$prob, 3),
      Observed  = as.vector(punt[team_index, team_index]),
      stringsAsFactors = FALSE
    )

    # keep only rows where Home ≠ Away AND Home_prob ≠ 0
    tbl <- tbl[tbl$Home != tbl$Away & tbl$Home_prob != 0, ]
    rownames(tbl) <- NULL
  }

  result <- list(round_table = tbl, round_plot = round_plot)
  if (output == "both") {
    return(result)
  } else if (output == "plot") {
    return(round_plot)
  } else if (output == "table") {
    return(tbl)
  }
}
LeoEgidi/footBayes documentation built on June 2, 2025, 11:32 a.m.