#' @include Rdbtools.R
# This extends the AthenaConnection S4 class defined by noctua to
# be a class we can use in MoJ. This has two effects, one is to add
# a slot with extra details (MoJdetails) and the other is that it
# lets us define a new set of methods for the MoJAthenaConnection
# objects - these new methods can be calls to the AthenaConnection
# methods (which work by inheritance) but with some pre-processing.
setClass(
"MoJAthenaConnection",
contains="AthenaConnection",
slots=c(MoJdetails="environment")
)
#' connect_athena
#'
#' Creates a connection object which permits the user to interact with the
#' Athena databases that are hosted on the MoJ's Analytical Platform.
#' It uses the [noctua package][https://dyfanjones.github.io/noctua/], with the MoJ's authentication.
#' This returns an object with class MoJAthenaConnection, which inherits
#' methods from noctua's AthenaConnection class, which in turn are DBI
#' methods.
#' In general the expected usage is to run the function with no arguments to
#' get a standard database connection, which should work for most basic data
#' access purposes.
#'
#' @param aws_region This is the region where the database is held. If unset or NULL then will default to the AP's region.
#' @param staging_dir This the s3 location where outputs of queries can be held. If unset or NULL then will default to a session specific temporary dir.
#' @param rstudio_conn_tab Set this to true to show this connection in you RStudio connections frame (warning: this takes a long time to load because of the number of databases in the AP's Athena)
#' @param session_duration The number of seconds which the session should last before needing new authentication. Minimum of 900.
#' @param role_session_name This is a parameter for authentication, and should be left to NULL in normal operation.
#' @param schema_name This is the default database that tables not specifying a database will be looked in. If this is set to the string `__temp__` then it will use (and create if required) the temporary database based on your username - this is useful for using dbplyr which does not understand the `__temp__` keyword, alongside the DBI commands.
#' @param ... Other agruments passed to `dbConnect`
#'
#' @examples
#' con <- connect_athena() # creates a connection with sensible defaults
#' data <- dbGetQuery(con, "SELECT * FROM database.table") # queries and puts data in R environment
#' dbDisconnect(con) # disconnects the connection
#'
#' @seealso See also noctua's documentation for connecting to a database [noctua::dbConnect,AthenaDriver-method]
#' @export
connect_athena <- function(aws_region = NULL,
staging_dir = NULL,
rstudio_conn_tab = FALSE,
session_duration = 3600L,
role_session_name = NULL,
schema_name = "default",
...
) {
if (is.null(aws_region)) aws_region <- get_region()
aws_role_arn <- Sys.getenv('AWS_ROLE_ARN')
aws_web_identity_token_file <- Sys.getenv('AWS_WEB_IDENTITY_TOKEN_FILE')
if (nchar(aws_role_arn) > 0 & nchar(aws_web_identity_token_file) > 0) {
credentials <- get_aws_credentials(aws_region,
session_duration,
role_session_name)
authentication_expiry <- as.POSIXct(credentials$Credentials$Expiration, origin = "1970-01-01", tz="UTC")
# Get the roles assumed
user_id <- credentials$AssumedRoleUser$AssumedRoleId
role_session_name <- strsplit(user_id, ":")[[1]][[2]]
# work out what your staging dir should be on the AP if unset
if (is.null(staging_dir)) {
staging_dir <- get_staging_dir_from_userid(user_id)
}
# this works out the temp db name from the user id
temp_db_name <- get_database_name_from_userid(user_id)
if (schema_name == "__temp__") {
schema_name_set <- temp_db_name
} else {
schema_name_set <- schema_name
}
# connect to athena
# returns an AthenaConnection object, see noctua docs for details
con <- dbConnect(noctua::athena(),
region_name = aws_region,
s3_staging_dir = staging_dir,
rstudio_conn_tab = rstudio_conn_tab,
aws_access_key_id = credentials$Credentials$AccessKeyId,
aws_secret_access_key = credentials$Credentials$SecretAccessKey,
aws_session_token = credentials$Credentials$SessionToken,
schema_name = schema_name_set,
...)
} else {
# get the athena user id, needed for staging dir and temp db name
svc <- paws::sts(config=list(region=aws_region))
user_id <- svc$get_caller_identity()$UserId
#temporary_authentication <- FALSE
authentication_expiry <- NULL
role_session_name <- NULL
# work out what your staging dir should be on the AP if unset
if (is.null(staging_dir)) {
staging_dir = get_staging_dir_from_userid(user_id)
}
# this works out the temp db name from the user id
temp_db_name <- get_database_name_from_userid(user_id)
if (schema_name == "__temp__") {
schema_name_set <- temp_db_name
} else {
schema_name_set <- schema_name
}
# connect to athena
# returns an AthenaConnection object, see noctua docs for details
con <- dbConnect(noctua::athena(),
region_name = aws_region,
s3_staging_dir = staging_dir,
rstudio_conn_tab = rstudio_conn_tab,
schema_name = schema_name,
...)
}
# coerce the AthenaConnection object to be a MoJAthenaConnection object
# this just adds the slot MoJdetails, as defined in setClass above
con <- as(con,"MoJAthenaConnection")
# then we can set the extra details we need in MoJ in the new slot
con@MoJdetails$user_id <- user_id
con@MoJdetails$role_session_name <- role_session_name
con@MoJdetails$aws_region <- aws_region
con@MoJdetails$staging_dir <- staging_dir
con@MoJdetails$authentication_expiry <- authentication_expiry
con@MoJdetails$session_duration_set <- session_duration
con@MoJdetails$temp_db_name <- temp_db_name
con@MoJdetails$temp_db_exists <- NA # Don't know if the temp db exists yet
# this checks that the temp database exists if it is set as the default db
if (schema_name == "__temp__") {
result <- athena_temp_db(con, check_exists = TRUE)
}
return(con)
}
#' refresh_athena_connection
#'
#' Refreshes an athena connection to the AP (e.g. if the credentials have expired).
#'
#' @param conn This is the connection which will be refreshed.
#'
#' @examples
#' con <- connect_athena() # creates a connection with sensible defaults
#' data <- dbGetQuery(con, "SELECT * FROM database.table") # queries and puts data in R environment
#' # Some time later...
#' refresh_athena_connection(con) # refresh the connection for any further queries on the same session
#' dbDisconnect(con) # disconnects the connection
#' @export
refresh_athena_connection <- function(conn) {
role_session_name <- conn@MoJdetails$role_session_name
aws_region <- conn@MoJdetails$aws_region
staging_dir <- conn@MoJdetails$staging_dir
session_duration <- conn@MoJdetails$session_duration_set
conn_refreshed <- connect_athena(aws_region = aws_region,
staging_dir = staging_dir,
session_duration = session_duration,
role_session_name = role_session_name)
# updates the conn slots which are environments to the new refreshed versions
slotNames(conn) %>%
purrr::walk(function(name_of_slot) {
slot_old <- slot(conn, name_of_slot)
slot_new <- slot(conn_refreshed, name_of_slot)
if(class(slot_old) == "environment") {
#for(n in ls(slot_old, all.names=TRUE)) rm(n, envir = slot_old)
for(n in ls(slot_new, all.names=TRUE)) assign(n, get(n, slot_new), slot_old)
}
})
invisible(conn_refreshed)
}
#' refresh_if_expired
#'
#' Refreshes an athena connection to the AP only if the credentials have expired.
#'
#' @param conn This is the connection which has expired, but will be refreshed.
#' @param window The number of seconds in advance of expiry that a refresh will still happen (default 5 mins).
#'
#' @export
refresh_if_expired <- function(conn, window = 5 * 60) {
if (!is_auth_within_expiry(conn, window)) {
refresh_athena_connection(conn)
message("Refreshed credentials")
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.