R/complete_chat.R

Defines functions complete_chat

Documented in complete_chat

#' Complete an LLM Chat
#'
#' @description
#' Submits a prompt to OpenAI's "Chat" API endpoint and formats the response into a string or tidy dataframe.
#'
#'
#' @param prompt The prompt
#' @param model  Which OpenAI model to use. Defaults to 'gpt-3.5-turbo'
#' @param openai_api_key Your API key. By default, looks for a system environment variable called "OPENAI_API_KEY" (recommended option). Otherwise, it will prompt you to enter the API key as an argument.
#' @param max_tokens How many tokens (roughly 4 characters of text) should the model return? Defaults to a single token (next word prediction).
#' @param temperature A numeric between 0 and 2 When set to zero, the model will always return the most probable next token. For values greater than zero, the model selects the next word probabilistically.
#' @param seed An integer. If specified, the OpenAI API will "make a best effort to sample deterministically".
#' @param parallel TRUE to submit API requests in parallel. Setting to FALSE can reduce rate limit errors at the expense of longer runtime.
#'
#' @return If max_tokens = 1, returns a dataframe with the 5 most likely next-word responses and their probabilities. If max_tokens > 1, returns a single string of text generated by the model.
#' @export
#'
#' @examples \dontrun{
#' format_chat('Are frogs sentient? Yes or No.') |> complete_chat()
#' format_chat('Write a haiku about frogs.') |> complete_chat(max_tokens = 100)
#' }
complete_chat <- function(prompt,
                          model = 'gpt-3.5-turbo',
                          openai_api_key = Sys.getenv('OPENAI_API_KEY'),
                          max_tokens = 1,
                          temperature = 0,
                          seed = NULL,
                          parallel = FALSE) {

  if(openai_api_key == ''){
    stop("No API key detected in system environment. You can enter it manually using the 'openai_api_key' argument.")
  }

  # function to return a formatted API request
  format_request <- function(prompt,
                             base_url = "https://api.openai.com/v1/chat/completions"){

    logprobs <- max_tokens == 1
    top_logprobs <- NULL
    if(logprobs) top_logprobs <- 20

    httr2::request(base_url) |>
      # headers
      httr2::req_headers('Authorization' = paste("Bearer", openai_api_key)) |>
      httr2::req_headers("Content-Type" = "application/json") |>
      # body
      httr2::req_body_json(list(model = model,
                                messages = prompt,
                                temperature = temperature,
                                max_tokens = max_tokens,
                                logprobs = logprobs,
                                top_logprobs = top_logprobs,
                                seed = seed)) #|>
      #httr2::req_retry(max_tries = 10)
  }

  # format a list of requests
  if(is.character(prompt[[1]][[1]])) prompt <- list(prompt) # if prompt is singular, this condition will be true
  reqs <- lapply(prompt, format_request)# Map(f = format_request, prompt = prompt)

  # submit prompts sequentially or in parallel
  if(parallel){
    # 20 concurrent requests per host seems to be the optimum
    resps <- httr2::req_perform_parallel(reqs, pool = curl::new_pool(host_con = 20))
  } else{
    resps <- httr2::req_perform_sequential(reqs)
  }

  # parse the responses
  parsed <- resps |>
    lapply(httr2::resp_body_string) |>
    lapply(jsonlite::fromJSON, flatten=TRUE)

  # if max_tokens > 1, return the text
  to_return <- unlist(lapply(parsed, function(x) x$choices$message.content))

  # if max_tokens == 1, return a tidy dataframe of probabilities for each prompt
  if(max_tokens == 1){

    df <- lapply(parsed, function(x) x$choices$logprobs.content[[1]]$top_logprobs[[1]])

    # to_return <- df |>
    #   lapply(function(x) dplyr::mutate(x, probability = exp(logprob))) |>
    #   lapply(function(x) dplyr::select(x, token, probability))
    to_return <- df |>
      lapply(function(x) cbind(x, probability = exp(x[['logprob']]))) |>
      lapply(function(x) x[,c('token', 'probability')])

    # don't return it as a list if there's only one prompt in the input
    if(length(prompt) == 1){
      to_return <- to_return[[1]]
    }

  }

  return(to_return)

  # # httr code adapted from https://github.com/irudnyts/openai
  #
  # ## Build path parameters ----------------------
  #
  # task <- "chat/completions"
  #
  # base_url <- glue::glue("https://api.openai.com/v1/{task}")
  #
  # headers <- c(
  #   "Authorization" = paste("Bearer", openai_api_key),
  #   "Content-Type" = "application/json"
  # )
  #
  # ## Build request body ----------------------------
  #
  # body <- list()
  # body[['model']] <- model
  # body[['messages']] <- prompt
  # body[['max_tokens']] <- max_tokens
  # body[['temperature']] <- temperature
  # if(max_tokens == 1){
  #   body[['logprobs']] <- TRUE
  #   body[['top_logprobs']] <- 5
  # }
  # body[['seed']] <- seed
  #
  # ## Make a request and parse it ----------------
  # response <- httr::POST(
  #   url = base_url,
  #   httr::add_headers(.headers = headers),
  #   body = body,
  #   encode = "json"
  # )
  #
  # parsed <- response |>
  #   httr::content(as = "text", encoding = "UTF-8") |>
  #   jsonlite::fromJSON(flatten = TRUE)
  #
  # ## Check whether request failed and return parsed --------------
  #
  # if (httr::http_error(response)) {
  #   paste0(
  #     "OpenAI API request failed [",
  #     httr::status_code(response),
  #     "]:\n\n",
  #     parsed$error$message
  #   ) |>
  #     stop(call. = FALSE)
  # }
  #
  # # if max_tokens > 1, return the text
  # to_return <- parsed$choices$message.content
  #
  # # if max_tokens == 1, return a tidy dataframe of probabilities for each prompt
  # if(max_tokens == 1){
  #
  #   df <- parsed$choices$logprobs.content[[1]]$top_logprobs[[1]]
  #
  #   df$probability <- exp(df$logprob)
  #
  #   to_return <- data.frame(token = df$token,
  #                           probability = df$probability)
  #
  # }
  #
  # return(to_return)

}

Try the promptr package in your browser

Any scripts or data that you put into this service are public.

promptr documentation built on Sept. 11, 2024, 8:15 p.m.