R/dbConnect.R

Defines functions dummyPrestoConnection check_tz

Documented in dummyPrestoConnection

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#' @include PrestoDriver.R PrestoConnection.R
NULL

# Adapted from https://github.com/cran/RPostgres/blob/master/R/PqConnection.R#L56
check_tz <- function(timezone) {
  arg_name <- deparse(substitute(timezone))

  if (timezone == "") {
    return(timezone)
  }
  tryCatch(
    {
      # Side effect: check if time zone is valid
      lubridate::force_tz(as.POSIXct("2021-03-01 10:40"), timezone)
      timezone
    },
    error = function(e) {
      warning(
        "Invalid time zone '", timezone, "', ",
        "falling back to Presto server timezone.\n",
        "Set the `", arg_name, "` argument to a valid time zone.\n",
        conditionMessage(e),
        call. = FALSE
      )
      ""
    }
  )
}

#' @param drv A driver object generated by [Presto()]
#' @param catalog The catalog to be used
#' @param schema The schema to be used
#' @param user The current user
#' @param host The presto host to connect to
#' @param port Port to use for the connection
#' @param source Source to specify for the connection
#' @param session.timezone Time zone of the Presto server. Presto returns
#'          timestamps without time zones with respect to this value. The time
#'          arithmetic (e.g. adding hours) will also be done in the given time
#'          zone. This value is passed to Presto server via the request headers.
#' @param output.timezone The time zone using which TIME WITH TZ and TIMESTAMP
#'          values in the output should be represented. Default to the Presto
#'          server timezone (use `show(<PrestoConnection>)` to see).
#' @param parameters A [list()] of extra parameters to be passed in
#'          the \sQuote{X-Presto-Session} header
#' @param ctes  `r lifecycle::badge("experimental")`
#'          A list of common table expressions (CTEs) that can be used in the
#'          WITH clause. See `vignette("common-table-expressions")`.
#' @param request.config An optional config list, as returned by
#'          `httr::config()`, to be sent with every HTTP request.
#' @param use.trino.headers A boolean to indicate whether Trino request headers
#'          should be used. Default to FALSE.
#' @param extra.credentials Extra credentials to be passed in the
#'          X-Presto-Extra-Credential or X-Trino-Extra-Credential header (
#'          depending on the value of the use.trino.headers argument). Default
#'          to an empty string.
#' @param bigint The R type that Presto's 64-bit integer (`BIGINT`) class should
#'          be translated to. The default is `"integer"`, which returns R's
#'          `integer` type, but results in `NA` for values above/below
#'          +/-2147483647. `"integer64"` returns a [bit64::integer64], which
#'          allows the full range of 64 bit integers. `"numeric"` coerces into
#'          R's `double` type but might result in precision loss. Lastly,
#'          `"character"` casts into R's `character` type.
#' @param ... currently ignored
#' @return [dbConnect] A [PrestoConnection-class] object
#' @importMethodsFrom DBI dbConnect
#' @importFrom methods new
#' @export
#' @rdname Presto
#' @examples
#' \dontrun{
#' conn <- dbConnect(Presto(),
#'   catalog = "hive", schema = "default",
#'   user = "onur", host = "localhost", port = 8080,
#'   session.timezone = "US/Eastern", bigint = "character"
#' )
#' dbListTables(conn, "%_iris")
#' dbDisconnect(conn)
#' }
#' @md
setMethod(
  "dbConnect",
  "PrestoDriver",
  function(drv,
           catalog,
           schema,
           user,
           host = "localhost",
           port = 8080,
           source = methods::getPackageName(),
           session.timezone = "",
           output.timezone = "",
           parameters = list(),
           ctes = list(),
           request.config = httr::config(),
           use.trino.headers = FALSE,
           extra.credentials = "",
           bigint = c("integer", "integer64", "numeric", "character"),
           ...) {
    port <- suppressWarnings(as.integer(port))
    if (!length(port) == 1 || is.na(port)) {
      stop("Please specify a port as an integer")
    }
    session.timezone <- check_tz(session.timezone)
    output.timezone <- check_tz(output.timezone)
    conn <- methods::new("PrestoConnection",
      catalog = catalog,
      schema = schema,
      user = user,
      host = host,
      port = port,
      source = source,
      session.timezone = session.timezone,
      output.timezone = output.timezone,
      request.config = request.config,
      use.trino.headers = use.trino.headers,
      session = PrestoSession$new(parameters, ctes),
      extra.credentials = extra.credentials,
      bigint = match.arg(bigint)
    )
    if(conn@session.timezone == "") {
      # Retrieve Presto server time zone and store in the connection
      conn@session.timezone <-
        DBI::dbGetQuery(conn, "SELECT current_timezone() AS tz")$tz
    }
    if (conn@output.timezone == "") {
      conn@output.timezone <- conn@session.timezone
    }
    return(conn)
  }
)

#' A dummy PrestoConnection
#' 
#' @export
#' @keywords internal
#' @examples
#' dummyPrestoConnection()
dummyPrestoConnection <- function() {
  methods::new("PrestoConnection",
    session.timezone = Sys.timezone(),
    output.timezone = Sys.timezone(),
    request.config = httr::config(),
    session = PrestoSession$new(list(), list())
  )
}
prestodb/RPresto documentation built on Feb. 28, 2024, 11:13 a.m.