R/computeKmeans.R

Defines functions computeKmeans getDataPrepSql getDataScaledSql getDataNoNullsSql getKmeansSql getKmeansPlotSql getKmeansStatsSql getDataSql getTableDataSql getClusteredDataSql getClusterSumOfSquaresSql getUnpivotedClusterSql getCentroidTableSql getTotalSumOfSquaresSql makeKmeansResult makeAggregatesAlwaysContainCount computeClusterSample getKmeansplotDataSql computeSilhouette makeSilhouetteDataSql computeCanopy getCanopySql makeCanopyResult

Documented in computeCanopy computeClusterSample computeKmeans computeSilhouette

#' Perform k-means clustering on the table.
#' 
#' K-means clustering algorithm runs in-database, returns object compatible with \code{\link{kmeans}} and 
#' includes arbitrary aggregate metrics computed on resulting clusters.
#' 
#' The function fist scales not-null data (if \code{scale=TRUE}) or just removes data with \code{NULL}s without scaling. 
#' After that the data given (table \code{tableName} with option of filering with \code{where}) are clustered by the 
#' k-means in Aster. Next, all standard metrics of k-means clusters plus additional aggregates provided with
#' \code{aggregates} are calculated again in-database.
#' 
#' @param channel connection object as returned by \code{\link{odbcConnect}}.
#' @param tableName Aster table name. This argument is ignored if \code{centers} is a canopy object.
#' @param tableInfo pre-built summary of data to use (require when \code{test=TRUE}). See \code{\link{getTableSummary}}.
#' @param id column name or SQL expression containing unique table key. This argument is ignored if \code{centers} 
#'   is a canopy object.
#' @param idAlias SQL alias for table id. This is required when SQL expression is given for \code{id}. 
#'   This argument is ignored if \code{centers} is a canopy object.
#' @param include a vector of column names with variables (must be numeric). Model never contains variables other than in the list.
#'   This argument is ignored if \code{centers} is a canopy object.
#' @param except a vector of column names to exclude from variables. Model never contains variables from the list.
#'   This argument is ignored if \code{centers} is a canopy object.
#' @param centers either the number of clusters, say \code{k}, a matrix of initial (distinct) cluster centres, 
#'   or an object of class \code{"toacanopy"} obtained with \code{computeCanopy}. 
#'   If a number, a random set of (distinct) rows in x is chosen as the initial centers. 
#'   If a matrix then number of rows determines the number of clusters as each row determines initial center.
#'   if a canopy object then number of centers it contains determines the number of clusters, plust it provides
#'   (and overrides) the following arguments: \code{tableName}, \code{id}, \code{idAlias}, \code{include},
#'   \code{except}, \code{scale}, \code{where}, \code{scaledTableName}, \code{schema}
#' @param threshold the convergence threshold. When the centroids move by less than this amount, 
#'   the algorithm has converged.
#' @param iterMax the maximum number of iterations the algorithm will run before quitting if the convergence 
#'   threshold has not been met.
#' @param aggregates vector with SQL aggregates that define arbitrary aggreate metrics to be computed on each cluster 
#'   after running k-means. Aggregates may have optional aliases like in \code{"AVG(era) avg_era"}. 
#'   Subsequently, used in \code{\link{createClusterPlot}} as cluster properties.
#' @param scale logical if TRUE then scale each variable in-database before clustering. Scaling performed results in 0 mean and unit
#'   standard deviation for each of input variables. when \code{FALSE} then function only removes incomplete
#'   data before clustering (conaining \code{NULL}s). This argument is ignored if \code{centers} is a canopy object.
#' @param persist logical if TRUE then function saves clustered data in the table \code{clusteredTableName} 
#'   (when defined) with cluster id assigned. Aster Analytics Foundation 6.20 or earlier 
#'   can't support this option and so must use \code{persisit=TRUE}.
#' @param where specifies criteria to satisfy by the table rows before applying
#'   computation. The creteria are expressed in the form of SQL predicates (inside
#'   \code{WHERE} clause). This argument is ignored if \code{centers} is a canopy object.
#' @param scaledTableName the name of the Aster table with results of scaling. This argument is ignored if \code{centers} is a canopy object.
#' @param centroidTableName the name of the Aster table with centroids found by kmeans.
#' @param clusteredTableName the name of the Aster table in which to store the clustered output. If omitted 
#'   and argument \code{persist = TRUE} the random table name is generated (always saved in the 
#'   resulting \code{toakmeans} object). If \code{persist = FALSE} then the name is ignored and
#'   function does not generate a table of clustered output.
#' @param tempTableName name of the temporary Aster table to use to store intermediate results. This table
#'   always gets dropped when function executes successfully.
#' @param schema name of Aster schema that tables \code{scaledTableName}, \code{centroidTableName}, and
#'   \code{clusteredTableName} belong to. Make sure that when this argument is supplied no table name defined
#'   contain schema in its name.
#' @param test logical: if TRUE show what would be done, only (similar to parameter \code{test} in \pkg{RODBC} 
#'   functions: \link{sqlQuery} and \link{sqlSave}).
#' @param version version of Aster Analytics Foundation functions applicable when \code{test=TRUE}, ignored otherwise.
#' @return \code{computeKmeans} returns an object of class \code{"toakmeans"} (compatible with class \code{"kmeans"}).
#' It is a list with at least the following components:
#' \describe{
#'   \item{\code{cluster}}{A vector of integers (from 0:K-1) indicating the cluster to which each point is allocated. 
#'     \code{\link{computeKmeans}} leaves this component empty. Use function \code{\link{computeClusterSample}} to set this compoenent.}
#'   \item{\code{centers}}{A matrix of cluster centres.}
#'   \item{\code{totss}}{The total sum of squares.}
#'   \item{\code{withinss}}{Vector of within-cluster sum of squares, one component per cluster.}
#'   \item{\code{tot.withinss}}{Total within-cluster sum of squares, i.e. \code{sum(withinss)}.}
#'   \item{\code{betweenss}}{The between-cluster sum of squares, i.e. \code{totss-tot.withinss}.}
#'   \item{\code{size}}{The number of points in each cluster. These includes all points in the Aster table specified that 
#'     satisfy optional \code{where} condition.}
#'   \item{\code{iter}}{The number of (outer) iterations.}
#'   \item{\code{ifault}}{integer: indicator of a possible algorithm problem (always 0).}
#'   \item{\code{scale}}{logical: indicates if variable scaling was performed before clustering.}
#'   \item{\code{persist}}{logical: indicates if clustered data was saved in the table.}
#'   \item{\code{aggregates}}{Vectors (dataframe) of aggregates computed on each cluster.}
#'   \item{\code{tableName}}{Aster table name containing data for clustering.}
#'   \item{\code{columns}}{Vector of column names with variables used for clustering.}
#'   \item{\code{scaledTableName}}{Aster table containing scaled data for clustering.}
#'   \item{\code{centroidTableName}}{Aster table containing cluster centroids.}
#'   \item{\code{clusteredTableName}}{Aster table containing clustered output.}
#'   \item{\code{id}}{Column name or SQL expression containing unique table key.}
#'   \item{\code{idAlias}}{SQL alias for table id.}
#'   \item{\code{whereClause}}{SQL \code{WHERE} clause expression used (if any).}
#'   \item{\code{time}}{An object of class \code{proc_time} with user, system, and total elapsed times
#'     for the \code{computeKmeans} function call.} 
#' }
#' 
#' @export
#' @seealso \code{\link{computeClusterSample}}, \code{\link{computeSilhouette}}, \code{\link{computeCanopy}}
#' @examples 
#' if(interactive()){
#' # initialize connection to Lahman baseball database in Aster 
#' conn = odbcDriverConnect(connection="driver={Aster ODBC Driver};
#'                          server=<dbhost>;port=2406;database=<dbname>;uid=<user>;pwd=<pw>")
#'                          
#' km = computeKmeans(conn, "batting", centers=5, iterMax = 25,
#'                    aggregates = c("COUNT(*) cnt", "AVG(g) avg_g", "AVG(r) avg_r", "AVG(h) avg_h"),
#'                    id="playerid || '-' || stint || '-' || teamid || '-' || yearid", 
#'                    include=c('g','r','h'), scaledTableName='kmeans_test_scaled', 
#'                    centroidTableName='kmeans_test_centroids',
#'                    where="yearid > 2000")
#' km
#' createCentroidPlot(km)
#' createClusterPlot(km)
#' 
#' # persist clustered data
#' kmc = computeKmeans(conn, "batting", centers=5, iterMax = 250,
#'                    aggregates = c("COUNT(*) cnt", "AVG(g) avg_g", "AVG(r) avg_r", "AVG(h) avg_h"),
#'                    id="playerid || '-' || stint || '-' || teamid || '-' || yearid", 
#'                    include=c('g','r','h'), 
#'                    persist = TRUE, 
#'                    scaledTableName='kmeans_test_scaled', 
#'                    centroidTableName='kmeans_test_centroids', 
#'                    clusteredTableName = 'kmeans_test_clustered',
#'                    tempTableName = 'kmeans_test_temp',
#'                    where="yearid > 2000")
#' createCentroidPlot(kmc)
#' createCentroidPlot(kmc, format="bar_dodge")
#' createCentroidPlot(kmc, format="heatmap", coordFlip=TRUE)
#' 
#' createClusterPlot(kmc)
#' 
#' kmc = computeClusterSample(conn, kmc, 0.01)
#' createClusterPairsPlot(kmc, title="Batters Clustered by G, H, R", ticks=FALSE)
#' 
#' kmc = computeSilhouette(conn, kmc)
#' createSilhouetteProfile(kmc, title="Cluster Silhouette Histograms (Profiles)")
#' 
#' }
computeKmeans <- function(channel, tableName, centers, threshold=0.0395, iterMax=10, 
                          tableInfo, id, include=NULL, except=NULL, 
                          aggregates="COUNT(*) cnt", scale=TRUE, persist=FALSE,
                          idAlias=gsub("[^0-9a-zA-Z]+", "_", id), where=NULL,
                          scaledTableName=NULL, centroidTableName=NULL, 
                          clusteredTableName=NULL, tempTableName=NULL,
                          schema=NULL, test=FALSE, version="6.21") {
  
  ptm = proc.time()
  
  if (test && missing(tableInfo)) {
    stop("Must provide tableInfo when test==TRUE")
  }
  
  if (persist && compareVersion(version, "6.21") < 0)
    stop("Persisting clustered data with versions before AAF 6.21 is not supported.")
  
  isValidConnection(channel, test)
  
  # validate centers (initial clusters)
  useCanopy = FALSE
  if (is.matrix(centers)) 
    K = nrow(centers)
  else if (is.numeric(centers)) 
    K = as.integer(centers)
  else if (inherits(centers, 'toacanopy')) {
    useCanopy = TRUE
    canopy = centers
    centers = canopy$centers
    K = nrow(canopy$centers)
  } else 
    stop("Parameter centers must be one of following: number of clusters, numeric matrix of initial centroids, or canopy object.")
  
  if (K < 1) 
    stop("Number of clusters must be greater or equal to 1.")
  
  if (useCanopy) {
    tableName=canopy$tableName
    columns=canopy$columns
    id=canopy$id
    idAlias=canopy$idAlias
    scale=canopy$scale
    scaledTableName=canopy$scaledTableName
    if (is.null(schema))
      schema=canopy$schema
    where_clause=canopy$whereClause
  }else {
    tableName = normalizeTableName(tableName)
    
    if (missing(tableInfo)) {
      tableInfo = sqlColumns(channel, tableName)
    }
    
    columns = getNumericColumns(tableInfo, names.only=TRUE, include=include, except=except)
    columns = sort(setdiff(columns, id))
    
    if (is.null(columns) || length(columns) < 1) {
      stop("Kmeans operates on one or more numeric variables.")
    }
    
    # check if id alias is not one of independent variables
    if(idAlias %in% columns)
      stop(paste0("Id alias '", idAlias, "' can't be one of variable names."))
    
    # adjust id alias if it's exactly one of the table columns
    if(idAlias %in% tableInfo$COLUMN_NAME)
      idAlias = paste("_", idAlias, "_", sep="_")
    
    # scale data table name
    if (is.null(scaledTableName))
      scaledTableName = makeTempTableName('scaled', 30, schema)
    else if (!is.null(schema))
      scaledTableName = paste0(schema, ".", scaledTableName)
    
    where_clause = makeWhereClause(where)
  }
  
  if (is.matrix(centers)) {
    if (length(columns) != ncol(centers))
      stop(paste0("Kmeans received incompatible parameters: dimension of initial cluster centers doesn't match variables: '", 
                  paste0(columns, collapse = "', '"), "'"))
  }
  
  aggregates = makeAggregatesAlwaysContainCount(aggregates)
  
  # centroids table name
  if(is.null(centroidTableName))
    centroidTableName = makeTempTableName('centroids', 30, schema)
  else if (!is.null(schema))
    centroidTableName = paste0(schema, ".", centroidTableName)
  
  # clustered data table name
  if(is.null(clusteredTableName) && persist) {
    clusteredTableName = makeTempTableName('clustered', 30, schema)
  }else if (!is.null(schema) && persist) {
    clusteredTableName = paste0(schema, ".", clusteredTableName)
  }else if(!persist) {
    clusteredTableName = NULL
  }
  
  if(is.null(tempTableName))
    tempTableName = makeTempTableName('clustered_remove', 30, schema)
  else if (!is.null(schema))
    tempTableName = paste0(schema, ".", tempTableName)
    
  emptyLine = "--"
  
  if(test)
    sqlText = ""
  
  if (!useCanopy) {
    # scale data or just eliminate incomplete observations (if not scaling)
    sqlComment = paste("-- Data Prep:", ifelse(scale, "scale", "omit nulls"))
    sqlDrop = paste("DROP TABLE IF EXISTS", scaledTableName)
    sql = getDataPrepSql(scale, tableName, scaledTableName, columns, id, idAlias, where_clause)
    if(test) {
      sqlText = paste(sqlComment, sqlDrop, sep='\n')
      sqlText = paste(sqlText, sql, sep=';\n')
    }else {
      toaSqlQuery(channel, sqlDrop)
      toaSqlQuery(channel, sql)
    }
  }
  
  # run kmeans
  sqlComment = "-- Run k-means"
  sqlDrop = paste("DROP TABLE IF EXISTS", centroidTableName)
  if (persist)
    sqlDrop2 = paste("DROP TABLE IF EXISTS", tempTableName)
  sql = getKmeansSql(persist, 
                     scaledTableName, centroidTableName, tempTableName, centers, threshold, iterMax)
  if(test) {
    sqlText = paste(sqlText, 
                    emptyLine, 
                    sqlComment, 
                    paste0(sqlDrop, ifelse(persist, paste0(';\n', sqlDrop2), "")),
                    sql, sep=';\n')
    kmeans_version = version
  }else {
    toaSqlQuery(channel, sqlDrop)
    if (persist)
      toaSqlQuery(channel, sqlDrop2)
    kmeansResultStr = toaSqlQuery(channel, sql, stringsAsFactors=FALSE)
    # kmeans output prior to AAF 6.21
    if ('message' %in% names(kmeansResultStr)) {
      warning("This version of kmeans is no longer supported. Please, upgrade to AAF 6.21 or higher, otherwise, running at your own risk.")
      kmeans_version = "6.20"
      if (kmeansResultStr[2,'message'] == "Successful!" &&
          kmeansResultStr[3,'message'] == "Algorithm converged.") {
        iter = as.integer(gsub("[^0-9]", "", kmeansResultStr[[4,'message']]))
      }else {
        msg = paste(kmeansResultStr[,'message'], collapse="\n")
        stop(msg)
      }  
    # kmeans output since AAF 6.21
    }else {
      kmeans_version = "6.21"
      line = K + 2
      iter = gsub("[^0-9]", "", kmeansResultStr[[line+1, 2]])
      if (kmeansResultStr[[line, 2]] == "Converged : False") {
        msg = paste("Kmeans failed to converge after", iter, "iterations.")
        stop(msg)
      }
    }
  }
  
  # persist clustered data
  if(persist){
    if(compareVersion(kmeans_version, "6.20") <= 0) {
      sqlComment = "-- Cluster and persist data"
      sqlDrop = paste("DROP TABLE IF EXISTS", clusteredTableName)
      sqlKmeansScore = getKmeansPlotSql(idAlias, scaledTableName, centroidTableName, clusteredTableName)
      if(test) {
        sqlText = paste(sqlText, 
                        emptyLine,
                        sqlComment,
                        sqlDrop,
                        sqlKmeansScore, sep=';\n')
      }else {
        toaSqlQuery(channel, sqlDrop)
        toaSqlQuery(channel, sqlKmeansScore)
      }
    }else {
      sqlComment = "-- Combine clustered ids with data"
      sqlDrop = paste("DROP TABLE IF EXISTS", clusteredTableName)
      sqlTempDrop = paste("DROP TABLE IF EXISTS", tempTableName)
      sqlKmeansClusteredData = paste0(
        "CREATE FACT TABLE ", clusteredTableName, " DISTRIBUTE BY HASH(", idAlias, ") AS 
         SELECT d.*, c.clusterid 
           FROM ",tempTableName," c JOIN 
                ",scaledTableName," d ON (c.",idAlias," = d.",idAlias,")")
      if(test) {
        sqlText = paste(sqlText,
                        emptyLine,
                        sqlComment,
                        sqlDrop,
                        sqlKmeansClusteredData,
                        sqlTempDrop, sep=';\n')
      }else {
        toaSqlQuery(channel, sqlDrop)
        toaSqlQuery(channel, sqlKmeansClusteredData)
        toaSqlQuery(channel, sqlTempDrop)
      }
    }
  } 
  
  # compute cluster stats
  sqlComment = "-- Run cluster assignment, cluster stats, and within-cluster sum of squares"
  sql = getKmeansStatsSql(persist, tableName, scaledTableName, centroidTableName, clusteredTableName,
                          columns,  K, id, idAlias, aggregates, where_clause, kmeans_version)
  if(test)
    sqlText = paste(sqlText, emptyLine, sqlComment, sql, sep=';\n')
  else
    kmeansstats = toaSqlQuery(channel, sql, stringsAsFactors=FALSE)
  
  
  # compute total sum of squares
  sqlComment = "-- Compute Total Sum of Squares"
  sql = getTotalSumOfSquaresSql(scaledTableName, columns, idAlias, scale)
  
  if(test)
    sqlText = paste(sqlText, emptyLine, sqlComment, sql, sep=';\n')
  else {
    rs = toaSqlQuery(channel, sql)
    totss = rs$totss[[1]]
  }

  # return sql
  if(test) {
    sqlText = paste0(sqlText, ';')
    return(sqlText)
  }
  
  result = makeKmeansResult(kmeansstats, K, totss, iter, tableName, columns, scale,
                            scaledTableName, centroidTableName, clusteredTableName, id, idAlias, 
                            persist, where_clause, kmeansResultStr, ptm, kmeans_version)
  
  return(result)
}


# Phase 1: Data Prep
getDataPrepSql <- function(scale, tableName, tempTableName, columns, id, idAlias, whereClause) {
  
  dataPrepSql = ifelse(scale, 
                      getDataScaledSql(tableName, columns, id, idAlias, whereClause),
                      getDataNoNullsSql(tableName, columns, id, idAlias, whereClause))
  
  paste0(
    "CREATE FACT TABLE ", tempTableName, " DISTRIBUTE BY HASH(", idAlias, ") AS 
       ", dataPrepSql
  )
}

getDataScaledSql <- function(tableName, columns, id, idAlias, whereClause) {
  
  sqlmr_column_list = makeSqlMrColumnList(columns)
  query_as_table = getDataSql(tableName, columns, id, idAlias, whereClause)
  
  scaleMapSql = paste0(
    "SELECT * FROM ScaleMap (
       ON (", query_as_table, ")
    InputColumns (", sqlmr_column_list, ")
    -- MissValue ('OMIT')
    )"
  )
  
  scaleSql = paste0(
    "SELECT * FROM Scale(
       ON (", query_as_table, ") AS input PARTITION BY ANY
       ON (", scaleMapSql, ") AS STATISTIC DIMENSION
       Method ('STD')
       Accumulate('", idAlias, "')
       GlobalScale ('false')
       InputColumns (", sqlmr_column_list, ")
     )"
  )
  
  return(scaleSql)
}

getDataNoNullsSql <- function(tableName, columns, id, idAlias, whereClause) {
  
  not_null_clause = paste0(c(idAlias, columns), " IS NOT NULL", collapse=" AND ")
  query_as_table = getDataSql(tableName, columns, id, idAlias, whereClause)
  
  noNullsSql = paste0(
    "SELECT * FROM (", query_as_table, ") d
      WHERE ", not_null_clause)
  
}

# Phase: kmeans 
getKmeansSql <- function(persist, scaledTableName, centroidTableName, clusteredTableName, centers, threshold, maxiternum) {
  
  if (is.matrix(centers)) {
    initCenters = paste0("MEANS(", paste0("'", paste0(apply(centers, 1, paste0, collapse='_'), collapse="', '"), "'"), ")")
  }else
    initCenters = paste0("NUMBERK('", centers, "')")
  
  kmeansSql = paste0(
    "SELECT * FROM kmeans(
      ON (SELECT 1)
      PARTITION BY 1
      INPUTTABLE('", scaledTableName, "')
      OUTPUTTABLE('", centroidTableName, "')
", ifelse(persist, paste0("
      ClusteredOutput('",clusteredTableName,"')"), ""), "
   ", initCenters, "
      THRESHOLD('", threshold, "')
      MAXITERNUM('", maxiternum, "')
    )")
}

getKmeansPlotSql <- function(idAlias, scaledTableName, centroidTableName, clusteredTableName) {
  
  kmeansPlotSql = paste0(
    "CREATE FACT TABLE ", clusteredTableName, " DISTRIBUTE BY HASH(", idAlias, ") AS SELECT * FROM KMeansPlot(
       ON ", scaledTableName, " PARTITION BY ANY
       ON ", centroidTableName, " DIMENSION
       CentroidsTable('", centroidTableName, "')
       PrintDistance('true')
     )"
  )
}

getKmeansStatsSql <- function(persist, tableName, scaledTableName, centroidTableName, clusteredTableName,
                              columns, K, id, idAlias, aggregates, whereClause, kmeans_version) {
  
  if(compareVersion(kmeans_version, "6.20") <= 0) 
    means_column_name = "means"
  else 
    means_column_name = paste0('"', paste0(columns, collapse = ' '), '"')

  clustersWithValuesSql = paste0(
    "SELECT c1.*, c2.withinss  
       FROM (SELECT clusterid, means, ", paste(aggregates, collapse=", "), 
    "          FROM (", getClusteredDataSql(persist, tableName, scaledTableName, centroidTableName, 
                                            clusteredTableName, means_column_name, id, idAlias, whereClause), "
                    ) clustered_data
              GROUP BY clusterid, means
            ) c1 JOIN ( 
            ", paste(sapply(1:K, FUN=getClusterSumOfSquaresSql, persist, scaledTableName, centroidTableName, 
                                                                clusteredTableName, columns, means_column_name, idAlias),
                     collapse="\nUNION ALL\n"),
    "
            ) c2 ON (c1.clusterid = c2.clusterid)
      ORDER BY clusterid"
  )
}

getDataSql <- function(tableName, columns, id, idAlias, whereClause) {
  
  paste0("SELECT ", id, " ", idAlias, ", ", makeSqlColumnList(columns), " FROM ", tableName, whereClause)
}


getTableDataSql <- function(tableName, id, idAlias, whereClause) {
  
  paste0("SELECT ", id, " ", idAlias, ", * FROM ", tableName, whereClause)
}


getClusteredDataSql <- function(persist, tableName, scaledTableName, centroidTableName, 
                                clusteredTableName, means_column_name, id, idAlias, whereClause) {
  
  query_as_table = getTableDataSql(tableName, id, idAlias, whereClause)
  
  paste0(
    "SELECT c.clusterid, c.",means_column_name," means, d.* 
      FROM ", centroidTableName, " c JOIN 
    ",ifelse(persist, clusteredTableName, paste0(
    "kmeansplot (
      ON ", scaledTableName, " PARTITION BY ANY
      ON ", centroidTableName, " DIMENSION
      centroidsTable('",centroidTableName,"')
    )"))," kmp ON (c.clusterid = kmp.clusterid) JOIN 
    (", query_as_table, ") d on (kmp.", idAlias, " = d.", idAlias, ")"
  )
}


getClusterSumOfSquaresSql <- function(clusterid, persist, scaledTableName, centroidTableName, clusteredTableName,
                                      columns, means_column_name, idAlias) {
  
  clusterid = as.character(clusterid - 1)
  
  sql = paste0(
    "SELECT ", clusterid, " clusterid, SUM(distance::double ^ 2) withinss FROM ", 
    getUnpivotedClusterSql(clusterid, persist, scaledTableName, centroidTableName, clusteredTableName, 
                           columns, means_column_name, idAlias)
  )
}


getUnpivotedClusterSql <- function(clusterid, persist, scaledTableName, centroidTableName, clusteredTableName, 
                                   columns, means_column_name, idAlias) {
  
  sql_column_list = makeSqlColumnList(columns)
  sqlmr_column_list = makeSqlMrColumnList(columns)
  
  sql = paste0(
    
    "VectorDistance(
       ON (
         SELECT clusterid, ", idAlias, ", variable, coalesce(value_double, value_long, value_str::double) value
           FROM unpivot(
                  ON (SELECT d.* 
                        FROM ", ifelse(persist, clusteredTableName, paste0(
                            "kmeansplot (
                               ON ", scaledTableName, " PARTITION BY ANY
                               ON ", centroidTableName, " DIMENSION
                               centroidsTable('",centroidTableName,"')
                             )"))," d 
                       WHERE clusterid = ", clusterid, "
                  )
                  COLSTOUNPIVOT(", sqlmr_column_list, ")
                  COLSTOACCUMULATE('",idAlias,"','clusterid')
                  ATTRIBUTECOLUMNNAME('variable')
                  VALUECOLUMNNAME('value')
                  KEEPINPUTCOLUMNTYPES('true')
                )
       ) AS target PARTITION BY ", idAlias, "
       ON (
         ", getCentroidTableSql(centroidTableName, sql_column_list, means_column_name, clusterid), "
       ) AS ref DIMENSION
       TARGETIDCOLUMNS('",idAlias,"')
       TARGETFEATURECOLUMN('variable')
       TARGETVALUECOLUMN('value')
       REFIDCOLUMNS('clusterid')
       REFFEATURECOLUMN('variable')
       REFVALUECOLUMN('value')
       MEASURE('Euclidean')
     )"
  )
  
}


getCentroidTableSql <- function(centroidTableName, sql_column_list, means_column_name, clusterid = NULL) {
  
  whereClause = ifelse(is.null(clusterid), 
                       "", 
                       paste0(" WHERE clusterid = ", clusterid))

  paste0(
    "SELECT *, regexp_split_to_table(",means_column_name,", ' ')::numeric value, regexp_split_to_table('", sql_column_list, "', ', ') variable 
           FROM ", centroidTableName, whereClause
    
  )
}


getTotalSumOfSquaresSql <- function(scaledTableName, columns, idAlias, scale) {
  
  sqlmr_column_list = makeSqlMrColumnList(columns)
  
  if (scale) {
    sqlagg_column_list = paste0("0.0::double ", columns, collapse = ", ")
    globalCenterSql = paste0("SELECT 1 id, ", sqlagg_column_list)
  }else {
    sqlagg_column_list = makeSqlAggregateColumnList(columns, "avg", FALSE, cast="::double")
    globalCenterSql = paste0("SELECT 1 id, ", sqlagg_column_list, " FROM ", scaledTableName)
  }

  sql = paste0(
    "SELECT SUM(distance::double ^ 2) totss FROM VectorDistance(
       ON (SELECT ", idAlias, ", variable, coalesce(value_double, value_long, value_str::double) value
             FROM unpivot(
               ON ", scaledTableName , "
               COLSTOUNPIVOT(", sqlmr_column_list, ")
               COLSTOACCUMULATE('",idAlias,"')
               ATTRIBUTECOLUMNNAME('variable')
               VALUECOLUMNNAME('value')
               KEEPINPUTCOLUMNTYPES('true')
             ) 
       ) AS target PARTITION BY ", idAlias, "
       ON (SELECT id, variable, value_double
             FROM unpivot(
               ON (", globalCenterSql, ")
               COLSTOUNPIVOT(", sqlmr_column_list, ")
               COLSTOACCUMULATE('id')
               ATTRIBUTECOLUMNNAME('variable')
               VALUECOLUMNNAME('value')
               KEEPINPUTCOLUMNTYPES('true')
             )
       ) AS ref DIMENSION
       TARGETIDCOLUMNS('",idAlias,"')
       TARGETFEATURECOLUMN('variable')
       TARGETVALUECOLUMN('value')
       REFIDCOLUMNS('id')
       REFFEATURECOLUMN('variable')
       REFVALUECOLUMN('value_double')
       MEASURE('Euclidean')
     )"
  )
}


makeKmeansResult <- function(data, K, totss, iter, tableName, columns, scale,
                             scaledTableName, centroidTableName, clusteredTableName, id, idAlias, 
                             persist, whereClause, output, ptm, kmeans_version) {
  
  if (compareVersion(kmeans_version, "6.20") <= 0) {
    centers_str = data$means
    tot_withinss = sum(data$withinss)
    withinss = data$withinss
    sizes = aggregates$cnt
  }else {
    if (persist) {
      delta = 1
    }else {
      clusteredTableName = NULL
      delta = 0
    }
    centers_str = output[1:K, 2]
    tot_withinss = as.numeric(gsub("[^0-9\\.]", "", output[K + 6 + delta, 2]))
    withinss = as.numeric(output$withinss[1:K])
    sizes = as.integer(output$size[1:K])
  }
  
  centers = matrix(as.numeric(unlist(strsplit(centers_str, split = " "))), 
                   ncol=length(columns), nrow=K, byrow=TRUE)
  colnames(centers) = columns
  rownames(centers) = data$clusterid
  
  aggregates = data[, -2]
  
  z <- structure(list(cluster=integer(0),
                      centers=centers,
                      totss=totss,
                      withinss = withinss,
                      tot.withinss = tot_withinss,
                      betweenss = totss - tot_withinss,
                      size = sizes,
                      iter=iter,
                      ifault = 0,
                      
                      scale=scale,
                      persist=persist,
                      aggregates=aggregates,
                      tableName=tableName,
                      columns=columns,
                      scaledTableName=scaledTableName,
                      centroidTableName=centroidTableName,
                      clusteredTableName=clusteredTableName,
                      id=id,
                      idAlias=idAlias,
                      whereClause=whereClause,
                      version=kmeans_version,
                      output=output,
                      time=proc.time() - ptm
  ),
  class = c("toakmeans", "kmeans"))
  
  return (z)
}


makeAggregatesAlwaysContainCount <- function(aggregates){
  
  # empty or NULL list 
  if (length(aggregates) == 0)
    return("COUNT(*) cnt")
 
  # parse aggregates into tuples of function and alias
  aggFun = unlist(sapply(strsplit(aggregates, '[[:space:]]'), 
                 FUN=function(x) {
                   s = paste0(x[1:length(x)-1], collapse = ' ')
                   if (nchar(s)==0) NULL else s
                 }))
 
  aggAlias = unlist(sapply(strsplit(aggregates, '[[:space:]]'), 
                   FUN=function(x) {
                     s = x[length(x)]
                     if (nchar(s)==0) NULL else s
                  }))
 
  # detect missing alias
  if (length(aggFun) != length(aggAlias))
   stop("Check aggregates: at least one missing alias found.")
 
  # detect if 'COUNT(*) cnt' present
  missingCount = TRUE
  for(i in 1:length(aggregates)) {
    fun = aggFun[[i]]
    alias = aggAlias[[i]]
   
    if(tolower(fun) == 'count(*)' && alias == 'cnt')
      missingCount = FALSE
 }
 
  # form final aggregates
  aggregates = paste(aggFun, aggAlias)
  if (missingCount)
    aggregates = c(aggregates, "COUNT(*) cnt")
 
  return(aggregates)
}


#' Random sample of clustered data
#' 
#' @param channel connection object as returned by \code{\link{odbcConnect}}.
#' @param km an object of class \code{"toakmeans"} obtained with \code{computeKmeans}.
#' @param sampleFraction vector with one or more sample fractions to use in the sampling of data.
#'   Multiple fractions define sampling for each cluster in kmeans \code{km} object where
#'   vector length must be equal to the number of clusters.
#' @param sampleSize vector with sample size (applies only when \code{sampleFraction} is missing).
#'   Multiple sizes define sampling for each cluster in kmeans \code{km} object where
#'   vector length must be equal to the number of clusters.
#' @param scaled logical: indicates if original (default) or scaled data returned.
#' @param includeId logical indicates if sample should include key attribute identifying
#'   each data point.
#' @param test logical: if TRUE show what would be done, only (similar to parameter \code{test} in \pkg{RODBC} 
#'   functions: \link{sqlQuery} and \link{sqlSave}).
#' @return \code{computeClusterSample} returns an object of class \code{"toakmeans"} (compatible with class \code{"kmeans"}).
#' @seealso \code{\link{computeKmeans}}
#' 
#' @export
#' @examples 
#' if(interactive()){
#' # initialize connection to Lahman baseball database in Aster 
#' conn = odbcDriverConnect(connection="driver={Aster ODBC Driver};
#'                          server=<dbhost>;port=2406;database=<dbname>;uid=<user>;pwd=<pw>")
#'                          
#' km = computeKmeans(conn, "batting", centers=5, iterMax = 25,
#'                    aggregates = c("COUNT(*) cnt", "AVG(g) avg_g", "AVG(r) avg_r", "AVG(h) avg_h"),
#'                    id="playerid || '-' || stint || '-' || teamid || '-' || yearid", 
#'                    include=c('g','r','h'), scaledTableName='kmeans_test_scaled', 
#'                    centroidTableName='kmeans_test_centroids',
#'                    where="yearid > 2000")
#' km = computeClusterSample(conn, km, 0.01)
#' km
#' createClusterPairsPlot(km, title="Batters Clustered by G, H, R", ticks=FALSE)
#' 
#' # per cluster sample fractions
#' km = computeClusterSample(conn, km, c(0.01, 0.02, 0.03, 0.02, 0.01))
#' }
computeClusterSample <- function(channel, km, sampleFraction, sampleSize, scaled=FALSE, includeId=TRUE, test=FALSE) {
  
  isValidConnection(channel, test)
  
  if (missing(km) || !is.object(km) || !inherits(km, "toakmeans")) {
    stop("Kmeans object must be specified.")
  }
  
  if ((missing(sampleFraction) || is.null(sampleFraction)) && 
      (missing(sampleSize) || is.null(sampleSize))) {
    stop("Sample fraction or sample size must be specified.")
  }
  
  # validate sample fraction
  if (!missing(sampleFraction) && !is.null(sampleFraction)) {
    sampleFraction = as.numeric(sampleFraction)
    
    # using sample fraction
    if (!all(sampleFraction >= 0) || !all(sampleFraction <= 1))
      stop("All sample fractions must be between 0 and 1 inclusively.")
    
    if (length(sampleFraction) > 1 && length(sampleFraction) != nrow(km$centers))
      stop("Fraction vector length must be either 1 or equal to the number of clusters.")
  }else {
    sampleSize = as.integer(sampleSize)
    
    if (!all(sampleSize >= 0.0))
      stop("All sample sizes must be equal to or greater than 0.")
    
    if (length(sampleSize) > 1 && length(sampleSize) != nrow(km$centers))
      stop("Size vector length must be either 1 or equal to the number of clusters.")
  }
  
  persist = ifelse(is.null(km$persist), FALSE, km$persist)
  table_name = km$tableName
  columns = km$columns
  scaled_table_name = km$scaledTableName
  centroid_table_name = km$centroidTableName
  clustered_table_name = km$clusteredTableName
  id = km$id
  idAlias = km$idAlias
  where_clause = km$whereClause
  centers = nrow(km$centers)
  
  conditionOnSql = paste0("'",paste0(1:centers-1, collapse="','"),"'")
  query_as_table = getDataSql(table_name, columns, id, idAlias, where_clause)
  
  
  if (!missing(sampleFraction) && !is.null(sampleFraction)) {
    
    fractionStr = paste0("'", as.character(sampleFraction), "'", collapse = ",")

    sql = paste0(
      "SELECT * FROM sample(
             ON (", getKmeansplotDataSql(scaled_table_name, centroid_table_name, clustered_table_name, scaled, persist, query_as_table, idAlias), "
             )
             CONDITIONONCOLUMN('clusterid')
             CONDITIONON(",conditionOnSql,")
             SAMPLEFRACTION(", fractionStr, ")
       )")
  }else {
    # using sample size
    
    sizeStr = paste0("'", as.character(sampleSize), "'", collapse = ",")
    
    sql = paste0(
      "WITH stratum_counts AS (
         SELECT clusterid stratum, count(*) stratum_count 
           FROM ", ifelse(persist, clustered_table_name, paste0(
                "kmeansplot(
             ON ", scaled_table_name, " PARTITION BY ANY
             ON ", centroid_table_name, " DIMENSION
             centroidsTable('",centroid_table_name,"')
           )"))," 
          WHERE clusterid != -1
         GROUP BY 1
       )
       SELECT * FROM sample (
         ON (", getKmeansplotDataSql(scaled_table_name, centroid_table_name, clustered_table_name, scaled, persist, query_as_table, idAlias), "
            ) AS data PARTITION BY ANY
         ON stratum_counts AS summary DIMENSION
         CONDITIONONCOLUMN('clusterid')
         CONDITIONON(",conditionOnSql,")
         ApproximateSampleSize(", sizeStr, ")
      )"
    )
  }
  
  if (!includeId) {
    sql = paste0(
      "SELECT * FROM antiselect(
         ON 
           (",sql,"
           )
         EXCLUDE('",idAlias,"')
       )")
  }
  
  if(test) {
    return(sql)
  }else {
    data = toaSqlQuery(channel, sql)
    km$cluster = data$clusterid
    km$data = data
    return(km)
  }
}


getKmeansplotDataSql <- function(scaled_table_name, centroid_table_name, clustered_table_name, 
                                 scaled, persist, query_as_table, idAlias) {
  
  sql = paste0(
                "SELECT ", ifelse(scaled,  " d.* ", " clusterid, d.* "), "
                   FROM ", ifelse(persist, clustered_table_name, paste0(
                        "kmeansplot(
                     ON ", scaled_table_name, " PARTITION BY ANY
                     ON ", centroid_table_name, " DIMENSION
                     centroidsTable('",centroid_table_name,"')
                   )")), ifelse(scaled, " d ", 
                               paste0( " kmp JOIN (", query_as_table, ") d ON (kmp.", idAlias, " = d.", idAlias, ")")), "
                  WHERE clusterid != -1"
  )
}


#' Compute Silhouette (k-means clustering).
#' 
#' @param channel connection object as returned by \code{\link{odbcConnect}}.
#' @param km an object of class \code{"toakmeans"} obtained with \code{computeKmeans}.
#' @param scaled logical: indicates if computation performed on original (default) or scaled values.
#' @param silhouetteTableName name of the Aster table to hold silhouette scores. The table persists silhoutte scores 
#'   for all clustered elements. Set parameter \code{drop=F} to keep the table.
#' @param drop logical: indicates if the table \code{silhouetteTableName}
#' @param test logical: if TRUE show what would be done, only (similar to parameter \code{test} in \pkg{RODBC} 
#'   functions: \link{sqlQuery} and \link{sqlSave}).
#' @return \code{computeSilhouette} returns an object of class \code{"toakmeans"} (compatible with class \code{"kmeans"}).
#'   It adds a named list \code{sil} the \code{km} containing couple of elements: average value of silhouette \code{value} and silhouette profile  
#'   (distribution of silhouette values on each cluster) \code{profile}
#' @seealso \code{\link{computeKmeans}}
#'
#' @export
#' @examples 
#' if(interactive()){
#' # initialize connection to Lahman baseball database in Aster 
#' conn = odbcDriverConnect(connection="driver={Aster ODBC Driver};
#'                          server=<dbhost>;port=2406;database=<dbname>;uid=<user>;pwd=<pw>")
#'                          
#' km = computeKmeans(conn, "batting", centers=5, iterMax = 25,
#'                    aggregates = c("COUNT(*) cnt", "AVG(g) avg_g", "AVG(r) avg_r", "AVG(h) avg_h"),
#'                    id="playerid || '-' || stint || '-' || teamid || '-' || yearid", 
#'                    include=c('g','r','h'), scaledTableName='kmeans_test_scaled', 
#'                    centroidTableName='kmeans_test_centroids',
#'                    where="yearid > 2000")
#' km = computeSilhouette(conn, km)
#' km$sil
#' createSilhouetteProfile(km, title="Cluster Silhouette Histograms (Profiles)")
#' }
computeSilhouette <- function(channel, km, scaled=TRUE, silhouetteTableName=NULL, drop=TRUE, test=FALSE) {
  
  ptm = proc.time()
  
  isValidConnection(channel, test)
  
  if(test && is.null(silhouetteTableName)){
    stop("Silhouette table name is required when test=TRUE.")
  }
    
  if (missing(km) || !is.object(km) || !inherits(km, "toakmeans")) {
    stop("Kmeans object must be specified.")
  }
  
  if (nrow(km$centers) == 1)
    stop("Silhouette values are trivial in case of single cluster model.")
  
  if (is.null(silhouetteTableName))
    silhouetteTableName = makeTempTableName('silhouette', 30)
  
  table_name = km$tableName
  columns = km$columns
  persist = ifelse(is.null(km$persist), FALSE, km$persist)
  scaled_table_name = km$scaledTableName
  centroid_table_name = km$centroidTableName
  clustered_table_name = km$clusteredTableName
  id = km$id
  idAlias = km$idAlias
  where_clause = km$whereClause
  
  emptyLine = "--"
  
  if(test)
    sqlText = ""
  
  # make silhouette data
  sqlComment = "-- Create Analytical Table with Silhouette Data"
  sqlDrop = paste("DROP TABLE IF EXISTS", silhouetteTableName)
  sql = makeSilhouetteDataSql(table_name, silhouetteTableName, columns, id, idAlias, where_clause, 
                              scaled_table_name, centroid_table_name, clustered_table_name, scaled, persist)
  if(test) {
    sqlText = paste(sqlComment, sqlDrop, sep='\n')
    sqlText = paste(sqlText, sql, sep=';\n')
  }else {
    toaSqlQuery(channel, sqlDrop)
    toaSqlQuery(channel, sql)
  }
  
  # Compute overall silhouette value
  sqlComment = "-- Compute overall silhouette value"
  sql = paste0("SELECT AVG((b-a)/greatest(a,b)) silhouette_value FROM ", silhouetteTableName)
  if(test) {
    sqlText = paste(sqlText, emptyLine, sqlComment, sql, sep=';\n')
  }else {
    data = toaSqlQuery(channel, sql)
    sil_value = data[[1,'silhouette_value']]
  }
  
  # Compute silhouette profiles (histograms by cluster)
  sqlComment = "-- Compute silhouette cluster profiles"
  sql = paste0(
    "SELECT * FROM Hist_Reduce(
       ON Hist_Map(
         ON (SELECT clusterid::varchar clusterid, (b-a)/greatest(a,b) silhouette_value FROM ", silhouetteTableName, "
         )
         STARTVALUE('-1')
         BINSIZE('0.05')
         ENDVALUE('1')
         VALUE_COLUMN('silhouette_value')
         GROUP_COLUMNS('clusterid')
       ) PARTITION BY clusterid
     )"
  )
  if(test){
    sqlText = paste(sqlText, emptyLine, sqlComment, sql, sep=';\n')
  }else {
    data = toaSqlQuery(channel, sql)
    sil_profile = toaSqlQuery(channel, sql, stringsAsFactors=TRUE)
  }
  
  # Drop Analytical Table with Silhouette Data 
  if(drop) {
    sqlComment = "-- Drop Analytical Table with Silhouette Data"
    sql = paste0("DROP TABLE IF EXISTS ", silhouetteTableName)
    if(test){
      sqlText = paste(sqlText, emptyLine, sqlComment, sql, sep=';\n')
    }else {
      toaSqlQuery(channel, sql)
    }
  }
  
  # get result back
  if(test){
    sqlText = paste0(sqlText, ';')
    return(sqlText)
  }
  
  sil = list(value=sil_value, profile=sil_profile)
  if(!drop) {
    sil = c(sil, tableName=silhouetteTableName)
  }
  sil$time = proc.time() - ptm
  km$sil = sil
  
  return(km)
}


makeSilhouetteDataSql <- function(table_name, temp_table_name, columns, id, idAlias, 
                                  where_clause, scaled_table_name, centroid_table_name, 
                                  clustered_table_name, scaled, persist) {
  
  sqlmr_column_list = makeSqlMrColumnList(columns)
  query_as_table = getDataSql(table_name, columns, id, idAlias, where_clause)
  
  sql = paste0(
    "CREATE ANALYTIC TABLE ", temp_table_name, "
     DISTRIBUTE BY HASH(clusterid)
     AS
     WITH kmeansplotresult AS (
         SELECT clusterid, ", idAlias, ", variable, coalesce(value_double, value_long, value_str::double) value
           FROM unpivot(
                  ON (", getKmeansplotDataSql(scaled_table_name, centroid_table_name, clustered_table_name, 
                                              scaled, persist, query_as_table, idAlias), "
                  )
                  COLSTOUNPIVOT(", sqlmr_column_list, ")
                  COLSTOACCUMULATE('",idAlias,"','clusterid')
                  ATTRIBUTECOLUMNNAME('variable')
                  VALUECOLUMNNAME('value')
                  KEEPINPUTCOLUMNTYPES('true')
                ) 
     )
     SELECT target_clusterid clusterid, target_", idAlias, " ", idAlias, ", a, b 
       FROM (
         SELECT target_clusterid, target_", idAlias, ",
                MAX(CASE WHEN target_clusterid = ref_clusterid THEN dissimilarity ELSE 0 END) a,
                MIN(CASE WHEN target_clusterid = ref_clusterid THEN 'Infinity' ELSE dissimilarity END) b 
           FROM
             (SELECT target_clusterid, target_", idAlias, ", ref_clusterid, avg(distance) dissimilarity 
                FROM VectorDistance(
                  ON kmeansplotresult AS target PARTITION BY ", idAlias, "
                  ON kmeansplotresult AS ref DIMENSION
                  TARGETIDCOLUMNS('clusterid','", idAlias, "')
                  TARGETFEATURECOLUMN('variable')
                  TARGETVALUECOLUMN('value')
                  REFIDCOLUMNS('clusterid','", idAlias, "')
                  REFFEATURECOLUMN('variable')
                  REFVALUECOLUMN('value')
                  MEASURE('Euclidean')
                ) 
          WHERE target_", idAlias," != ref_", idAlias, " 
          GROUP BY 1,2,3
         ) agg  
       GROUP BY 1,2
     ) sil"
  )
}


#' Perform canopy clustering on the table to determine cluster centers.
#' 
#' Canopy clustering algorithm runs in-database, returns centroids compatible with \code{\link{computeKmeans}} and 
#' pre-processes data for k-means and other clustering algorithms.
#' 
#' Canopy clustering often precedes kmeans algorithm (see \code{\link{computeKmeans}}) 
#' or other clustering algorithms. The goal is to speed up clustering by choosing initial centroids more efficiently
#' than randomly or naively, especially for big data applications. An important notes are that:
#' \itemize{
#'   \item function does not let specify number of canopies (clusters), instead it controls them with pair of
#'     threshold arguments \code{looseDistance} and \code{tightDistance}. By adjusting them one tunes 
#'     \code{computeCanopy} to produce more or less canopies as desired.
#'   \item individual data points may be part of several canopies and cluster memberhip is not available as result 
#'     of the operation.
#'   \item resulting \code{toacanopy} object should be passed to \code{computeKmeans} with \code{canopy} argument 
#'    effectively overriding arguments in kmeans function.
#' }
#' 
#' The function fist scales not-null data (if \code{scale=TRUE}) or just eliminate nulls without scaling. After 
#' that the data given (table \code{tableName} with option of filering with \code{where}) are clustered using canopy 
#' algorithm in Aster. This results in 
#' \enumerate{
#'   \item set of centroids to use as initial cluster centers in k-means and
#'   \item pre-processed data persisted and ready for clustering with kmeans function \code{computeKmeans}.
#' }
#' 
#' @param channel connection object as returned by \code{\link{odbcConnect}}.
#' @param canopy an object of class \code{"toacanopy"} obtained with \code{computeCanopy}.
#' @param tableName Aster table name.
#' @param looseDistance specifies the maximum distance that any point can be from a canopy center to be considered 
#'   part of that canopy.
#' @param tightDistance specifies the minimum distance that separates two canopy centers.
#' @param tableInfo pre-built summary of data to use (require when \code{test=TRUE}). See \code{\link{getTableSummary}}.
#' @param id column name or SQL expression containing unique table key.
#' @param idAlias SQL alias for table id. This is required when SQL expression is given for \code{id}.
#' @param include a vector of column names with variables (must be numeric). Model never contains variables other than in the list.
#' @param except a vector of column names to exclude from variables. Model never contains variables from the list.
#' @param scale logical if TRUE then scale each variable in-database before clustering. Scaling performed results in 0 mean and unit
#'   standard deviation for each of input variables. when \code{FALSE} then function only removes incomplete
#'   data before clustering (conaining \code{NULL}s).
#' @param where specifies criteria to satisfy by the table rows before applying
#'   computation. The creteria are expressed in the form of SQL predicates (inside
#'   \code{WHERE} clause).
#' @param scaledTableName the name of the Aster table with results of scaling
#' @param schema name of Aster schema that tables \code{scaledTableName}, \code{centroidTableName}, and
#'   \code{clusteredTableName} belong to. Make sure that when this argument is supplied no table name defined
#'   contain schema in its name.
#' @param test logical: if TRUE show what would be done, only (similar to parameter \code{test} in \pkg{RODBC} 
#'   functions: \link{sqlQuery} and \link{sqlSave}).
#'   
#' @export
#' @seealso \code{\link{computeClusterSample}}, \code{\link{computeSilhouette}}, \code{\link{computeCanopy}}
#' @examples 
#' if(interactive()){
#' # initialize connection to Lahman baseball database in Aster 
#' conn = odbcDriverConnect(connection="driver={Aster ODBC Driver};
#'                          server=<dbhost>;port=2406;database=<dbname>;uid=<user>;pwd=<pw>")
#' can = computeCanopy(conn, "batting", looseDistance = 1, tightDistance = 0.5,
#'                     id="playerid || '-' || stint || '-' || teamid || '-' || yearid", 
#'                     include=c('g','r','h'), 
#'                     scaledTableName='test_canopy_scaled', 
#'                     where="yearid > 2000")
#' createCentroidPlot(can)
#' 
#' can = computeCanopy(conn, canopy = can, looseDistance = 2, tightDistance = 0.5)
#' createCentroidPlot(can)
#' 
#' can = computeCanopy(conn, canopy = can, looseDistance = 4, tightDistance = 1)
#' createCentroidPlot(can)
#'
#' km = computeKmeans(conn, centers=can, iterMax = 1000, persist = TRUE, 
#'                    aggregates = c("COUNT(*) cnt", "AVG(g) avg_g", "AVG(r) avg_r", "AVG(h) avg_h"),
#'                    centroidTableName = "kmeans_test_centroids",
#'                    tempTableName = "kmeans_test_temp",
#'                    clusteredTableName = "kmeans_test_clustered") 
#' createCentroidPlot(km)
#' 
#' }
computeCanopy <- function(channel, tableName, looseDistance, tightDistance,
                          canopy, 
                          tableInfo, id, include=NULL, except=NULL, 
                          scale=TRUE, 
                          idAlias=gsub("[^0-9a-zA-Z]+", "_", id), where=NULL,
                          scaledTableName=NULL, schema=NULL, test=FALSE) {
  
  ptm = proc.time()
  
  if (test && missing(tableInfo) && missing(canopy)) {
    stop("Must provide tableInfo when test==TRUE")
  }
  
  if (tightDistance >= looseDistance) 
    stop("The loose distance must be greater than the tight distance.")
  
  isValidConnection(channel, test)
  
  if(missing(canopy) || is.null(canopy)) {
    canopy = NULL
    
    tableName = normalizeTableName(tableName)
    
    if (missing(tableInfo)) {
      tableInfo = sqlColumns(channel, tableName)
    }
    
    columns = getNumericColumns(tableInfo, names.only=TRUE, include=include, except=except)
    columns = sort(setdiff(columns, id))
    
    if (is.null(columns) || length(columns) < 1) {
      stop("Kmeans operates on one or more numeric variables.")
    }
    
    # check if id alias is not one of independent variables
    if(idAlias %in% columns)
      stop(paste0("Id alias '", idAlias, "' can't be one of variable names."))
    
    # adjust id alias if it's exactly one of the table columns
    if(idAlias %in% tableInfo$COLUMN_NAME)
      idAlias = paste("_", idAlias, "_", sep="_")
    
    # scale data table name
    if (is.null(scaledTableName))
      scaledTableName = makeTempTableName('scaled', 30, schema)
    else if (!is.null(schema))
      scaledTableName = paste0(schema, ".", scaledTableName)
    
    where_clause = makeWhereClause(where)
  }else {
    tableName = canopy$tableName
    columns = canopy$columns
    id = canopy$id
    idAlias = canopy$idAlias
    scale = canopy$scale
    scaledTableName = canopy$scaledTableName
    schema = canopy$schema
    where_clause = canopy$whereClause
  }
  
  emptyLine = "--"
  
  if(test)
    sqlText = ""
  
  # scale data or just eliminate incomplete observations (if not scaling)
  if (is.null(canopy)) {
    sqlComment = paste("-- Data Prep:", ifelse(scale, "scale", "omit nulls"))
    sqlDrop = paste("DROP TABLE IF EXISTS", scaledTableName)
    sql = getDataPrepSql(scale, tableName, scaledTableName, columns, id, idAlias, where_clause)
    if(test) {
      sqlText = paste(sqlComment, sqlDrop, sep='\n')
      sqlText = paste(sqlText, sql, sep=';\n')
    }else {
      toaSqlQuery(channel, sqlDrop)
      toaSqlQuery(channel, sql)
    }
  }
  
  # run canopy
  sqlComment = "-- Run canopy"
  sql = getCanopySql(scaledTableName, looseDistance, tightDistance)
  if(test) {
    sqlText = paste(sqlText, sqlComment, sql, sep=';\n')
  }else {
    centers = toaSqlQuery(channel, sql)
  }
  
  # return sql
  if(test) {
    sqlText = paste0(sqlText, ';')
    return(sqlText)
  }
  
  result = makeCanopyResult(centers, tableName, columns, looseDistance, tightDistance, 
                            scale, scaledTableName, id, idAlias, schema,
                            where_clause, ptm)
  
  return(result)
}


getCanopySql <- function(scaledTableName, looseDistance, tightDistance) {
  
  sql = paste0(
    "SELECT * FROM Canopy(
       ON (SELECT 1) PARTITION BY 1
       InputTable('",scaledTableName,"')
       LooseDistance('",looseDistance,"')
       TightDistance('",tightDistance,"')
     )"
  )
}

makeCanopyResult <- function(centers, tableName, columns, looseDistance, tightDistance,
                             scale, scaledTableName, id, idAlias, schema,
                             whereClause, ptm) {
  
  centers = as.matrix(centers[,-1])
  
  z <- structure(list(centers=centers,
                      looseDistance=looseDistance, 
                      tightDistance=tightDistance,
                      
                      tableName=tableName,
                      columns=columns,
                      scale=scale,
                      scaledTableName=scaledTableName,
                      schema=schema,
                      id=id,
                      idAlias=idAlias,
                      whereClause=whereClause,
                      time=proc.time() - ptm
  ),
  class = c("toacanopy"))
  
  return(z)
}

Try the toaster package in your browser

Any scripts or data that you put into this service are public.

toaster documentation built on May 30, 2017, 3:51 a.m.