R/db.R

CLASS_MAP <- list(
  integer = 'bigint', numeric = 'double precision', factor = 'text',
  double = 'double precision', character = 'text', logical = 'text',
  POSIXct = 'timestamp'
)

# Columns that store meta data for shards
META_COLS <- c("last_cached_at")
META_COLS_TYPE <- list(last_cached_at = 'text')

#' Database table name for a given prefix and salt.
#'
#' @param prefix character. Prefix.
#' @param salt list. Salt for the table name.
#' @return the table name. This will just be \code{"prefix_"}
#'   appended with the MD5 hash of the digest of the \code{salt}.
table_name <- function(prefix, salt) {
  tolower(paste0(prefix, "_", digest::digest(salt)))
}

#' Fetch the map of column names.
#'
#' @param dbconn SQLConnection. A database connection.
column_names_map <- function(dbconn) {
  DBI::dbGetQuery(dbconn, "SELECT * FROM column_names")
}

#' Fetch all the shards for the given table name.
#'
#' @param dbconn SQLConnection. A database connection.
#' @param tbl_name character. The calculated table name for the function.
#' @return one or many names of the shard tables.
get_shards_for_table <- function(dbconn, tbl_name) {
  if (!DBI::dbExistsTable(dbconn, "table_shard_map")) create_shards_table(dbconn, "table_shard_map")
  DBI::dbGetQuery(dbconn, paste0("SELECT shard_name FROM table_shard_map where table_name='", tbl_name, "'"))$shard_name
}

#' Generate names for new shards
#'
#' @param tbl_name character. The calculated table name for the function.
#' @param numshards numeric. The number of shards to generate names for.
generate_new_shard_names <- function(tbl_name, numshards) {
  paste0("shard", seq(numshards), "_", digest::digest(tbl_name))
}

#' Create the table <=> shards map.
#'
#' @rdname create_table
#' @param dbconn SQLConnection. A database connection.
#' @param tblname character.The table to be created
create_shards_table <- function(dbconn, tblname) {
  if (DBI::dbExistsTable(dbconn, tblname)) return(TRUE)
  sql <- paste0("CREATE TABLE ", tblname, " (table_name varchar(255) NOT NULL, shard_name varchar(255) NOT NULL);")
  DBI::dbGetQuery(dbconn, sql)
  TRUE
}

#' MD5 digest of column names.
#'
#' @param raw_names character. A character vector of column names.
#' @return the character vector of hashed names.
get_hashed_names <- function(raw_names) {
  paste0("c", vapply(raw_names, digest::digest, character(1)))
}

#' Translate column names using the column_names table from MD5 to raw.
#'
#' @param names character. A character vector of column names.
#' @param dbconn SQLConnection. A database connection.
translate_column_names <- function(names, dbconn) {
  name_map <- column_names_map(dbconn)
  name_map <- setNames(as.list(name_map$raw_name), name_map$hashed_name)
  vapply(names, function(name) name_map[[name]] %||% name, character(1))
}

#' Convert the raw fetched database table to a readable data frame.
#'
#' @param df dataframe. Raw fetched database table.
#' @param dbconn SQLConnection. A database connection.
#' @param key character. Identifier of database table.
db2df <- function(df, dbconn, key) {
  df[[key]] <- NULL
  for (meta in META_COLS) df[meta] <- NULL
  colnames(df) <- translate_column_names(colnames(df), dbconn)
  df
}

#' Create index on a table
#'
#' @param dbconn SQLConnection. A database connection.
#' @param tblname character. The name of the table to add an index to.
#' @param key character. Identifier of database table.
#' @param idx_name character. Name of the index.
add_index <- function(dbconn, tblname, key, idx_name = paste0("idx_", digest::digest(tblname))) {
  if (!tolower(substring(idx_name, 1, 1)) %in% letters) {
    stop(sprintf("Invalid index name '%s': must begin with an alphabetic character", idx_name))
  }
  DBI::dbGetQuery(dbconn, paste0("CREATE INDEX ", idx_name, " ON ", tblname, "(", key, ")"))
  TRUE
}

#' Try and check dbWriteTable until success
#'
#' @param dbconn SQLConnection. A database connection.
#' @param tblname character. Database table name.
#' @param df dataframe. The data frame to insert.
#' @param append logical. if \code{FALSE}, the prior table is overwritten.
#' @param row.names list. Row names for the table to be written.
dbWriteTableUntilSuccess <- function(dbconn, tblname, df, append = FALSE, row.names = NA) {
  if (DBI::dbExistsTable(dbconn, tblname) && !isTRUE(append)) {
    DBI::dbRemoveTable(dbconn, tblname)
  }
  if (any(is.na(df))) {
    df[, vapply(df, function(x) all(is.na(x)), logical(1))] <- as.character(NA)
  }

  repeat {
    field_classes <- vapply(df, function(col) class(col)[1L], character(1))
    field_types <- vapply(field_classes, function(klass) CLASS_MAP[[klass]], character(1))
    DBI::dbWriteTable(dbconn, tblname, df, append = append,
                      row.names = row.names, field.types = field_types)
    #TODO(kirill): repeat maximum of N times
    if (!isTRUE(append)) {
      num_rows <- DBI::dbGetQuery(dbconn, paste0("SELECT COUNT(*) FROM ", tblname))
      if (num_rows == nrow(df)) break
    } else break
  }
}

## Helper utility for safe IO of a data.frame to a database connection.
##
## This function will be mindful of three problems: non-existent columns,
## long column names, and sharding *data.frame*s with too many columns.
##
## Since this is meant to be used as a helper function for caching
## data, we must take a few precautions. If certain variables are not
## available for older data but are introduced for newer data, we
## must be careful to create those columns first.
##
## Furthermore, certain column names may be longer than PostgreSQL supports.
## To circumvent this problem, this helper function stores an MD5
## digest of each column name and maps them using the `column_names`
## helper table.
##
## By default, this function assumes any data to be written is not
## already present in the table and should be appended. If the table does
## not exist, it will be created.
##
#' Write data.frames to DB addressing pitfalls
#'
#' @param dbconn PostgreSQLConnection. The database connection.
#' @param tblname character. The table name to write the data into.
#' @param df data.frame. The data to write.
#' @param key character. The identifier column name.
#' @param safe_columns logical. Whether an error should be thrown if the columns returned
#'   are different than what is in the column cache.
#' @param blacklist list. A list of values to not cache.
#' @inheritParams cache
write_data_safely <- function(dbconn, tblname, df, key, safe_columns, blacklist) {
  if (is.null(df)) return(FALSE)
  if (!is.data.frame(df)) return(FALSE)
  if (nrow(df) == 0) return(FALSE)

  if (missing(key)) {
    id_cols <- grep("(_|^)id$", colnames(df), value = TRUE)
    if (length(id_cols) == 0)
      stop("The data you are writing to the database must contain at least one ",
           "column ending with '_id'")
  } else {
    id_cols <- key
    if (!is.integer(df[[key]]) && is.numeric(df[[key]])) {
      # TODO: (RK) Check if coercion is possible.
      df[[key]] <- as.integer(df[[key]])
    }
  }

  write_column_names_map <- function(raw_names) {
    hashed_names <- get_hashed_names(raw_names)
    column_map <- data.frame(raw_name = raw_names, hashed_name = hashed_names)
    column_map <- column_map[!duplicated(column_map), ]

    ## If we don't do this, we will get really weird bugs with numeric things stored as character
    ## For example, a row with ID 100000 will be stored as 10e+5, which is wrong.
    old_options <- options(scipen = 20, digits = 20)
    on.exit(options(old_options))

    ## Store the map of raw to MD5'ed column names in the column_names table.
    if (!DBI::dbExistsTable(dbconn, "column_names")) {
      dbWriteTableUntilSuccess(dbconn, "column_names", column_map, append = FALSE)
    } else {
      raw_names <- DBI::dbGetQuery(dbconn, "SELECT raw_name FROM column_names")[[1]]
      column_map <- column_map[!is.element(column_map$raw_name, raw_names), ]
      if (NROW(column_map) > 0) {
        dbWriteTable(dbconn, "column_names", column_map, append = TRUE, row.names = FALSE)
      }
    }
    TRUE
  }

  get_shard_names <- function(df, tbl_name) {
    ## Two cases: the shards already exist - or they don't
    ##
    ## Fetch existing shards
    shards <- get_shards_for_table(dbconn, tbl_name)

    ## come up with new shards if needed
    numcols <- NCOL(df)
    if (numcols == 0) return(NULL)
    numshards <- ceiling(numcols / MAX_COLUMNS_PER_SHARD)
    ## All data-containing tables will start with prefix *shard#{n}_*
    newshards <- generate_new_shard_names(tbl_name, numshards)
    if (length(shards) > 0) {
      ## only generate new shard names for shards that don't exist!
      unique(c(shards, newshards[-seq(length(shards))]))
    } else newshards
  }

  df2shards <- function(dbconn, df, shard_names, key) {
    ## Here comes the hard part. Sharding strategies!
    ##
    ## Here is how we're going to do it.
    ## We sort the shardnames, to ensure that the first shard is the biggest in size
    ## This way appending to a shard is trivial: if we have any columns in the
    ## dataframe that are not yet stored in the cache - just append them to the
    ## last shard!
    ## Since we've done the calculation of number of shards beforehand we
    ## don't even have to worry about creating new shards if something won't fit.
    ##
    ## Because it will.

    ## Make sure we don't store `key` in the used_columns! Need it in every dataframe
    used_columns <- c()

    ## We want to sort our shards prior to writing.
    ## Unfortunately, `sort(1:11) == c(1, 10, 11, 2, 3, ...)` which is not what we want
    ## That's why we're using a slightly more ghetto solution
    suffix <- strsplit(shard_names[1], '_')[[1]][2]
    lapply(paste0("shard", seq(length(shard_names)), "_", suffix), function (shard, last, key) {
      ## We need to create a map in the form of
      ## ```list(df = dataframe, shard_name = shard_names)```, where the dataframe is a subset
      ## of the original dataframe that contains less columns than
      ## **MAX_COLUMNS_PER_SHARD**.
      ## This is what we should do for each shard:
      ##
      ## 1. Determine which columns are already being stored in the shard
      ## 2. Take the subset of the dataframe that has these columns, assign it to a shard
      ## 3. See which columns are left unsaved, and add those to the last shard
      if (shard == last) {
        ## Write out the rest of the dataframe into the last shard
        list(df = df[setdiff(colnames(df), used_columns)], shard_name = shard)
      } else {
        ## If the response is empty, write the first N columns of the dataframe
        ## Otherwise, only write out those columns that already exist in this shard
        shard_exists <- DBI::dbExistsTable(dbconn, shard)
        if (isTRUE(shard_exists)) {
          one_row <- DBI::dbGetQuery(dbconn, paste0("SELECT * FROM ", shard, " LIMIT 1"))
        } else one_row <- NULL
        ## Here we abuse the fact that ```NROW(NULL) == 0```
        if (NROW(one_row) == 0 || NCOL(one_row) == 2) {
          ## TODO: This is very hacky...
          ## If we see only two columns in a shard, it means that we only stored
          ## the id and the hashed id. So basically this shard is useless!
          ## In this case we should drop it, and pretend this table doesn't exist
          if (NCOL(one_row) == 2) {
            DBI::dbGetQuery(dbconn, paste0("DROP TABLE ", shard))
          }
          columns <- colnames(df)
          columns <- columns[columns != key]
          columns <- setdiff(columns, used_columns)
          columns <- c(columns[1:MAX_COLUMNS_PER_SHARD - 1], key)
          used_columns <<- append(used_columns, columns[columns != key])
          list(df = df[columns], shard_name = shard)
        } else {
          columns <- unique(translate_column_names(colnames(one_row), dbconn))
          used_columns <<- append(used_columns, columns[columns != key])
          list(df = df[colnames(df) %in% columns], shard_name = shard)
        }
      }
    }, last = shard_names[length(shard_names)], key = key)
  }

  write_column_hashed_data <- function(df, tblname, append = TRUE) {
    ## Don't cache anything that is in the blacklist.
    if (length(blacklist) > 0L) {
      df_without_id <- df[-which(names(df) == key)]  # Don't blacklist on id.
      if (any(is.na(blacklist))) {
        df <- df[!apply(is.na(df_without_id), 1, all), , drop = FALSE]
        blacklist <- blacklist[!is.na(blacklist)]
      }
      if (length(blacklist) > 0L) {
        df <- df[apply(df_without_id, 1, function(x) all(!(x %in% blacklist))), , drop = FALSE]
      }
    }
    if (NROW(df) == 0 || NCOL(df) == 0) { return(NULL) }

    ## Create the mapping between original column names and their MD5 companions
    write_column_names_map(colnames(df))

    ## Store a copy of the ID columns (ending with '_id')
    id_cols_ix <- which(is.element(colnames(df), id_cols))
    colnames(df) <- get_hashed_names(colnames(df))
    df[, id_cols] <- df[id_cols_ix]

    ## Convert some types to character so they go in the DB properly.
    to_chars <- unname(vapply(df, function(x) is.factor(x) || is.ordered(x) || is.logical(x), logical(1)))
    df[to_chars] <- lapply(df[to_chars], as.character)

    ## Add meta data
    df$last_cached_at <- format(Sys.time(), tz = "UTC")

    ## Write out to postgres
    dbWriteTableUntilSuccess(dbconn, tblname, df, row.names = FALSE, append = append)
  }

  ## Use transactions!
  DBI::dbGetQuery(dbconn, "BEGIN")
  tryCatch({
    ## Find the appropriate shards for this dataframe and tablename
    shard_names <- get_shard_names(df, tblname)
    ## Create references for these shards if needed
    write_table_shard_map(dbconn, tblname, shard_names)
    ## Split the dataframe into the appropriate shards
    df_shard_map <- df2shards(dbconn, df, shard_names, key)

    ## Actually write the data to the database
    actually_write_data <- function(lst) {
      tblname <- lst$shard_name
      df <- lst$df
      create_and_index_table <- function() {
        if (is.null(write_column_hashed_data(df, tblname, append = FALSE))) {
          return(invisible(TRUE))  # Nothing was cached
        } else {
          add_index(dbconn, tblname, key, paste0("idx_", digest::digest(tblname)))
        }
      }

      if (!DBI::dbExistsTable(dbconn, tblname)) {
        ## The shard doesn't exist yet. Let's create it and index it by key!
        create_and_index_table()
        return(invisible(TRUE))
      }

      one_row <- if (DBI::dbExistsTable(dbconn, tblname)) {
        DBI::dbGetQuery(dbconn, paste("SELECT * FROM ", tblname, " LIMIT 1"))
      } else NULL

      if (NROW(one_row) == 0) {
        ## The shard is empty! Delete it and write to it, finally
        ## Also, it's a great opportunity to enforce indexes on this table!
        if (DBI::dbExistsTable(dbconn, tblname)) DBI::dbRemoveTable(dbconn, tblname)
        create_and_index_table()
        return(invisible(TRUE))
      }

      ## Columns that are missing in database need to be created
      new_names <- get_hashed_names(colnames(df))
      ## We also keep non-hashed versions of ID columns around for convenience.
      new_names <- c(new_names, id_cols, META_COLS)
      new_names_raw <- c(colnames(df), id_cols, META_COLS)
      missing_cols <- setNames(!is.element(new_names, colnames(one_row)), new_names_raw)

      missing_user_cols <- missing_cols[setdiff(new_names_raw, META_COLS)]

      if (any(missing_user_cols)) {
        if (isTRUE(safe_columns)) {
          stop("Safe Columns Error: Your function call is adding additional ",
            "columns to a cache that already has pre-existing columns. This ",
            "would suggest your cache is invalid and you should wipe the cache ",
            "and start over.")
        } else if (is.function(safe_columns)) {
          safe_columns(missing_user_cols)
        }
      }

      # TODO: (RK) Check reverse, that we're not missing any already-present columns
      removes <- integer(0)

      if (any(missing_cols) && isTRUE(getOption("cachemeifyoucan.debug"))) {
        message("Add columns to", tblname, ":", paste(names(which(missing_cols)), collapse = ", "))
      }

      for (index in which(missing_cols)) {
        col <- new_names[index]
        if (!all(vapply(col, nchar, integer(1)) > 0))
          stop("Failed to retrieve MD5 hashed column names in write_data_safely")
        # TODO: (RK) Figure out how to filter all NA columns without wrecking
        # the tables.
        if (index > length(df)) index <- col

        if (col %in% META_COLS) col_type <- META_COLS_TYPE[[col]]
        else col_type <- CLASS_MAP[[class(df[[index]])[1]]]

        sql <- paste("ALTER TABLE", tblname, "ADD COLUMN", col, col_type)
        suppressWarnings(DBI::dbGetQuery(dbconn, sql))
      }

      ## Columns that are missing in data need to be set to NA
      missing_cols <- !is.element(colnames(one_row), new_names)
      if (sum(missing_cols) > 0) {
        raw_names <- translate_column_names(colnames(one_row)[missing_cols], dbconn)
        stopifnot(is.character(raw_names))
        df[, raw_names] <- lapply(sapply(one_row[missing_cols], class), as, object = NA)
      }

      write_column_hashed_data(df, tblname, append = TRUE)
    }

    lapply(df_shard_map, actually_write_data)
  },
  warning = function(w) {
    message("An warning occurred: ", w)
    message("Rollback!")
    DBI::dbRollback(dbconn)
  },
  error = function(e) {
    message("An error occurred: ", e)
    message("Rollback!")
    DBI::dbRollback(dbconn)
    stop(e)
  },
  finally = {
    DBI::dbGetQuery(dbconn, "COMMIT")
  })

  invisible(TRUE)
}

#' Register shards in table_shard_map
#'
#' @param dbconn SQLConnection. The database connection.
#' @param tblname character. Table name.
#' @param shard_names character. Calculated shard names given the table name
write_table_shard_map <- function(dbconn, tblname, shard_names) {
  ## Example:
  ##
  ## |   | table_name  | shard_name  |
  ## |---|-------------|-------------|
  ## | 1 | tblname_1   | shard_1     |
  ## | 2 | tblname_1   | shard_2     |
  ## | 3 | tblname_2   | shard_3     |
  table_shard_map <- data.frame(table_name = rep(tblname, length(shard_names)), shard_name = shard_names)
  ## If we don't do this, we will get really weird bugs with numeric things stored as character
  ## For example, a row with ID 100000 will be stored as 10e+5, which is wrong.
  old_options <- options(scipen = 20, digits = 20)
  on.exit(options(old_options))

  ## Store the map of logical table names to physical shards in the table_shard_map table.
  if (!DBI::dbExistsTable(dbconn, 'table_shard_map')) {
    dbWriteTableUntilSuccess(dbconn, 'table_shard_map', table_shard_map, append = FALSE)
  } else {
    shards <- get_shards_for_table(dbconn, tblname)
    if (length(shards) > 0) {
      table_shard_map <- table_shard_map[table_shard_map$shard_name %nin% shards, ]
    }
    if (NROW(table_shard_map) > 0) {
      DBI::dbWriteTable(dbconn, 'table_shard_map', table_shard_map, append = TRUE, row.names = FALSE)
    }
  }
  TRUE
}
#' setdiff current ids with those in the table of the database.
#'
#' @param dbconn SQLConnection. The database connection.
#' @param tbl_name character. Database table name.
#' @param ids vector. A vector of ids.
#' @param key character. Identifier of database table.
get_new_key <- function(dbconn, tbl_name, ids, key) {
  if (length(ids) == 0) return(integer(0))
  shards <- get_shards_for_table(dbconn, tbl_name)
  ## If there are no existing shards - then nothing is cached yet
  if (length(shards) == 0) return(ids)

  if (!DBI::dbExistsTable(dbconn, shards[1])) return(ids)
  id_column_name <- get_hashed_names(key)
  ## We can check only the first shard because all shards have the same keys
  present_ids <- DBI::dbGetQuery(dbconn, paste0(
    "SELECT ", id_column_name, " FROM ", shards[1]))
  ## If the table is empty, a 0-by-0 dataframe will be returned, so
  ## we must be careful.
  present_ids <- if (NROW(present_ids)) present_ids[[1]] else integer(0)
  setdiff(ids, present_ids)
}

#' remove old keys to maintain uniqueness of "id" for the sake of force pushing
#'
#' @param dbconn SQLConnection. The database connection.
#' @param tbl_name character. Database table name.
#' @param ids vector. A vector of ids.
#' @param key character. Identifier of database table.
remove_old_key <- function(dbconn, tbl_name, ids, key) {
  if (length(ids) == 0) return(invisible(NULL))
  id_column_name <- get_hashed_names(key)
  shards <- get_shards_for_table(dbconn, tbl_name)
  if (length(shards) == 0) return(invisible(NULL))
  ## In this case though, we need to delete from all shards to keep them consistent
  sapply(shards, function(shard) {
    DBI::dbGetQuery(dbconn, paste0(
      "DELETE FROM ", shard, " WHERE ", id_column_name, " IN (",
      paste(ids, collapse = ","), ")"))
  })
  invisible(NULL)
}
robertzk/cachemeifyoucan documentation built on May 27, 2019, 10:34 a.m.