R/predict.R

Defines functions predict_probabilities

Documented in predict_probabilities

# Constants
DEFAULT_TEMPERATURE <- 1

#' Predict probabilities of OT candidates
#'
#' Predict probabilities of candidates based on their violation profiles and
#' constraint weights.
#'
#' For each input/output pair in the provided file this function
#' will calculate the probability of that output given the input form and the
#' provided weights. This probability is defined as
#'
#' \deqn{P(y|x; w) = \frac{1}{Z_w(x)}\exp(-\sum_{k=1}^{m}{w_k f_k(y, x)})}
#'
#' where \eqn{f_k(y, x)} is the number of violations of constraint \eqn{k}
#' incurred by mapping underlying \eqn{x} to surface \eqn{y}, \eqn{w_k} is the
#' weight associated with constraint \eqn{k}, and  \eqn{Z_w(x)} is a
#' normalization term defined as
#'
#' \deqn{Z_w(x) = \sum_{y\in\mathcal{Y}(x)}{\exp(-\sum_{k=1}^{m}{w_k f_k(y, x)})}}
#'
#' where \eqn{\mathcal{Y}(x)} is the set of all output candidates for input
#' \eqn{x}.
#'
#' The resulting probabilities will be appended to a data frame object
#' representing the input tableaux. This data frame can also be saved to a file
#' if the `output_path` argument is provided.
#'
#' @section Using temperature:
#'
#' If the temperature parameter \eqn{T} is specified, \eqn{P(y|x; w)} is
#' calculated as
#'
#' \deqn{\frac{1}{Z_w(x)}\exp(-\sum_{k=1}^{m}{(w_k f_k(y, x)})/T)} and
#' \eqn{Z_w(x)} is similarly calculated as
#'
#' \deqn{\sum_{y\in \mathcal{Y}(x)}{\exp(-\sum_{k=1}^{m}{(w_k f_k(y, x))/T})}}
#'
#' Larger values of \eqn{T} move the predicted probabilities of output
#' candidates for a particular input towards equality with one another. For
#' example, if a particular input has two candidate outputs, higher values of
#' \eqn{T} will move the probability of each towards `0.5`.
#'
#' The temperature parameter can be used to generate less categorical
#' predictions in a way that is independent of the constraint weights. See
#' Ackley, Hinton, and Sejnowski (1985, p. 150-152) for more detail, and Hayes
#' et al. (2009) and Mayer (2021, Ch. 4) for examples of temperature used in
#' practice. By default this parameter is set to `1`, which renders the
#' equations in this section equivalent to the standard calculations of
#' probability.
#'
#' @param test_input The input data frame/data table/tibble. This should contain one
#'   or more OT tableaux consisting of mappings between underlying and surface
#'   forms with observed frequency and violation profiles. Constraint violations
#'   must be numeric.
#'
#'   For an example of the data frame format, see inst/extdata/sample_data_frame.csv.
#'   You can read this file into a data frame using read.csv or into a tibble
#'   using dplyr::read_csv.
#'
#'   This function also supports the legacy OTSoft file format. You can use this
#'   format by passing in a file path string to the OTSoft file rather than a
#'   data frame.
#'
#'   For examples of OTSoft format, see inst/extdata/sample_data_file.txt.
#'
#' @param constraint_weights A vector of constraint weights to use. These are typically
#'   generated by the \code{\link{optimize_weights}} function.
#' @param output_path (optional) A string specifying the path to a file to
#'   which the predictions will be saved. If the file exists it will be overwritten.
#'   If this argument isn't provided, the output will not be written to a file.
#' @param out_sep (optional) The delimiter used in the output files.
#'   Defaults to commas.
#' @param encoding (optional) The character encoding of the input file. Defaults
#'  to "unknown".
#' @param temperature (optional) The temperature parameter, which should be a
#'   real number \eqn{>= 1}. Defaults to 1.
#' @return An object with the following named attributes:
#' \itemize{
#'         \item `log_lik`: the log likelihood of the data under the provided
#'           weights
#'         \item `predictions`: A data table containing all the tableaux, with
#'           probabilities assigned to each candidate and errors.
#' }
#' @examples
#'   # Get paths to toy data file
#'   df_file <- system.file(
#'       "extdata", "sample_data_frame.csv", package = "maxent.ot"
#'   )
#'   # Fit weights to dataframe with no biases
#'   tableaux_df <- read.csv(df_file)
#'   fit_model <- optimize_weights(tableaux_df)
#'   predict_probabilities(tableaux_df, fit_model$weights)
#'
#'   # Do so with a temperature parameter
#'   predict_probabilities(tableaux_df, fit_model$weights, temperature = 2)
#'
#'   # Save predictions to a file
#'   tmp_output <- tempfile()
#'   predict_probabilities(tableaux_df, fit_model$weights, output_path=tmp_output)
#' @export
predict_probabilities <- function(test_input, constraint_weights,
                                  output_path = NA, out_sep = ',',
                                  encoding = 'unknown',
                                  temperature = DEFAULT_TEMPERATURE) {

  processed_input <- load_input(test_input, encoding = encoding)
  long_names <- processed_input$long_names
  data <- processed_input$data

  # Build ourselves a matrix for efficient computation
  # Pre-allocate space
  data_matrix <- matrix(0L, nrow = nrow(data), ncol = ncol(data) + 2)
  # Map URs to integers
  data_matrix[,1] <- as.integer(as.factor(data[,1]))
  # Set the violation profiles
  data_matrix[,2:(ncol(data_matrix) - 3)] <- apply(
    as.matrix(data[,3:ncol(data)]), 2, as.numeric
  )
  # Replace empty cells with 0
  data_matrix[is.na(data_matrix)] <- 0

  loglik <- calculate_log_likelihood(
    constraint_weights, data_matrix, temperature
  )

  # Calculate probabilities
  data_matrix <- calculate_probabilities(
    constraint_weights, data_matrix, temperature
  )
  # Unlog them
  data_matrix[, ncol(data_matrix) - 1] <- exp(data_matrix[, ncol(data_matrix) - 1])

  # Calculate predicted probabilities
  data_matrix[, ncol(data_matrix)] <- apply(data_matrix, 1, normalize_row, data_matrix, 2)
  data_matrix <- data_matrix[, -(ncol(data_matrix) - 2)]

  output <- cbind(data[, 1:2], data_matrix[, 2:ncol(data_matrix)])

  # Calculate error
  error <- output[, (ncol(output)-1)] - output[, (ncol(output))]
  output <- cbind(output, error)

  names(output) <- c(c(c("Input", "Output", "Freq"), unlist(long_names)),
                     "Predicted", "Observed", "Error")

  if (!is.na(output_path)) {
    utils::write.table(output, file=output_path, sep=out_sep, row.names = FALSE)
  }

  out_object <- list(
    loglik = loglik,
    predictions = data.frame(output)
  )

  return(out_object)
}
connormayer/maxent.ot documentation built on Nov. 24, 2024, 1:21 p.m.