R/utils.R

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Utilities and Helpers

# Given a JList<T>, returns an R list containing the same elements, the number
# of which is optionally upper bounded by `logicalUpperBound` (by default,
# return all elements).  Takes care of deserializations and type conversions.
convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL,
  serializedMode = "byte") {
  arrSize <- callJMethod(jList, "size")

  # Datasets with serializedMode == "string" (such as an RDD directly generated by textFile()):
  # each partition is not dense-packed into one Array[Byte], and `arrSize`
  # here corresponds to number of logical elements. Thus we can prune here.
  if (serializedMode == "string" && !is.null(logicalUpperBound)) {
    arrSize <- min(arrSize, logicalUpperBound)
  }

  results <- if (arrSize > 0) {
    lapply(0 : (arrSize - 1),
          function(index) {
            obj <- callJMethod(jList, "get", as.integer(index))

            # Assume it is either an R object or a Java obj ref.
            if (inherits(obj, "jobj")) {
              if (isInstanceOf(obj, "scala.Tuple2")) {
                # JavaPairRDD[Array[Byte], Array[Byte]].

                keyBytes <- callJMethod(obj, "_1")
                valBytes <- callJMethod(obj, "_2")
                res <- list(unserialize(keyBytes),
                  unserialize(valBytes))
              } else {
                stop(paste("utils.R: convertJListToRList only supports",
                  "RDD[Array[Byte]] and",
                  "JavaPairRDD[Array[Byte], Array[Byte]] for now"))
              }
            } else {
              if (inherits(obj, "raw")) {
                if (serializedMode == "byte") {
                  # RDD[Array[Byte]]. `obj` is a whole partition.
                  res <- unserialize(obj)
                  # For serialized datasets, `obj` (and `rRaw`) here corresponds to
                  # one whole partition dense-packed together. We deserialize the
                  # whole partition first, then cap the number of elements to be returned.
                } else if (serializedMode == "row") {
                  res <- readRowList(obj)
                  # For DataFrames that have been converted to RRDDs, we call readRowList
                  # which will read in each row of the RRDD as a list and deserialize
                  # each element.
                  flatten <<- FALSE
                  # Use global assignment to change the flatten flag. This means
                  # we don't have to worry about the default argument in other functions
                  # e.g. collect
                }
                # TODO: is it possible to distinguish element boundary so that we can
                # unserialize only what we need?
                if (!is.null(logicalUpperBound)) {
                  res <- head(res, n = logicalUpperBound)
                }
              } else {
                # obj is of a primitive Java type, is simplified to R's
                # corresponding type.
                res <- list(obj)
              }
            }
            res
          })
  } else {
    list()
  }

  if (flatten) {
    as.list(unlist(results, recursive = FALSE))
  } else {
    as.list(results)
  }
}

# Returns TRUE if `name` refers to an RDD in the given environment `env`
isRDD <- function(name, env) {
  obj <- get(name, envir = env)
  inherits(obj, "RDD")
}

#' Compute the hashCode of an object
#'
#' Java-style function to compute the hashCode for the given object. Returns
#' an integer value.
#'
#' @details
#' This only works for integer, numeric and character types right now.
#'
#' @param key the object to be hashed
#' @return the hash code as an integer
#' @export
#' @examples
#'\dontrun{
#' hashCode(1L) # 1
#' hashCode(1.0) # 1072693248
#' hashCode("1") # 49
#'}
#' @note hashCode since 1.4.0
hashCode <- function(key) {
  if (class(key) == "integer") {
    as.integer(key[[1]])
  } else if (class(key) == "numeric") {
    # Convert the double to long and then calculate the hash code
    rawVec <- writeBin(key[[1]], con = raw())
    intBits <- packBits(rawToBits(rawVec), "integer")
    as.integer(bitwXor(intBits[2], intBits[1]))
  } else if (class(key) == "character") {
    # TODO: SPARK-7839 means we might not have the native library available
    n <- nchar(key)
    if (n == 0) {
      0L
    } else {
      asciiVals <- sapply(charToRaw(key), function(x) { strtoi(x, 16L) })
      hashC <- 0
      for (k in 1:length(asciiVals)) {
        hashC <- mult31AndAdd(hashC, asciiVals[k])
      }
      as.integer(hashC)
    }
  } else {
    warning(paste("Could not hash object, returning 0", sep = ""))
    as.integer(0)
  }
}

# Helper function used to wrap a 'numeric' value to integer bounds.
# Useful for implementing C-like integer arithmetic
wrapInt <- function(value) {
  if (value > .Machine$integer.max) {
    value <- value - 2 * .Machine$integer.max - 2
  } else if (value < -1 * .Machine$integer.max) {
    value <- 2 * .Machine$integer.max + value + 2
  }
  value
}

# Multiply `val` by 31 and add `addVal` to the result. Ensures that
# integer-overflows are handled at every step.
#
# TODO: this function does not handle integer overflow well
mult31AndAdd <- function(val, addVal) {
  vec <- c(bitwShiftL(val, c(4, 3, 2, 1, 0)), addVal)
  vec[is.na(vec)] <- 0
  Reduce(function(a, b) {
          wrapInt(as.numeric(a) + as.numeric(b))
         },
         vec)
}

# Create a new RDD with serializedMode == "byte".
# Return itself if already in "byte" format.
serializeToBytes <- function(rdd) {
  if (!inherits(rdd, "RDD")) {
    stop("Argument 'rdd' is not an RDD type.")
  }
  if (getSerializedMode(rdd) != "byte") {
    ser.rdd <- lapply(rdd, function(x) { x })
    return(ser.rdd)
  } else {
    return(rdd)
  }
}

# Create a new RDD with serializedMode == "string".
# Return itself if already in "string" format.
serializeToString <- function(rdd) {
  if (!inherits(rdd, "RDD")) {
    stop("Argument 'rdd' is not an RDD type.")
  }
  if (getSerializedMode(rdd) != "string") {
    ser.rdd <- lapply(rdd, function(x) { toString(x) })
    # force it to create jrdd using "string"
    getJRDD(ser.rdd, serializedMode = "string")
    return(ser.rdd)
  } else {
    return(rdd)
  }
}

# Fast append to list by using an accumulator.
# http://stackoverflow.com/questions/17046336/here-we-go-again-append-an-element-to-a-list-in-r
#
# The accumulator should has three fields size, counter and data.
# This function amortizes the allocation cost by doubling
# the size of the list every time it fills up.
addItemToAccumulator <- function(acc, item) {
  if (acc$counter == acc$size) {
    acc$size <- acc$size * 2
    length(acc$data) <- acc$size
  }
  acc$counter <- acc$counter + 1
  acc$data[[acc$counter]] <- item
}

initAccumulator <- function() {
  acc <- new.env()
  acc$counter <- 0
  acc$data <- list(NULL)
  acc$size <- 1
  acc
}

# Utility function to sort a list of key value pairs
# Used in unit tests
sortKeyValueList <- function(kv_list, decreasing = FALSE) {
  keys <- sapply(kv_list, function(x) x[[1]])
  kv_list[order(keys, decreasing = decreasing)]
}

# Utility function to generate compact R lists from grouped rdd
# Used in Join-family functions
# param:
#   tagged_list R list generated via groupByKey with tags(1L, 2L, ...)
#   cnull Boolean list where each element determines whether the corresponding list should
#         be converted to list(NULL)
genCompactLists <- function(tagged_list, cnull) {
  len <- length(tagged_list)
  lists <- list(vector("list", len), vector("list", len))
  index <- list(1, 1)

  for (x in tagged_list) {
    tag <- x[[1]]
    idx <- index[[tag]]
    lists[[tag]][[idx]] <- x[[2]]
    index[[tag]] <- idx + 1
  }

  len <- lapply(index, function(x) x - 1)
  for (i in (1:2)) {
    if (cnull[[i]] && len[[i]] == 0) {
      lists[[i]] <- list(NULL)
    } else {
      length(lists[[i]]) <- len[[i]]
    }
  }

  lists
}

# Utility function to merge compact R lists
# Used in Join-family functions
# param:
#   left/right Two compact lists ready for Cartesian product
mergeCompactLists <- function(left, right) {
  result <- list()
  length(result) <- length(left) * length(right)
  index <- 1
  for (i in left) {
    for (j in right) {
      result[[index]] <- list(i, j)
      index <- index + 1
    }
  }
  result
}

# Utility function to wrapper above two operations
# Used in Join-family functions
# param (same as genCompactLists):
#   tagged_list R list generated via groupByKey with tags(1L, 2L, ...)
#   cnull Boolean list where each element determines whether the corresponding list should
#         be converted to list(NULL)
joinTaggedList <- function(tagged_list, cnull) {
  lists <- genCompactLists(tagged_list, cnull)
  mergeCompactLists(lists[[1]], lists[[2]])
}

# Utility function to reduce a key-value list with predicate
# Used in *ByKey functions
# param
#   pair key-value pair
#   keys/vals env of key/value with hashes
#   updateOrCreatePred predicate function
#   updateFn update or merge function for existing pair, similar with `mergeVal` @combineByKey
#   createFn create function for new pair, similar with `createCombiner` @combinebykey
updateOrCreatePair <- function(pair, keys, vals, updateOrCreatePred, updateFn, createFn) {
  # assume hashVal bind to `$hash`, key/val with index 1/2
  hashVal <- pair$hash
  key <- pair[[1]]
  val <- pair[[2]]
  if (updateOrCreatePred(pair)) {
    assign(hashVal, do.call(updateFn, list(get(hashVal, envir = vals), val)), envir = vals)
  } else {
    assign(hashVal, do.call(createFn, list(val)), envir = vals)
    assign(hashVal, key, envir = keys)
  }
}

# Utility function to convert key&values envs into key-val list
convertEnvsToList <- function(keys, vals) {
  lapply(ls(keys),
         function(name) {
           list(keys[[name]], vals[[name]])
         })
}

# Utility function to merge 2 environments with the second overriding values in the first
# env1 is changed in place
overrideEnvs <- function(env1, env2) {
  lapply(ls(env2),
         function(name) {
           env1[[name]] <- env2[[name]]
         })
}

# Utility function to capture the varargs into environment object
varargsToEnv <- function(...) {
  # Based on http://stackoverflow.com/a/3057419/4577954
  pairs <- list(...)
  env <- new.env()
  for (name in names(pairs)) {
    env[[name]] <- pairs[[name]]
  }
  env
}

# Utility function to capture the varargs into environment object but all values are converted
# into string.
varargsToStrEnv <- function(...) {
  pairs <- list(...)
  nameList <- names(pairs)
  env <- new.env()
  ignoredNames <- list()

  if (is.null(nameList)) {
    # When all arguments are not named, names(..) returns NULL.
    ignoredNames <- pairs
  } else {
    for (i in seq_along(pairs)) {
      name <- nameList[i]
      value <- pairs[i]
      if (identical(name, "")) {
        # When some of arguments are not named, name is "".
        ignoredNames <- append(ignoredNames, value)
      } else {
        value <- pairs[[name]]
        if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) {
          stop(paste0("Unsupported type for ", name, " : ", class(value),
               ". Supported types are logical, numeric, character and NULL."), call. = FALSE)
        }
        if (is.logical(value)) {
          env[[name]] <- tolower(as.character(value))
        } else if (is.null(value)) {
          env[[name]] <- value
        } else {
          env[[name]] <- as.character(value)
        }
      }
    }
  }

  if (length(ignoredNames) != 0) {
    warning(paste0("Unnamed arguments ignored: ", paste(ignoredNames, collapse = ", "), "."),
            call. = FALSE)
  }
  env
}

getStorageLevel <- function(newLevel = c("DISK_ONLY",
                                         "DISK_ONLY_2",
                                         "MEMORY_AND_DISK",
                                         "MEMORY_AND_DISK_2",
                                         "MEMORY_AND_DISK_SER",
                                         "MEMORY_AND_DISK_SER_2",
                                         "MEMORY_ONLY",
                                         "MEMORY_ONLY_2",
                                         "MEMORY_ONLY_SER",
                                         "MEMORY_ONLY_SER_2",
                                         "OFF_HEAP")) {
  match.arg(newLevel)
  storageLevelClass <- "org.apache.spark.storage.StorageLevel"
  storageLevel <- switch(newLevel,
                         "DISK_ONLY" = callJStatic(storageLevelClass, "DISK_ONLY"),
                         "DISK_ONLY_2" = callJStatic(storageLevelClass, "DISK_ONLY_2"),
                         "MEMORY_AND_DISK" = callJStatic(storageLevelClass, "MEMORY_AND_DISK"),
                         "MEMORY_AND_DISK_2" = callJStatic(storageLevelClass, "MEMORY_AND_DISK_2"),
                         "MEMORY_AND_DISK_SER" = callJStatic(storageLevelClass,
                                                             "MEMORY_AND_DISK_SER"),
                         "MEMORY_AND_DISK_SER_2" = callJStatic(storageLevelClass,
                                                               "MEMORY_AND_DISK_SER_2"),
                         "MEMORY_ONLY" = callJStatic(storageLevelClass, "MEMORY_ONLY"),
                         "MEMORY_ONLY_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_2"),
                         "MEMORY_ONLY_SER" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER"),
                         "MEMORY_ONLY_SER_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER_2"),
                         "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP"))
}

storageLevelToString <- function(levelObj) {
  useDisk <- callJMethod(levelObj, "useDisk")
  useMemory <- callJMethod(levelObj, "useMemory")
  useOffHeap <- callJMethod(levelObj, "useOffHeap")
  deserialized <- callJMethod(levelObj, "deserialized")
  replication <- callJMethod(levelObj, "replication")
  shortName <- if (!useDisk && !useMemory && !useOffHeap && !deserialized && replication == 1) {
    "NONE"
  } else if (useDisk && !useMemory && !useOffHeap && !deserialized && replication == 1) {
    "DISK_ONLY"
  } else if (useDisk && !useMemory && !useOffHeap && !deserialized && replication == 2) {
    "DISK_ONLY_2"
  } else if (!useDisk && useMemory && !useOffHeap && deserialized && replication == 1) {
    "MEMORY_ONLY"
  } else if (!useDisk && useMemory && !useOffHeap && deserialized && replication == 2) {
    "MEMORY_ONLY_2"
  } else if (!useDisk && useMemory && !useOffHeap && !deserialized && replication == 1) {
    "MEMORY_ONLY_SER"
  } else if (!useDisk && useMemory && !useOffHeap && !deserialized && replication == 2) {
    "MEMORY_ONLY_SER_2"
  } else if (useDisk && useMemory && !useOffHeap && deserialized && replication == 1) {
    "MEMORY_AND_DISK"
  } else if (useDisk && useMemory && !useOffHeap && deserialized && replication == 2) {
    "MEMORY_AND_DISK_2"
  } else if (useDisk && useMemory && !useOffHeap && !deserialized && replication == 1) {
    "MEMORY_AND_DISK_SER"
  } else if (useDisk && useMemory && !useOffHeap && !deserialized && replication == 2) {
    "MEMORY_AND_DISK_SER_2"
  } else if (useDisk && useMemory && useOffHeap && !deserialized && replication == 1) {
    "OFF_HEAP"
  } else {
    NULL
  }
  fullInfo <- callJMethod(levelObj, "toString")
  if (is.null(shortName)) {
    fullInfo
  } else {
    paste(shortName, "-", fullInfo)
  }
}

# Utility function for functions where an argument needs to be integer but we want to allow
# the user to type (for example) `5` instead of `5L` to avoid a confusing error message.
numToInt <- function(num) {
  if (as.integer(num) != num) {
    warning(paste("Coercing", as.list(sys.call())[[2]], "to integer."))
  }
  as.integer(num)
}

# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a
# user defined function (UDF), and to examine variables in the UDF to decide
# if their values should be included in the new function environment.
# param
#   node The current AST node in the traversal.
#   oldEnv The original function environment.
#   defVars An Accumulator of variables names defined in the function's calling environment,
#           including function argument and local variable names.
#   checkedFunc An environment of function objects examined during cleanClosure. It can
#               be considered as a "name"-to-"list of functions" mapping.
#   newEnv A new function environment to store necessary function dependencies, an output argument.
processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
  nodeLen <- length(node)

  if (nodeLen > 1 && typeof(node) == "language") {
    # Recursive case: current AST node is an internal node, check for its children.
    if (length(node[[1]]) > 1) {
      for (i in 1:nodeLen) {
        processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
      }
    } else {
      # if node[[1]] is length of 1, check for some R special functions.
      nodeChar <- as.character(node[[1]])
      if (nodeChar == "{" || nodeChar == "(") {
        # Skip start symbol.
        for (i in 2:nodeLen) {
          processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
        }
      } else if (nodeChar == "<-" || nodeChar == "=" ||
                   nodeChar == "<<-") {
        # Assignment Ops.
        defVar <- node[[2]]
        if (length(defVar) == 1 && typeof(defVar) == "symbol") {
          # Add the defined variable name into defVars.
          addItemToAccumulator(defVars, as.character(defVar))
        } else {
          processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
        }
        for (i in 3:nodeLen) {
          processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
        }
      } else if (nodeChar == "function") {
        # Function definition.
        # Add parameter names.
        newArgs <- names(node[[2]])
        lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) })
        for (i in 3:nodeLen) {
          processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
        }
      } else if (nodeChar == "$") {
        # Skip the field.
        processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
      } else if (nodeChar == "::" || nodeChar == ":::") {
        processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv)
      } else {
        for (i in 1:nodeLen) {
          processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
        }
      }
    }
  } else if (nodeLen == 1 &&
               (typeof(node) == "symbol" || typeof(node) == "language")) {
    # Base case: current AST node is a leaf node and a symbol or a function call.
    nodeChar <- as.character(node)
    if (!nodeChar %in% defVars$data) {
      # Not a function parameter or local variable.
      func.env <- oldEnv
      topEnv <- parent.env(.GlobalEnv)
      # Search in function environment, and function's enclosing environments
      # up to global environment. There is no need to look into package environments
      # above the global or namespace environment that is not SparkR below the global,
      # as they are assumed to be loaded on workers.
      while (!identical(func.env, topEnv)) {
        # Namespaces other than "SparkR" will not be searched.
        if (!isNamespace(func.env) ||
            (getNamespaceName(func.env) == "SparkR" &&
               !(nodeChar %in% getNamespaceExports("SparkR")))) {
          # Only include SparkR internals.

          # Set parameter 'inherits' to FALSE since we do not need to search in
          # attached package environments.
          if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE),
                       error = function(e) { FALSE })) {
            obj <- get(nodeChar, envir = func.env, inherits = FALSE)
            if (is.function(obj)) {
              # If the node is a function call.
              funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
                               ifnotfound = list(list(NULL)))[[1]]
              found <- sapply(funcList, function(func) {
                ifelse(identical(func, obj), TRUE, FALSE)
              })
              if (sum(found) > 0) {
                # If function has been examined, ignore.
                break
              }
              # Function has not been examined, record it and recursively clean its closure.
              assign(nodeChar,
                     if (is.null(funcList[[1]])) {
                       list(obj)
                     } else {
                       append(funcList, obj)
                     },
                     envir = checkedFuncs)
              obj <- cleanClosure(obj, checkedFuncs)
            }
            assign(nodeChar, obj, envir = newEnv)
            break
          }
        }

        # Continue to search in enclosure.
        func.env <- parent.env(func.env)
      }
    }
  }
}

# Utility function to get user defined function (UDF) dependencies (closure).
# More specifically, this function captures the values of free variables defined
# outside a UDF, and stores them in the function's environment.
# param
#   func A function whose closure needs to be captured.
#   checkedFunc An environment of function objects examined during cleanClosure. It can be
#               considered as a "name"-to-"list of functions" mapping.
# return value
#   a new version of func that has a correct environment (closure).
cleanClosure <- function(func, checkedFuncs = new.env()) {
  if (is.function(func)) {
    newEnv <- new.env(parent = .GlobalEnv)
    func.body <- body(func)
    oldEnv <- environment(func)
    # defVars is an Accumulator of variables names defined in the function's calling
    # environment. First, function's arguments are added to defVars.
    defVars <- initAccumulator()
    argNames <- names(as.list(args(func)))
    for (i in 1:(length(argNames) - 1)) {
      # Remove the ending NULL in pairlist.
      addItemToAccumulator(defVars, argNames[i])
    }
    # Recursively examine variables in the function body.
    processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv)
    environment(func) <- newEnv
  }
  func
}

# Append partition lengths to each partition in two input RDDs if needed.
# param
#   x An RDD.
#   Other An RDD.
# return value
#   A list of two result RDDs.
appendPartitionLengths <- function(x, other) {
  if (getSerializedMode(x) != getSerializedMode(other) ||
      getSerializedMode(x) == "byte") {
    # Append the number of elements in each partition to that partition so that we can later
    # know the boundary of elements from x and other.
    #
    # Note that this appending also serves the purpose of reserialization, because even if
    # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded
    # as a single byte array. For example, partitions of an RDD generated from partitionBy()
    # may be encoded as multiple byte arrays.
    appendLength <- function(part) {
      len <- length(part)
      part[[len + 1]] <- len + 1
      part
    }
    x <- lapplyPartition(x, appendLength)
    other <- lapplyPartition(other, appendLength)
  }
  list(x, other)
}

# Perform zip or cartesian between elements from two RDDs in each partition
# param
#   rdd An RDD.
#   zip A boolean flag indicating this call is for zip operation or not.
# return value
#   A result RDD.
mergePartitions <- function(rdd, zip) {
  serializerMode <- getSerializedMode(rdd)
  partitionFunc <- function(partIndex, part) {
    len <- length(part)
    if (len > 0) {
      if (serializerMode == "byte") {
        lengthOfValues <- part[[len]]
        lengthOfKeys <- part[[len - lengthOfValues]]
        stopifnot(len == lengthOfKeys + lengthOfValues)

        # For zip operation, check if corresponding partitions
        # of both RDDs have the same number of elements.
        if (zip && lengthOfKeys != lengthOfValues) {
          stop(paste("Can only zip RDDs with same number of elements",
                     "in each pair of corresponding partitions."))
        }

        if (lengthOfKeys > 1) {
          keys <- part[1 : (lengthOfKeys - 1)]
        } else {
          keys <- list()
        }
        if (lengthOfValues > 1) {
          values <- part[(lengthOfKeys + 1) : (len - 1)]
        } else {
          values <- list()
        }

        if (!zip) {
          return(mergeCompactLists(keys, values))
        }
      } else {
        keys <- part[c(TRUE, FALSE)]
        values <- part[c(FALSE, TRUE)]
      }
      mapply(
        function(k, v) { list(k, v) },
        keys,
        values,
        SIMPLIFY = FALSE,
        USE.NAMES = FALSE)
    } else {
      part
    }
  }

  PipelinedRDD(rdd, partitionFunc)
}

# Convert a named list to struct so that
# SerDe won't confuse between a normal named list and struct
listToStruct <- function(list) {
  stopifnot(class(list) == "list")
  stopifnot(!is.null(names(list)))
  class(list) <- "struct"
  list
}

# Convert a struct to a named list
structToList <- function(struct) {
  stopifnot(class(list) == "struct")

  class(struct) <- "list"
  struct
}

# Convert a named list to an environment to be passed to JVM
convertNamedListToEnv <- function(namedList) {
  # Make sure each item in the list has a name
  names <- names(namedList)
  stopifnot(
    if (is.null(names)) {
      length(namedList) == 0
    } else {
      !any(is.na(names))
    })

  env <- new.env()
  for (name in names) {
    env[[name]] <- namedList[[name]]
  }
  env
}

# Assign a new environment for attach() and with() methods
assignNewEnv <- function(data) {
  stopifnot(class(data) == "SparkDataFrame")
  cols <- columns(data)
  stopifnot(length(cols) > 0)

  env <- new.env()
  for (i in 1:length(cols)) {
    assign(x = cols[i], value = data[, cols[i], drop = F], envir = env)
  }
  env
}

# Utility function to split by ',' and whitespace, remove empty tokens
splitString <- function(input) {
  Filter(nzchar, unlist(strsplit(input, ",|\\s")))
}

varargsToJProperties <- function(...) {
  pairs <- list(...)
  props <- newJObject("java.util.Properties")
  if (length(pairs) > 0) {
    lapply(ls(pairs), function(k) {
      callJMethod(props, "setProperty", as.character(k), as.character(pairs[[k]]))
    })
  }
  props
}

launchScript <- function(script, combinedArgs, wait = FALSE) {
  if (.Platform$OS.type == "windows") {
    scriptWithArgs <- paste(script, combinedArgs, sep = " ")
    # on Windows, intern = F seems to mean output to the console. (documentation on this is missing)
    shell(scriptWithArgs, translate = TRUE, wait = wait, intern = wait) # nolint
  } else {
    # http://stat.ethz.ch/R-manual/R-devel/library/base/html/system2.html
    # stdout = F means discard output
    # stdout = "" means to its console (default)
    # Note that the console of this child process might not be the same as the running R process.
    system2(script, combinedArgs, stdout = "", wait = wait)
  }
}

getSparkContext <- function() {
  if (!exists(".sparkRjsc", envir = .sparkREnv)) {
    stop("SparkR has not been initialized. Please call sparkR.session()")
  }
  sc <- get(".sparkRjsc", envir = .sparkREnv)
  sc
}

isMasterLocal <- function(master) {
  grepl("^local(\\[([0-9]+|\\*)\\])?$", master, perl = TRUE)
}

isClientMode <- function(master) {
  grepl("([a-z]+)-client$", master, perl = TRUE)
}

isSparkRShell <- function() {
  grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE)
}

# Works identically with `callJStatic(...)` but throws a pretty formatted exception.
handledCallJStatic <- function(cls, method, ...) {
  result <- tryCatch(callJStatic(cls, method, ...),
                     error = function(e) {
                       captureJVMException(e, method)
                     })
  result
}

# Works identically with `callJMethod(...)` but throws a pretty formatted exception.
handledCallJMethod <- function(obj, method, ...) {
  result <- tryCatch(callJMethod(obj, method, ...),
                     error = function(e) {
                       captureJVMException(e, method)
                     })
  result
}

captureJVMException <- function(e, method) {
  rawmsg <- as.character(e)
  if (any(grep("^Error in .*?: ", rawmsg))) {
    # If the exception message starts with "Error in ...", this is possibly
    # "Error in invokeJava(...)". Here, it replaces the characters to
    # `paste("Error in", method, ":")` in order to identify which function
    # was called in JVM side.
    stacktrace <- strsplit(rawmsg, "Error in .*?: ")[[1]]
    rmsg <- paste("Error in", method, ":")
    stacktrace <- paste(rmsg[1], stacktrace[2])
  } else {
    # Otherwise, do not convert the error message just in case.
    stacktrace <- rawmsg
  }

  # StreamingQueryException could wrap an IllegalArgumentException, so look for that first
  if (any(grep("org.apache.spark.sql.streaming.StreamingQueryException: ", stacktrace))) {
    msg <- strsplit(stacktrace, "org.apache.spark.sql.streaming.StreamingQueryException: ",
                    fixed = TRUE)[[1]]
    # Extract "Error in ..." message.
    rmsg <- msg[1]
    # Extract the first message of JVM exception.
    first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
    stop(paste0(rmsg, "streaming query error - ", first), call. = FALSE)
  } else if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) {
    msg <- strsplit(stacktrace, "java.lang.IllegalArgumentException: ", fixed = TRUE)[[1]]
    # Extract "Error in ..." message.
    rmsg <- msg[1]
    # Extract the first message of JVM exception.
    first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
    stop(paste0(rmsg, "illegal argument - ", first), call. = FALSE)
  } else if (any(grep("org.apache.spark.sql.AnalysisException: ", stacktrace))) {
    msg <- strsplit(stacktrace, "org.apache.spark.sql.AnalysisException: ", fixed = TRUE)[[1]]
    # Extract "Error in ..." message.
    rmsg <- msg[1]
    # Extract the first message of JVM exception.
    first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
    stop(paste0(rmsg, "analysis error - ", first), call. = FALSE)
  } else
    if (any(grep("org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: ", stacktrace))) {
    msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: ",
                    fixed = TRUE)[[1]]
    # Extract "Error in ..." message.
    rmsg <- msg[1]
    # Extract the first message of JVM exception.
    first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
    stop(paste0(rmsg, "no such database - ", first), call. = FALSE)
  } else
    if (any(grep("org.apache.spark.sql.catalyst.analysis.NoSuchTableException: ", stacktrace))) {
    msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.analysis.NoSuchTableException: ",
                    fixed = TRUE)[[1]]
    # Extract "Error in ..." message.
    rmsg <- msg[1]
    # Extract the first message of JVM exception.
    first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
    stop(paste0(rmsg, "no such table - ", first), call. = FALSE)
  } else if (any(grep("org.apache.spark.sql.catalyst.parser.ParseException: ", stacktrace))) {
    msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.parser.ParseException: ",
                    fixed = TRUE)[[1]]
    # Extract "Error in ..." message.
    rmsg <- msg[1]
    # Extract the first message of JVM exception.
    first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
    stop(paste0(rmsg, "parse error - ", first), call. = FALSE)
  } else {
    stop(stacktrace, call. = FALSE)
  }
}

# rbind a list of rows with raw (binary) columns
#
# @param inputData a list of rows, with each row a list
# @return data.frame with raw columns as lists
rbindRaws <- function(inputData) {
  row1 <- inputData[[1]]
  rawcolumns <- ("raw" == sapply(row1, class))

  listmatrix <- do.call(rbind, inputData)
  # A dataframe with all list columns
  out <- as.data.frame(listmatrix)
  out[!rawcolumns] <- lapply(out[!rawcolumns], unlist)
  out
}

# Get basename without extension from URL
basenameSansExtFromUrl <- function(url) {
  # split by '/'
  splits <- unlist(strsplit(url, "^.+/"))
  last <- tail(splits, 1)
  # this is from file_path_sans_ext
  # first, remove any compression extension
  filename <- sub("[.](gz|bz2|xz)$", "", last)
  # then, strip extension by the last '.'
  sub("([^.]+)\\.[[:alnum:]]+$", "\\1", filename)
}

isAtomicLengthOne <- function(x) {
  is.atomic(x) && length(x) == 1
}

is_windows <- function() {
  .Platform$OS.type == "windows"
}

hadoop_home_set <- function() {
  !identical(Sys.getenv("HADOOP_HOME"), "")
}

windows_with_hadoop <- function() {
  !is_windows() || hadoop_home_set()
}

# get0 not supported before R 3.2.0
getOne <- function(x, envir, inherits = TRUE, ifnotfound = NULL) {
  mget(x[1L], envir = envir, inherits = inherits, ifnotfound = list(ifnotfound))[[1L]]
}

# Returns a vector of parent directories, traversing up count times, starting with a full path
# eg. traverseParentDirs("/Users/user/Library/Caches/spark/spark2.2", 1) should return
# this "/Users/user/Library/Caches/spark/spark2.2"
# and  "/Users/user/Library/Caches/spark"
traverseParentDirs <- function(x, count) {
  if (dirname(x) == x || count <= 0) x else c(x, Recall(dirname(x), count - 1))
}
vkapartzianis/SparkR documentation built on May 18, 2019, 8:10 p.m.