R/replicateAPI4R.R

Defines functions replicatellmAPI4R

Documented in replicatellmAPI4R

#' replicatellmAPI4R: Interact with Replicate API for LLM models in R
#'
#' @description This function interacts with the Replicate API (v1) to utilize language models (LLM) such as Llama. It sends a POST request with the provided input and handles both streaming and non-streaming responses.
#'
#' @param input A list containing the API request body with parameters including prompt, max_tokens, top_k, top_p, min_tokens, temperature, system_prompt, presence_penalty, and frequency_penalty.
#' @param model_url A character string specifying the model endpoint URL (e.g., "/models/meta/meta-llama-3.1-405b-instruct/predictions").
#' @param simple A logical value indicating whether to return a simplified output (only the model output) if TRUE, or the full API response if FALSE. Default is TRUE.
#' @param fetch_stream A logical value indicating whether to fetch a streaming response. Default is FALSE.
#' @param api_key A character string representing the Replicate API key. Defaults to the environment variable "Replicate_API_KEY".
#'
#' @importFrom httr add_headers POST GET content
#' @importFrom jsonlite toJSON fromJSON
#' @importFrom curl new_handle handle_setopt handle_setheaders curl_fetch_stream
#' @importFrom assertthat assert_that is.string is.flag noNA
#'
#' @return If fetch_stream is FALSE, returns either a simplified output (if simple is TRUE) or the full API response. In streaming mode, outputs the response stream directly to the console.
#'
#' @examples
#' \dontrun{
#'   Sys.setenv(Replicate_API_KEY = "Your API key")
#'   input <- list(
#'     input = list(
#'       prompt = "What is the capital of France?",
#'       max_tokens = 1024,
#'       top_k = 50,
#'       top_p = 0.9,
#'       min_tokens = 0,
#'       temperature = 0.6,
#'       system_prompt = "You are a helpful assistant.",
#'       presence_penalty = 0,
#'       frequency_penalty = 0
#'     )
#'   )
#'   model_url <- "/models/meta/meta-llama-3.1-405b-instruct/predictions"
#'   response <- replicatellmAPI4R(input, model_url)
#'   print(response)
#' }
#'
#' @export replicatellmAPI4R
#' @author Satoshi Kume
replicatellmAPI4R <- function(input,
                              model_url,
                              simple = TRUE,
                              fetch_stream = FALSE,
                              api_key = Sys.getenv("Replicate_API_KEY")) {

  # Validate input arguments using assertthat functions
  assertthat::assert_that(
    #input = list(NA)
    is.list(input),
    assertthat::is.string(model_url),
    assertthat::is.flag(simple),
    assertthat::is.flag(fetch_stream),
    assertthat::is.string(api_key),
    assertthat::noNA(api_key)
  )

  # Define the base API URL and construct the full API endpoint URL by concatenating the base URL and model URL
  api_url <- "https://api.replicate.com/v1/"
  api_url0 <- paste0(api_url, model_url)
  # Remove any accidental double slashes from the URL
  api_url0 <- gsub("\\/\\/", "\\/", api_url0)

  # Configure HTTP headers for the API request, including content type and authorization
  headers <- httr::add_headers(
    `Content-Type` = "application/json",
    `Authorization` = paste("Bearer", api_key)
  )

  # Define the body of the API request using the input provided
  body <- input

  # Send a POST request to the Replicate API endpoint with the JSON-encoded body and headers
  response <- httr::POST(
    url = api_url0,
    body = jsonlite::toJSON(body, auto_unbox = TRUE),
    encode = "json",
    config = headers
  )

  # If fetch_stream is TRUE, handle streaming response mode
  if (fetch_stream) {
    # Get the URL to poll for the prediction result
    get_url <- httr::content(response, "parsed")$urls$get

    # Initialize result as NULL to start polling
    result <- NULL
    # Poll the get_url until the result is ready
    while (is.null(result)) {
      response_output <- httr::GET(get_url, headers)
      # Parse the response text into JSON
      content <- jsonlite::fromJSON(httr::content(response_output, "text", encoding = "UTF-8"))

      if (content$status == "succeeded") {
        # If prediction succeeded, assign the result to response_result and update result to exit the loop
        response_result <- content
        result <- response_result
      } else if (content$status == "failed") {
        # If prediction failed, throw an error
        stop("Prediction failed")
      } else {
        # Wait for 1 second before polling again
        Sys.sleep(1)
      }
    }

    # Return a simplified output if simple is TRUE, otherwise return the full response
    if (simple) {
      return(response_result$output)
    } else {
      return(response_result)
    }
  } else {
    # If fetch_stream is FALSE, handle non-streaming mode

    # Get the streaming URL from the API response
    stream_url <- httr::content(response, "parsed")$urls$stream

    # Define a callback function to process streaming data chunks as they arrive
    streaming_callback <- function(data) {
      # Convert raw data to character
      message <- rawToChar(data)
      # Output the message with a newline for clarity
      cat(message, "\n")
      # Return TRUE to indicate processing was successful
      TRUE
    }

    # Create a new curl handle for the streaming request
    stream_handle <- curl::new_handle()
    # Set the streaming URL on the handle
    curl::handle_setopt(stream_handle, url = stream_url)
    # Set the required headers for the streaming request, including authorization and content type
    curl::handle_setheaders(stream_handle,
      Authorization = paste("Bearer", api_key),
      `Content-Type` = "application/json"
    )

    # Send the streaming request, processing data using the defined callback function
    curl::curl_fetch_stream(url = stream_url, fun = streaming_callback, handle = stream_handle)
  }
}

Try the chatAI4R package in your browser

Any scripts or data that you put into this service are public.

chatAI4R documentation built on April 4, 2025, 1:06 a.m.