R/provider-azure.R

Defines functions chat_azure

Documented in chat_azure

#' @include provider-openai.R
#' @include content.R
NULL

# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions

#' Chat with a model hosted on Azure OpenAI
#'
#' @description
#' The [Azure OpenAI server](https://azure.microsoft.com/en-us/products/ai-services/openai-service)
#' hosts a number of open source models as well as proprietary models
#' from OpenAI.
#'
#' ## Authentication
#'
#' `chat_azure()` supports API keys and the `credentials` parameter, but it also
#' makes use of:
#'
#' - Azure service principals (when the `AZURE_TENANT_ID`, `AZURE_CLIENT_ID`,
#'   and `AZURE_CLIENT_SECRET` environment variables are set).
#' - Interactive Entra ID authentication, like the Azure CLI.
#' - Viewer-based credentials on Posit Connect. Requires the \pkg{connectcreds}
#'   package.
#'
#' @param endpoint Azure OpenAI endpoint url with protocol and hostname, i.e.
#'  `https://{your-resource-name}.openai.azure.com`. Defaults to using the
#'   value of the `AZURE_OPENAI_ENDPOINT` envinronment variable.
#' @param deployment_id Deployment id for the model you want to use.
#' @param api_version The API version to use.
#' @param api_key An API key to use for authentication. You generally should not
#'   supply this directly, but instead set the `AZURE_OPENAI_API_KEY`
#'   environment variable.
#' @param token `r lifecycle::badge("deprecated")` A literal Azure token to use
#'   for authentication. Deprecated in favour of ambient Azure credentials or
#'   an explicit `credentials` argument.
#' @param credentials A list of authentication headers to pass into
#'   [`httr2::req_headers()`], a function that returns them, or `NULL` to use
#'   `token` or `api_key` to generate these headers instead. This is an escape
#'   hatch that allows users to incorporate Azure credentials generated by other
#'   packages into \pkg{ellmer}, or to manage the lifetime of credentials that
#'   need to be refreshed.
#' @inheritParams chat_openai
#' @inherit chat_openai return
#' @export
#' @examples
#' \dontrun{
#' chat <- chat_azure(deployment_id = "gpt-4o-mini")
#' chat$chat("Tell me three jokes about statisticians")
#' }
chat_azure <- function(endpoint = azure_endpoint(),
                       deployment_id,
                       api_version = NULL,
                       system_prompt = NULL,
                       turns = NULL,
                       api_key = NULL,
                       token = deprecated(),
                       credentials = NULL,
                       api_args = list(),
                       echo = c("none", "text", "all")) {
  check_exclusive(token, credentials, .require = FALSE)
  if (lifecycle::is_present(token)) {
    lifecycle::deprecate_warn(
      when = "0.1.1",
      what = "chat_azure(token)",
      details = "Support for the static `token` argument (which quickly \
                 expires) will be dropped in next release. Use ambient Azure \
                 credentials instead, or pass an explicit `credentials` \
                 argument."
    )
  } else {
    token <- NULL
  }
  check_string(endpoint)
  check_string(deployment_id)
  api_version <- set_default(api_version, "2024-10-21")
  turns <- normalize_turns(turns, system_prompt)
  check_string(api_key, allow_null = TRUE)
  api_key <- api_key %||% Sys.getenv("AZURE_OPENAI_API_KEY")
  check_string(token, allow_null = TRUE)
  echo <- check_echo(echo)

  if (is_list(credentials)) {
    static_credentials <- force(credentials)
    credentials <- function() static_credentials
  }
  check_function(credentials, allow_null = TRUE)
  credentials <- credentials %||% default_azure_credentials(api_key, token)

  provider <- ProviderAzure(
    endpoint = endpoint,
    deployment_id = deployment_id,
    api_version = api_version,
    api_key = api_key,
    credentials = credentials,
    extra_args = api_args
  )
  Chat$new(provider = provider, turns = turns, echo = echo)
}

chat_azure_test <- function(system_prompt = NULL, ...) {
  api_key <- key_get("AZURE_OPENAI_API_KEY")

  chat_azure(
    ...,
    system_prompt = system_prompt,
    api_key = api_key,
    endpoint = "https://ai-hwickhamai260967855527.openai.azure.com",
    deployment_id = "gpt-4o-mini"
  )
}

ProviderAzure <- new_class(
  "ProviderAzure",
  parent = ProviderOpenAI,
  constructor = function(endpoint, deployment_id, api_version, api_key,
                         credentials, extra_args = list()) {
    new_object(
      ProviderOpenAI(
        base_url = paste0(endpoint, "/openai/deployments/", deployment_id),
        model = deployment_id,
        api_key = api_key,
        extra_args = extra_args
      ),
      api_version = api_version,
      credentials = credentials
    )
  },
  properties = list(
    credentials = class_function,
    api_version = prop_string()
  )
)

# https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints#api-key
azure_endpoint <- function() {
  key_get("AZURE_OPENAI_ENDPOINT")
}

# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
method(chat_request, ProviderAzure) <- function(provider,
                                                stream = TRUE,
                                                turns = list(),
                                                tools = list(),
                                                type = NULL) {

  req <- request(provider@base_url)
  req <- req_url_path_append(req, "/chat/completions")
  req <- req_url_query(req, `api-version` = provider@api_version)
  if (nchar(provider@api_key)) {
    req <- req_headers_redacted(req, `api-key` = provider@api_key)
  }
  req <- req_headers(req, !!!provider@credentials(), .redact = "Authorization")
  req <- req_retry(req, max_tries = 2)
  req <- ellmer_req_timeout(req, stream)
  req <- req_error(req, body = function(resp) {
    error <- resp_body_json(resp)$error
    msg <- paste0(error$code, ": ", error$message)
    # Try to be helpful in the (common) case that the user or service
    # principal is missing the necessary role.
    # See: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/role-based-access-control
    bad_rbac <- identical(
      error$message,
      "Principal does not have access to API/Operation."
    )
    if (bad_rbac) {
      msg <- c(
        "*" = msg,
        "i" = cli::format_inline(
          "Your user or service principal likely needs one of the following
        roles: {.emph Cognitive Services OpenAI User},
        {.emph Cognitive Services OpenAI Contributor}, or
        {.emph Cognitive Services Contributor}.",
          keep_whitespace = FALSE
        )
      )
    }
    msg
  })

  messages <- compact(unlist(as_json(provider, turns), recursive = FALSE))
  tools <- as_json(provider, unname(tools))

  if (!is.null(type)) {
    response_format <- list(
      type = "json_schema",
      json_schema = list(
        name = "structured_data",
        schema = as_json(provider, type),
        strict = TRUE
      )
    )
  } else {
    response_format <- NULL
  }

  body <- compact(list2(
    messages = messages,
    model = provider@model,
    seed = provider@seed,
    stream = stream,
    stream_options = if (stream) list(include_usage = TRUE),
    tools = tools,
    response_format = response_format
  ))
  body <- modify_list(body, provider@extra_args)
  req <- req_body_json(req, body)

  req
}

default_azure_credentials <- function(api_key = NULL, token = NULL) {
  if (!is.null(token)) {
    return(function() list(Authorization = paste("Bearer", token)))
  }

  azure_openai_scope <- "https://cognitiveservices.azure.com/.default"

  # Detect viewer-based credentials from Posit Connect.
  if (has_connect_viewer_token(scope = azure_openai_scope)) {
    return(function() {
      token <- connectcreds::connect_viewer_token(scope = azure_openai_scope)
      list(Authorization = paste("Bearer", token$access_token))
    })
  }

  # Detect Azure service principals.
  tenant_id <- Sys.getenv("AZURE_TENANT_ID")
  client_id <- Sys.getenv("AZURE_CLIENT_ID")
  client_secret <- Sys.getenv("AZURE_CLIENT_SECRET")
  if (nchar(tenant_id) && nchar(client_id) && nchar(client_secret)) {
    # Service principals use an OAuth client credentials flow. We cache the token
    # so we don't need to perform this flow before each turn.
    client <- oauth_client(
      client_id,
      token_url = paste0(
        "https://login.microsoftonline.com/",
        tenant_id,
        "/oauth2/v2.0/token"
      ),
      secret = client_secret,
      auth = "body",
      name = "ellmer-azure-sp"
    )
    return(function() {
      token <- oauth_token_cached(
        client,
        oauth_flow_client_credentials,
        flow_params = list(scope = azure_openai_scope),
        # Don't use the cached token when testing.
        reauth = is_testing()
      )
      list(Authorization = paste("Bearer", token$access_token))
    })
  }

  # If we have an API key, rely on that for credentials.
  if (nchar(api_key)) {
    return(function() list())
  }

  # Masquerade as the Azure CLI.
  client_id <- "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
  if (is_interactive() && !is_hosted_session()) {
    client <- oauth_client(
      client_id,
      token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token",
      secret = "",
      auth = "body",
      name = paste0("ellmer-", client_id)
    )
    return(function() {
      token <- oauth_token_cached(
        client,
        oauth_flow_auth_code,
        flow_params = list(
          auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
          scope = paste(azure_openai_scope, "offline_access"),
          redirect_uri = "http://localhost:8400",
          auth_params = list(prompt = "select_account")
        )
      )
      list(Authorization = paste("Bearer", token$access_token))
    })
  }

  if (is_testing()) {
    testthat::skip("no Azure credentials available")
  }

  cli::cli_abort("No Azure credentials are available.")
}

Try the ellmer package in your browser

Any scripts or data that you put into this service are public.

ellmer documentation built on April 4, 2025, 3:53 a.m.