R/xgb_mat.R

Defines functions xgb_mat xgb_mat.data.frame xgb_mat.default xgb_mat.matrix xgb_mat.dgCMatrix xgb_mat.dgCMatrix

Documented in xgb_mat

#' xgb matrix
#'
#' Simple wrapper for creating a xgboost matrix
#'
#' @param x Input data
#' @param ... Other data to cbind
#' @param y Label vector
#' @param split Optional number between 0-1 indicating the desired split between
#'  train and test
#' @return A xgb.Dmatrix
#' @examples
#'
#' xgb_mat(data.frame(x = rnorm(20), y = rnorm(20)))
#'
#' @export
xgb_mat <- function(x, ..., y = NULL, split = NULL) {
  UseMethod("xgb_mat")
}

#' @export
xgb_mat.data.frame <- function(x, ..., y = NULL, split = NULL) {
  x <- as.matrix.data.frame(x, rownames.force = FALSE)
  xgb_mat(x, ..., y = y, split = split)
}

#' @export
xgb_mat.default <- function(x, ..., y = NULL, split = NULL) {
  x <- as.matrix(x)
  xgb_mat(x, ..., y = y, split = split)
}

#' @export
xgb_mat.matrix <- function(x, ..., y = NULL, split = NULL) {
  x <- cbind(x, ...)
  if (is.null(split)) {
    if (is.null(y)) {
      return(xgboost::xgb.DMatrix(x))
    }
    return(xgboost::xgb.DMatrix(x, label = y))
  }
  train_rows <- sample(seq_len(nrow(x)), nrow(x) * split)
  if (is.null(y)) {
    return(list(
      train = xgboost::xgb.DMatrix(x[train_rows, , drop = FALSE]),
      test = xgboost::xgb.DMatrix(x[-train_rows, , drop = FALSE])
    ))
  }
  list(
    train = xgboost::xgb.DMatrix(x[train_rows, , drop = FALSE],
      label = y[train_rows]),
    test = xgboost::xgb.DMatrix(x[-train_rows, , drop = FALSE],
      label = y[-train_rows])
  )
}

#' @export
xgb_mat.dgCMatrix <- function(x, ..., y = NULL, split = NULL) {
  x <- cbind(x, ...)
  if (is.null(split)) {
    if (is.null(y)) {
      return(xgboost::xgb.DMatrix(x))
    }
    return(xgboost::xgb.DMatrix(x, label = y))
  }
  train_rows <- sample(seq_len(nrow(x)), nrow(x) * split)
  if (is.null(y)) {
    return(list(
      train = xgboost::xgb.DMatrix(x[train_rows, , drop = FALSE]),
      test = xgboost::xgb.DMatrix(x[-train_rows, , drop = FALSE])
    ))
  }
  list(
    train = xgboost::xgb.DMatrix(x[train_rows, , drop = FALSE],
      label = y[train_rows]),
    test = xgboost::xgb.DMatrix(x[-train_rows, , drop = FALSE],
      label = y[-train_rows])
  )
}



#' @export
xgb_mat.dgCMatrix <- function(x, ..., y = NULL, split = NULL) {
  x <- cbind(x, ...)
  if (is.null(split)) {
    if (is.null(y)) {
      return(xgboost::xgb.DMatrix(x))
    }
    return(xgboost::xgb.DMatrix(x, label = y))
  }
  train_rows <- sample(seq_len(nrow(x)), nrow(x) * split)
  if (is.null(y)) {
    return(list(
      train = xgboost::xgb.DMatrix(x[train_rows, , drop = FALSE]),
      test = xgboost::xgb.DMatrix(x[-train_rows, , drop = FALSE])
    ))
  }
  list(
    train = xgboost::xgb.DMatrix(x[train_rows, , drop = FALSE],
      label = y[train_rows]),
    test = xgboost::xgb.DMatrix(x[-train_rows, , drop = FALSE],
      label = y[-train_rows])
  )
}

Try the wactor package in your browser

Any scripts or data that you put into this service are public.

wactor documentation built on Dec. 18, 2019, 5:07 p.m.