Tullia/split_v2.R

partysplit <- function(varid, breaks = NULL, index = NULL, right = TRUE,
                       prob = NULL, info = NULL, centroids = NULL, basid = NULL) {

  ### informal class for splits
  split <- vector(mode = "list", length = 8)
  names(split) <- c("varid", "breaks", "index", "right", "prob", "info", "centroids", "basid")

  ### split is an id referring to a variable
  if (!is.integer(varid))
    stop(sQuote("varid"), " ", "is not integer")
  split$varid <- varid

  if (is.null(breaks) && is.null(index))
    stop("either", " ", sQuote("breaks"), " ", "or", " ",
         sQuote("index"), " ", "must be given")

  ### vec
  if (!is.null(breaks)) {
    if (is.numeric(breaks) && (length(breaks) >= 1)) {
      ### FIXME: I think we need to make sure breaks are double in C
      split$breaks <- as.double(breaks)
    } else {
      stop(sQuote("break"), " ",
           "should be a numeric vector containing at least one element")
    }
  }

  if (!is.null(index)) {
    if (is.integer(index)) {
      if (!(length(index) >= 2))
        stop(sQuote("index"), " ", "has less than two elements")
      if (!(min(index, na.rm = TRUE) == 1))
        stop("minimum of", " ", sQuote("index"), " ", "is not equal to 1")
      if (!all.equal(diff(sort(unique(index))), rep(1, max(index, na.rm = TRUE) - 1)))
        stop(sQuote("index"), " ", "is not a contiguous sequence")
      split$index <- index
    } else {
      stop(sQuote("index"), " ", "is not a class", " ", sQuote("integer"))
    }
    if (!is.null(breaks)) {
      if (length(breaks) != (length(index) - 1))
        stop("length of", " ", sQuote("breaks"), " ",
             "does not match length of", " ", sQuote("index"))
    }
  }

  if (is.logical(right) & !is.na(right))
    split$right <- right
  else
    stop(sQuote("right"), " ", "is not a logical")

  if (!is.null(prob)) {
    if (!is.double(prob) ||
        (any(prob < 0) | any(prob > 1) | !isTRUE(all.equal(sum(prob), 1))))
      stop(sQuote("prob"), " ", "is not a vector of probabilities")
    if (!is.null(index)) {
      if (!(max(index, na.rm = TRUE) == length(prob)))
        stop("incorrect", " ", sQuote("index"))
    }
    if (!is.null(breaks) && is.null(index)) {
      if (!(length(breaks) == (length(prob) - 1)))
        stop("incorrect", " ", sQuote("breaks"))
    }
    split$prob <- prob
  }

  if (!is.null(info))
    split$info <- info

  if (!is.null(centroids))
    split$centroids <- centroids

  if (!is.null(basid))
    if (!is.integer(varid)){
      stop(sQuote("varid"), " ", "is not integer")
    } else {
      split$basid <- basid
    }

  class(split) <- "partysplit"

  return(split)
}

varid_split <- function(split) {
  if (!(inherits(split, "partysplit")))
    stop(sQuote("split"), " ", "is not an object of class",
         " ", sQuote("partysplit"))
  split$varid
}

breaks_split <- function(split) {
  if (!(inherits(split, "partysplit")))
    stop(sQuote("split"), " ", "is not an object of class",
         " ", sQuote("partysplit"))
  split$breaks
}

index_split <- function(split) {
  if (!(inherits(split, "partysplit")))
    stop(sQuote("split"), " ", "is not an object of class",
         " ", sQuote("partysplit"))
  split$index
}

right_split <- function(split) {
  if (!(inherits(split, "partysplit")))
    stop(sQuote("split"), " ", "is not an object of class",
         " ", sQuote("partysplit"))
  split$right
}

prob_split <- function(split) {
  if (!(inherits(split, "partysplit")))
    stop(sQuote("split"), " ", "is not an object of class",
         " ", sQuote("partysplit"))
  prob <- split$prob
  if (!is.null(prob)) return(prob)

  ### either breaks or index must be there
  if (is.null(index <- index_split(split))) {
    if (is.null(breaks <- breaks_split(split)))
      stop("neither", " ", sQuote("prob"), " ", "nor", " ",
           sQuote("index"), " ", "or", sQuote("breaks"), " ",
           "given for", " ", sQuote("split"))
    nkids <- length(breaks) + 1
  } else {
    nkids <- max(index, na.rm = TRUE)
  }
  prob <- rep(1, nkids) / nkids
  return(prob)
}

info_split <- function(split) {
  if (!(inherits(split, "partysplit")))
    stop(sQuote("split"), " ", "is not an object of class",
         " ", sQuote("partysplit"))
  split$info
}

centroids_split <- function(split) {
  if (!(inherits(split, "partysplit")))
    stop(sQuote("split"), " ", "is not an object of class",
         " ", sQuote("partysplit"))
  split$centroids
}

basid_split <- function(split) {
  if (!(inherits(split, "partysplit")))
    stop(sQuote("split"), " ", "is not an object of class",
         " ", sQuote("partysplit"))
  split$basid
}


kidids_split <- function(split, data, vmatch = 1:length(data), obs = NULL) {

  varid <- varid_split(split)
  basid <- basid_split(split)
  class(data) <- "list" ### speed up
  if(!is.null(basid)){ #means we are in the coeff case
    x <- data[[vmatch[varid]]][,basid]
  } else { #means we are in the cluster case
    x <- data[[vmatch[varid]]]
  }
  if (!is.null(obs)) x <- x[obs]

  if (is.null(breaks_split(split))) {

    if (storage.mode(x) != "integer")
      stop("variable", " ", vmatch[varid], " ", "is not integer")
  } else {
    ### labels = FALSE returns integers and is faster
    ### <FIXME> use findInterval instead of cut?
    #        x <- cut.default(as.numeric(x), labels = FALSE,
    #                 breaks = unique(c(-Inf, breaks_split(split), Inf)),  ### breaks_split(split) = Inf possible (MIA)
    #                 right = right_split(split))
    x <- .bincode(as.numeric(x), # labels = FALSE,
                  breaks = unique(c(-Inf, breaks_split(split), Inf)),  ### breaks_split(split) = Inf possible (MIA)
                  right = right_split(split))
    ### </FIXME>
  }
  index <- index_split(split)
  ### empty factor levels correspond to NA and return NA here
  ### and thus the corresponding observations will be treated
  ### as missing values (surrogate or random splits):
  if (!is.null(index))
    x <- index[x]
  return(x)
}


kidids_split_predict <- function(split, data, vmatch = 1:length(data), obs = NULL) {

  varid <- varid_split(split)
  basid <- basid_split(split)
  class(data) <- "list" ### speed up
  if(!is.null(basid)){ #means we are in the coeff case
    x <- data[[vmatch[varid]]][,basid]

  } else { #means we are in the cluster case
    x <- data[[vmatch[varid]]]
  }
  if (!is.null(obs)) x <- x[obs]

  if (is.null(breaks_split(split))) {

    if(!is.null(centroids_split(split))){

      cl.idx = lapply(centroids_split(split), compute.dissimilarity.cl, x = x, lp = 2)
      x <- apply(matrix(unlist(cl.idx),ncol = 2),1, which.min)
    }

    if (storage.mode(x) != "integer")
      stop("variable", " ", vmatch[varid], " ", "is not integer")
  } else {
    ### labels = FALSE returns integers and is faster
    ### <FIXME> use findInterval instead of cut?
    #        x <- cut.default(as.numeric(x), labels = FALSE,
    #                 breaks = unique(c(-Inf, breaks_split(split), Inf)),  ### breaks_split(split) = Inf possible (MIA)
    #                 right = right_split(split))
    x <- .bincode(as.numeric(x), # labels = FALSE,
                  breaks = unique(c(-Inf, breaks_split(split), Inf)),  ### breaks_split(split) = Inf possible (MIA)
                  right = right_split(split))
    ### </FIXME>
  }
  index <- index_split(split)
  ### empty factor levels correspond to NA and return NA here
  ### and thus the corresponding observations will be treated
  ### as missing values (surrogate or random splits):
  if (!is.null(index))
    x <- index[x]
  return(x)
}

character_split <- function(split, data = NULL, digits = getOption("digits") - 2) {

  varid <- varid_split(split)

  if (!is.null(data)) {
    ## names and labels
    lev <- lapply(data, levels)[[varid]]
    mlab <- names(data)[varid]

    ## determine split type
    type <- sapply(data, function(x) class(x)[1])[varid_split(split)]
    type[!(type %in% c("factor", "ordered"))] <- "numeric"
  } else {
    ## (bad) default names and labels
    lev <- NULL
    mlab <- paste("V", varid, sep = "")
    type <- "numeric"
  }

  ## process defaults for breaks and index
  breaks <- breaks_split(split)
  index <- index_split(split)
  right <- right_split(split)

  if (is.null(breaks)) breaks <- 1:(length(index) - 1)
  if (is.null(index)) index <- 1:(length(breaks) + 1)

  ## check whether ordered are really ordered
  if (type == "ordered") {
    if (length(breaks) > 1)
      type <- "factor"
  }
  ### <FIXME> format ordered multiway splits? </FIXME>

  switch(type,
         "factor" = {
           nindex <- index[cut(seq_along(lev), c(-Inf, breaks, Inf), right = right)]
           dlab <- as.vector(tapply(lev, nindex, paste, collapse = ", "))
         },
         "ordered" = {
           if (length(breaks) == 1) {
             if (right)
               dlab <- paste(c("<=", ">"), lev[breaks], sep = " ")
             else
               dlab <- paste(c("<", ">="), lev[breaks], sep = " ")
           } else {
             stop("") ### see above
           }
           dlab <- dlab[index]
         },
         "numeric" = {
           breaks <- round(breaks, digits)
           if (length(breaks) == 1) {
             if (right)
               dlab <- paste(c("<=", ">"), breaks, sep = " ")
             else
               dlab <- paste(c("<", ">="), breaks, sep = " ")
           } else {
             dlab <- levels(cut(0, breaks = c(-Inf, breaks, Inf),
                                right = right))
           }
           dlab <- as.vector(tapply(dlab, index, paste, collapse = " | "))
         }
  )

  return(list(name = mlab, levels = dlab))
}
tulliapadellini/energytree documentation built on May 14, 2020, 8:06 p.m.