R/R6_Postgres.R

Postgres <- R6::R6Class(classname="Postgres", public=list(
  user = NA,
  dbname = NA,
  con = NA,
  
  initialize = function(user, dbname) {
    if (!requireNamespace("RPostgres", quietly = TRUE)) {
      stop("Package \"RPostgres\" needed for Postgres class to work. Please install it.",
        call. = FALSE)
    }
    require(RPostgres)
    stopifnot(is.character(user), length(user) == 1)
    stopifnot(is.character(dbname), length(dbname) == 1)
    self$user = user
    self$dbname = dbname
    self$con = self$connect()
  },
  
  connect = function() {
    # Connect to Postgres database
    con = dbConnect(
      RPostgres::Postgres(), host='localhost', port='5432',
      dbname=self$dbname, user=self$user)
    return(con)
  },
  
  list_tables = function(schema) {
    # List tables in a Postgres database
    tables = 'SELECT table_name
      FROM INFORMATION_SCHEMA."tables"
      WHERE table_schema = \'%s\'' %>%
      sprintf(schema) %>%
      self$read_sql() %>%
      pull(table_name)
    
    return(tables)
  },
  
  table_exists = function(schema, table) {
    # Check if Postgres table exists
    tables = self$list_tables(schema)
    
    if (table %in% tables) {
      return (TRUE)
    } else {
      return (FALSE)
    }
  },
  
  colnames = function(schema, table) {
    # Return character vector of column names of table
    column_sql = sprintf("
      SELECT column_name FROM information_schema.columns
      WHERE table_schema = '%s' AND table_name = '%s'",
      schema, table)
    return(as.character(self$read_sql(column_sql)[,1]))
  },
  
  coldtypes = function(schema, table, column=NULL) {
    # Return dataframe with column name and datatype, each row represents a column in a table
    dtype_sql = sprintf("
      SELECT column_name, data_type
      FROM INFORMATION_SCHEMA.COLUMNS
      WHERE table_schema = '%s'
      AND table_name = '%s'",
      schema, table)
    out = self$read_sql(dtype_sql)
    if (!is.null(column)) {
      require(dplyr)
      out = dtypedf %>% filter(column_name == column) %>% pull(data_type)
    }
    return(out)
  },
  
  read_sql = function(sql) {
    # Use dbGetQuery() to get SQL results
    dbGetQuery(self$con, sql)
  },
  
  execute = function(sql, transaction=TRUE, logfile=NULL) {
    # Execute SQL with optional transaction
    stopifnot(is.character(sql))
    if (transaction) {
      dbBegin(self$con)
      tryCatch({
        for (stmt in sql) {
          dbExecute(self$con, stmt)
        }
        if (!is.null(logfile)) {
          write.table(sql, file=logfile, append=TRUE, row.names=FALSE, col.names=FALSE, quote=FALSE)  
        }
        dbCommit(self$con)
      }, error = function(e) {
        dbRollback(self$con)
        stop(e)
      })
    } else {
      for (stmt in sql) {
        dbExecute(self$con, stmt)
      }
    }
    invisible(self)
  },
  
  read_table = function(schema, table) {
    # Read all data frame a table
    return(self$read_sql(sprintf("SELECT * FROM %s.%s", schema, table)))
  },
  
  write_table = function(schema, table, df) {
    # Overwrite database table by truncating and copying from a temporary csv
    tryCatch({
      tmpfile = file.path("~", "Postgres-dbWriteTable-tmp.csv")
      write.table(df, tmpfile, sep=",", row.names=FALSE, na=NULL)
      sql = sprintf(
        "COPY %s.%s (%s)
        FROM '%s' DELIMITER ',' CSV HEADER",
        schema,
        table,
        paste0(colnames(df), collapse=", "),
        path.expand(tmpfile)
      )
      self$execute(sprintf("TRUNCATE TABLE %s.%s;", schema, table), transaction=FALSE)
      self$execute(sql, transaction=FALSE)
      if (file.exists(tmpfile)) file.remove(tmpfile)  
    }, error = function(e) {
      if (file.exists(tmpfile)) file.remove(tmpfile)
      stop(e)
    })
    
  },
  
  validate_dtype = function(schema, table, col, val) {
    # Query database for datatype of value and validate that the R value to insert
    # to that column is compatible with the SQL datatype
    
    # Arguments:
    #   schema {char} -- schema name
    #   table {char} -- table name
    #   column {char} -- column name
    #   val: value to check against schema.table.column.datatype
    #
    # Returns:
    #   logical
    
    # Get datatype of schema.table.column
    dtype = self$coldtypes(schema, table)
    dtype = dtype$data_type[dtype$col == col]
    if (!length(dtype)) {
      echo(sprintf("Column %s.%s.%s does not exist!", schema, table, col), abort=TRUE)
    }
    
    # Map all R datatypes to PostgreSQL datatypes. These combinations must match
    map = list(
      "bigint"                      = "numeric",
      "int8"                        = "numeric",
      "bigserial"                   = "numeric",
      "serial8"                     = "numeric",
      "integer"                     = "numeric",
      "int"                         = "numeric",
      "int4"                        = "numeric",
      "smallint"                    = "numeric",
      "int2"                        = "numeric",
      "float"                       = "numeric",
      "float4"                      = "numeric",
      "float8"                      = "numeric",
      "numeric"                     = "numeric",
      "decimal"                     = "numeric",
      "character"                   = "character",
      "char"                        = "character",
      "character varying"           = "character",
      "varchar"                     = "character",
      "text"                        = "character",
      "date"                        = "Date",
      "timestamp"                   = "character",
      "timestamp with time zone"    = "character",
      "timestamp without time zone" = "character",
      "boolean"                     = "logical",
      "bool"                        = "logical"
    )
    
    # Get R equivalent of SQL column datatype according to above map. This *should* be the same
    # datatype as 'val'. If not, only return TRUE if 'val' can be coerced to that datatype.
    r_dtype_ideal = map[dtype][[1]]
    if (is.null(r_dtype_ideal)) {
      echo(gsub("\\s+", " ", sprintf("Column %s.%s.%s is datatype '%s' which is not
        in datatype map in class method Postgres.validate_dtype",
        schema, table, col, dtype)), abort=TRUE)
    }
    
    # Prepare message in case of error. Will most likely not be needed.
    msg = sprintf("SQL column %s.%s.%s is type '%s' but R value '%s' is type '%s'",
      schema, table, col, dtype, val, class(val))
    
    # Begin validation
    if (r_dtype_ideal == "logical") {
      if (is.logical(val)) {
        return(TRUE)
      } else {
        if (is.character(val)) {
          if (tolower(val) %in% c("t", "true", "f", "false")) {
            return(TRUE)
          } else {
            # echo(msg, abort=TRUE)
            return(FALSE)
          }
        } else {
          # echo(msg, abort=TRUE)
          return(FALSE)
        }
      }
    } else if (r_dtype_ideal == "numeric") {
      if (is.numeric(val)) {
        return(TRUE)
      } else {
        tryCatch({
          test = as.numeric(val)
          return(TRUE)
        }, error = function(e) {
          return(FALSE)
        })
      }
    } else if (r_dtype_ideal == "Date") {
      if (class(val) == "Date") {
        return(TRUE)
      } else {
        tryCatch({
          test = as.Date(val)
          return(TRUE)
        }, error = function(e) {
          return(FALSE)
        })
      }
    } else {
      # String
      tryCatch({
        test = as.character(val)
        return(TRUE)
      }, error = function(e) {
        return(FALSE)
      })
    }
  },
  
  build_update = function(schema, table, pkey_name, pkey_value, columns, values, validate=TRUE) {
    #                              Construct SQL UPDATE statement
    # By default, this method will:
    #   * Attempt to coerce a date value to proper format if the input value is detect_dtype
    #     as a date but possibly in the improper format. Ex: '2019:02:08' -> '2019-02-08'
    #   * Quote all values passed in as strings. This will include string values that are
    #     coercible to numerics. Ex: '5', '7.5'.
    #   * Do not quote all values passed in as integer or boolean values.
    #   * Primary key value is quoted if passed in as a string. Otherwise, not quoted.
    
    # Arguments:
    #   schema     : char    : schema name
    #   table      : char    : table name
    #   pkey_name  : char    : name of primary key in table
    #   pkey_value : value   : value of primary key for value to update
    #   columns    : list    : columns to consider in UPDATE statement
    #   values     : list    : values to consider in UPDATE statement
    #   validate   : logical : Query column type from DB, validate that datatypes of existing
    #                          column and value to insert are the same [OPTIONAL]
    #
    # Returns:
    #   char
    suppressPackageStartupMessages(require(lubridate))
    stopifnot(length(columns) == length(values))
    null_eq = list(NULL, NA, "nan", "n/a", "null", "")
    logical_eq = c("true", "false")
    
    # Escape quotes in primary key value
    pkey_value = self$.__handle_single_quote__(pkey_value)
    
    # Get Column-Value pairs applying proper quoting and handling any values equivalent to NULL
    col_val_pairs = sapply(1:length(columns), function(i) {
      col = columns[[i]]
      val = values[[i]]
      
      # Detect if value is NULL equivalent
      if (is.null(val)) {
        return(sprintf("%s=NULL", col))
      } else if (tolower(val) %in% unlist(null_eq)) {
        return(sprintf("%s=NULL", col))
      }
      
      # Check if value compatible with column. Abort program if fails
      if (validate) {
        is_validated = self$validate_dtype(schema, table, col, val)
        if (!is_validated) {
          require(crayon)
          col_dtype = self$coldtypes(schema, table, col)
          echo(gsub("\\s+", " ", sprintf("SQL column %s.%s.%s is type %s
            but R value %s is type %s",
            schema, table, col, crayon::red(col_dtype),
            crayon::red(val), crayon::red(class(val)))), abort=TRUE)
        }
      }
      
      # Quote value if string or date, do not quote if numeric or logical
      if (is.logical(val) || tolower(val) %in% logical_eq) {
      } else if (is.numeric(val)) {
      } else if (is.Date(val)) {
        tryCatch({
          val = as.Date(val)
          val = self$.__handle_single_quote__(val)
        }, error = function(e) {
        })
      } else {
        # Char case (most)
        val = self$.__handle_single_quote__(val)
      }
      return(sprintf("%s=%s", col, val))
    })
    
    # Get Key-KeyValue pairs to put in "...WHERE ___" clause
    pkey_pvalue_pairs = paste(paste(pkey_name, "=", pkey_value), collapse=" AND ")
    
    # Paste into SQL string
    sql = sprintf("UPDATE %s.%s SET %s WHERE %s;",
      schema, table, paste(col_val_pairs, collapse = ", "), pkey_pvalue_pairs)
    return(sql)
    },
  
  build_insert = function(schema, table, columns, values, validate=FALSE) {
    #                         Construct SQL INSERT statement
    # By default, this method will:
    #   * Attempt to coerce a date value to proper format if the input value is detect_dtype
    #     as a date but possibly in the improper format. Ex: '2019:02:08' -> '2019-02-08'
    #   * Quote all values passed in as strings. This will include string values that are
    #     coercible to numerics. Ex: '5', '7.5'.
    #   * Do not quote all values passed in as integer or boolean values.
    #   * Primary key value is quoted if passed in as a string. Otherwise, not quoted.
    
    # Arguments:
    #   schema     : char    : schema name
    #   table      : char    : table name
    #   pkey_name  : char    : name of primary key in table
    #   pkey_value : value   : value of primary key for value to update
    #   columns    : list    : columns to consider in INSERT statement
    #   values     : list    : values to consider in INSERT statement
    #   validate   : logical : Query column type from DB, validate that datatypes of existing
    #                          column and value to insert are the sames [OPTIONAL]
    suppressPackageStartupMessages(require(lubridate))
    stopifnot(length(columns) == length(values))
    null_eq = list(NULL, NA, "nan", "n/a", "null", "")
    logical_eq = c("true", "false")
    
    # Analyze each value's datatype and the column it is being inserted into and determine
    # whether to quote that value or not.
    val_vec = sapply(1:length(columns), function(i) {
      col = columns[[i]]
      val = values[[i]]
      
      # Detect if value is NULL equivalent
      if (is.null(val)) {
        return("NULL")
      } else if (tolower(val) %in% unlist(null_eq)) {
        return("NULL")
      }
      
      # Check if value compatible with column. Abort program if fails
      if (validate) {
        is_validated = self$validate_dtype(schema, table, col, val)
        if (!is_validated) {
          require(crayon)
          col_dtype = self$coldtypes(schema, table, col)
          echo(gsub("\\s+", " ", sprintf("SQL column %s.%s.%s is type %s
            but R value %s is type %s",
            schema, table, col, crayon::red(col_dtype),
            crayon::red(val), crayon::red(class(val)))), abort=TRUE)
        }
      }
      
      # Quote value if string or date, do not quote if numeric or logical
      if (is.logical(val) || tolower(val) %in% logical_eq) {
      } else if (is.numeric(val)) {
      } else if (is.Date(val)) {
        tryCatch({
          val = as.Date(val)
          val = self$.__handle_single_quote__(val)
        }, error = function(e) {
        })
      } else {
        # Char case (most)
        val = self$.__handle_single_quote__(val)
      }
      return(val)
    })
    
    # Paste into SQL string
    sql = sprintf("INSERT INTO %s.%s (%s) VALUES (%s);",
      schema, table, paste0(unlist(columns), collapse = ", "), paste0(val_vec, collapse = ", "))
    return(sql)
    },
  
  build_delete = function(schema, table, pkey_name, pkey_value) {
    # Construct SQL DELETE FROM statement
    pkey_value = self$.__handle_single_quote__(pkey_value)
    return(sprintf("DELETE FROM %s.%s WHERE %s = %s;", schema, table, pkey_name, pkey_value))
  },
  
  dump = function(backup_dir_abspath) {
    # Execute pg_dump command on connected database. Create .sql backup file
    backup_dir_abspath = path.expand(backup_dir_abspath)
    cmd = sprintf('pg_dump %s > "%s/%s.sql"', self$dbname, backup_dir_abspath, self$dbname)
    system(cmd)
  },
  
  dump_tables = function(backup_dir_abspath, sep, coerce_csv = TRUE) {
    # Dump each table in database to a textfile with specified separator
    # Define postgres function to dump a database, one table at a time to CSV
    # https://stackoverflow.com/questions/17463299/export-database-into-csv-file?answertab=oldest#tab-top
    db_to_csv = trimws(sprintf("
      CREATE OR REPLACE FUNCTION db_to_csv(path TEXT) RETURNS void AS $$
      DECLARE
      tables RECORD;
      statement TEXT;
      BEGIN
      FOR tables IN 
      SELECT (table_schema || '.' || table_name) AS schema_table
      FROM information_schema.tables t
      INNER JOIN information_schema.schemata s ON s.schema_name = t.table_schema 
      WHERE t.table_schema NOT IN ('pg_catalog', 'information_schema')
      AND t.table_type NOT IN ('VIEW')
      ORDER BY schema_table
      LOOP
      statement := 'COPY ' || tables.schema_table || ' TO ''' || path || '/' || tables.schema_table || '.csv' ||''' DELIMITER ''%s'' CSV HEADER';
      EXECUTE statement;
      END LOOP;
      RETURN;  
      END;
      $$ LANGUAGE plpgsql;", sep))
    self$execute(db_to_csv)
    
    # Execute function, dumping each table to a textfile.
    # Function is used as follows: SELECT db_to_csv('/path/to/dump/destination');
    backup_dir_abspath = path.expand(backup_dir_abspath)
    self$execute(sprintf("SELECT db_to_csv('%s')", backup_dir_abspath))
    
    # If coerce_csv is True, read in each file outputted, then write as a quoted CSV.
    # Replace 'sep' if different from ',' and quote each text field.
    if (coerce_csv) {
      if (sep != ",") {
        wd = getwd()
        setwd(backup_dir_abspath)
        
        # Get tables that were dumped and build filenames
        get_dumped_tables = "
        SELECT (table_schema || '.' || table_name) AS schema_table
        FROM information_schema.tables t
        INNER JOIN information_schema.schemata s ON s.schema_name = t.table_schema 
        WHERE t.table_schema NOT IN ('pg_catalog', 'information_schema')
        AND t.table_type NOT IN ('VIEW')
        ORDER BY schema_table"
        dumped_tables = paste0(self$read_sql(get_dumped_tables), ".csv")
        
        # Read in each table and overwrite file with comma sep and quoted text values
        for (csvfile in dumped_tables) {
          write.csv(read.csv(csvfile, sep=sep), row.names = FALSE)
        }
        setwd(wd)
      }
    }
  },
  
  .__handle_single_quote__ = function(val, esc_char="''") {
    # Escape single quotes and put single quotes around value if string value
    if (is.character(val)) {
      val = gsub("'", esc_char, val)
      val = paste0("'", val, "'")
    }
    return(val)
  }
    ))
tsouchlarakis/rdoni documentation built on Sept. 16, 2019, 8:53 p.m.