R/party.R

Defines functions .list.rules.party nodeprune.default nodeprune.partynode nodeprune model.frame.party formula.party getCall.constparties getCall.party depth.party width.party data_party.default data_party .simplify_pred .predict_party_constparty .pred_density .pred_quantile .pred_ecdf .pred_numeric_response .pred_factor_response .pred_factor .pred_Surv_response .pred_Surv predict_party.constparty predict_party.default predict_party predict.party nodeapply.partynode nodeapply.party nodeapply nodeids.party nodeids.partynode nodeids as.constparty is.constparty node_party .names_party names.party length.party party

Documented in as.constparty data_party data_party.default depth.party getCall.party is.constparty length.party names.party nodeapply nodeapply.party nodeapply.partynode nodeids nodeids.party nodeids.partynode node_party nodeprune nodeprune.partynode party predict_party predict.party predict_party.constparty predict_party.default width.party

## FIXME: data in party
##   - currently assumed to be a data.frame
##   - potentially empty
##   - the following are all assumed to work:
##     dim(data), names(data)
##     sapply(data, class), lapply(data, levels)
##   - potentially these need to be modified if data/terms
##     should be able to deal with data bases

party <- function(node, data, fitted = NULL, terms = NULL, names = NULL, info = NULL) {

    stopifnot(inherits(node, "partynode"))
    stopifnot(inherits(data, "data.frame"))
    ### make sure all split variables are there 
    ids <- nodeids(node)[!nodeids(node) %in% nodeids(node, terminal = TRUE)]
    varids <- unique(unlist(nodeapply(node, ids = ids, FUN = function(x) 
        varid_split(split_node(x)))))
    stopifnot(varids %in% 1:ncol(data))
    
    if(!is.null(fitted)) {
        stopifnot(inherits(fitted, "data.frame"))
        stopifnot(nrow(data) == 0L | nrow(data) == nrow(fitted))
        
	# try to provide default variable "(fitted)"
	if(nrow(data) > 0L) {
          if(!("(fitted)" %in% names(fitted))) 
            fitted[["(fitted)"]] <- fitted_node(node, data = data)
	} else {
	  stopifnot("(fitted)" == names(fitted)[1L])
	}

        nt <- nodeids(node, terminal = TRUE)
        stopifnot(all(fitted[["(fitted)"]] %in% nt))

        node <- as.partynode(node, from = 1L)
        nt2 <- nodeids(node, terminal = TRUE)
        fitted[["(fitted)"]] <- nt2[match(fitted[["(fitted)"]], nt)]
    } else {
        node <- as.partynode(node, from = 1L)
	# default "(fitted)"
	if(nrow(data) > 0L & missing(fitted)) 
          fitted <- data.frame("(fitted)" = fitted_node(node, 
            data = data), check.names = FALSE)
    }
    
    party <- list(node = node, data = data, fitted = fitted, 
                  terms = NULL, names = NULL, info = info)
    class(party) <- "party"

    if(!is.null(terms)) {
        stopifnot(inherits(terms, "terms"))
	party$terms <- terms
    }

    if (!is.null(names)) {
        n <- length(nodeids(party, terminal = FALSE))
        if (length(names) != n)
            stop("invalid", " ", sQuote("names"), " ", "argument")
        party$names <- names
    }

    party
}

length.party <- function(x)
    length(nodeids(x))

names.party <- function(x)
    .names_party(x)

"names<-.party" <- function(x, value) {
     n <- length(nodeids(x, terminal = FALSE))
     if (!is.null(value) && length(value) != n)
         stop("invalid", " ", sQuote("names"), " ", "argument")
     x$names <- value
     x
}

.names_party <- function(party) {
    names <- party$names
    if (is.null(names))
        names <- as.character(nodeids(party, terminal = FALSE))
    names
}

node_party <- function(party) {
    stopifnot(inherits(party, "party"))
    party$node
}

is.constparty <- function(party) {
    stopifnot(inherits(party, "party"))
    if (!is.null(party$fitted)) 
        return(all(c("(fitted)", "(response)") %in% names(party$fitted)))
    return(FALSE)
}

as.constparty <- function(obj, ...) {
    if(!inherits(obj, "party")) obj <- as.party(obj)
    if (!is.constparty(obj)) {
        if(is.null(obj$fitted))
	  obj$fitted <- data.frame("(fitted)" = predict(obj, type = "node"), check.names = FALSE)
	if(!("(fitted)" %in% names(obj$fitted)))
	  obj$fitted["(fitted)"] <- predict(obj, type = "node")
	if(!("(response)" %in% names(obj$fitted)))
	  obj$fitted["(response)"] <- model.response(model.frame(obj))
	if(!("(weights)" %in% names(obj$fitted))) {
	  w <- model.weights(model.frame(obj))
	  if(is.null(w) && any(w != 1L)) obj$fitted["(weights)"] <- w
	}
    }
    if (is.constparty(obj)) {
        ret <- obj
        class(ret) <- c("constparty", class(obj))
        return(ret)
    }
    stop("cannot coerce object of class", " ", sQuote(class(obj)), 
          " ", "to", " ", sQuote("constparty"))
}

"[.party" <- "[[.party" <- function(x, i, ...) {
    if (is.character(i) && !is.null(names(x)))
        i <- which(names(x) %in% i)
    stopifnot(length(i) == 1 & is.numeric(i))
    stopifnot(i <= length(x) & i >= 1)
    i <- as.integer(i)
    dat <- data_party(x, i)
    if (!is.null(x$fitted)) {
        findx <- which("(fitted)" == names(dat))[1]
        fit <- dat[,findx:ncol(dat), drop = FALSE]
        dat <- dat[,-(findx:ncol(dat)), drop = FALSE]
        if (ncol(dat) == 0)
            dat <- x$data
    } else {
        fit <- NULL
        dat <- x$data
    }
    nam <- names(x)[nodeids(x, from = i, terminal = FALSE)]

    recFun <- function(node) {
        if (id_node(node) == i) return(node)
        kid <- sapply(kids_node(node), id_node)
        return(recFun(node[[max(which(kid <= i))]]))
    }
    node <- recFun(node_party(x))

    ret <- party(node = node, data = dat, fitted = fit, 
                 terms = x$terms, names = nam, info = x$info)
    class(ret) <- class(x)
    ret
}

nodeids <- function(obj, ...)
    UseMethod("nodeids")

nodeids.partynode <- function(obj, from = NULL, terminal = FALSE, ...) {

    if(is.null(from)) from <- id_node(obj)

    id <- function(node, record = TRUE, terminal = FALSE) {
      if(!record) return(NULL)
      if(!terminal)
          return(id_node(node))
      else
          if(is.terminal(node)) return(id_node(node)) else return(NULL)
    }

    rid <- function(node, record = TRUE, terminal = FALSE) {  
        myid <- id(node, record = record, terminal = terminal)
        if(is.terminal(node)) return(myid)
        kids <- kids_node(node)    
        kids_record <- if(record)  
            rep(TRUE, length(kids))
        else
            sapply(kids, id_node) == from
        return(c(myid,
            unlist(lapply(1:length(kids), function(i)
                rid(kids[[i]], record = kids_record[i], terminal = terminal)))
        ))
    }

    return(rid(obj, from == id_node(obj), terminal))
}

nodeids.party <- function(obj, from = NULL, terminal = FALSE, ...)
    nodeids(node_party(obj), from = from, terminal = terminal, ...)

nodeapply <- function(obj, ids = 1, FUN = NULL, ...)
    UseMethod("nodeapply")

nodeapply.party <- function(obj, ids = 1, FUN = NULL, by_node = TRUE, ...) {

    stopifnot(isTRUE(all.equal(ids, round(ids))))
    ids <- as.integer(ids)

    if(is.null(FUN)) FUN <- function(x, ...) x

    if (length(ids) == 0)
        return(NULL)

    if (by_node) {
        rval <- nodeapply(node_party(obj), ids = ids, FUN = FUN, ...)
    } else {
        rval <- lapply(ids, function(i) FUN(obj[[i]], ...))
    }

    names(rval) <- names(obj)[ids]
    return(rval)
}

nodeapply.partynode <- function(obj, ids = 1, FUN = NULL, ...) {

    stopifnot(isTRUE(all.equal(ids, round(ids))))
    ids <- as.integer(ids)

    if(is.null(FUN)) FUN <- function(x, ...) x

    if (length(ids) == 0)
        return(NULL)

    rval <- vector(mode = "list", length = length(ids))
    rval_id <- rep(0, length(ids))
    i <- 1
	
    recFUN <- function(node, ...) {
        if(id_node(node) %in% ids) {
            rval_id[i] <<- id_node(node)
            rval[[i]] <<- FUN(node, ...)
            i <<- i + 1
        }
        kids <- kids_node(node)
        if(length(kids) > 0) {
            for(j in 1:length(kids)) recFUN(kids[[j]])
        }
        invisible(TRUE)
    }
    foo <- recFUN(obj)
    rval <- rval[match(ids, rval_id)]
    return(rval)
}

predict.party <- function(object, newdata = NULL, perm = NULL, ...)
{
    ### compute fitted node ids first
    fitted <- if(is.null(newdata) && is.null(perm)) {    
        object$fitted[["(fitted)"]]	
    } else {
      if (is.null(newdata)) newdata <- model.frame(object)
      terminal <- nodeids(object, terminal = TRUE)
	
      if(max(terminal) == 1L) {
        rep.int(1L, NROW(newdata))
      } else {
	
        inner <- 1L:max(terminal)
        inner <- inner[-terminal]

        primary_vars <- nodeapply(object, ids = inner, by_node = TRUE, FUN = function(node) {
            varid_split(split_node(node))
        })
        surrogate_vars <- nodeapply(object, ids = inner, by_node = TRUE, FUN = function(node) {
            surr <- surrogates_node(node)
            if(is.null(surr)) return(NULL) else return(sapply(surr, varid_split))
        })
        vnames <- names(object$data)

        ### the splits of nodes with a primary split in perm
        ### will be permuted
        if (!is.null(perm)) {
            if (is.character(perm)) {
                stopifnot(all(perm %in% vnames))
                perm <- match(perm, vnames)
            } else {
                ### perm is a named list of factors coding strata
                ### (for varimp(..., conditional = TRUE)
                stopifnot(all(names(perm) %in% vnames))
                stopifnot(all(sapply(perm, is.factor)))
                tmp <- vector(mode = "list", length = length(vnames))
                tmp[match(names(perm), vnames)] <- perm
                perm <- tmp
            }
        }

        ## ## FIXME: the is.na() call takes loooong on large data sets
        ## unames <- if(any(sapply(newdata, is.na))) 
        ##     vnames[unique(unlist(c(primary_vars, surrogate_vars)))]
        ## else 
        ##     vnames[unique(unlist(primary_vars))]
	unames <- vnames[unique(unlist(c(primary_vars, surrogate_vars)))]
	
        vclass <- structure(lapply(object$data, class), .Names = vnames)
        ndnames <- names(newdata)
        ndclass <- structure(lapply(newdata, class), .Names = ndnames)
        checkclass <- all(sapply(unames, function(x) 
          isTRUE(all.equal(vclass[[x]], ndclass[[x]]))))
        factors <- sapply(unames, function(x) inherits(object$data[[x]], "factor"))
        checkfactors <- all(sapply(unames[factors], function(x) 
          isTRUE(all.equal(levels(object$data[[x]]), levels(newdata[[x]])))))
        ## FIXME: inform about wrong classes / factor levels?
        if(all(unames %in% ndnames) && checkclass && checkfactors) {
            vmatch <- match(vnames, ndnames)
            fitted_node(node_party(object), data = newdata, 
                        vmatch = vmatch, perm = perm)
        } else {
            if (!is.null(object$terms)) {
                ### <FIXME> this won't work for multivariate responses
                ### </FIXME>
                xlev <- lapply(unames[factors], 
                               function(x) levels(object$data[[x]]))
                names(xlev) <- unames[factors]
                mf <- model.frame(delete.response(object$terms), newdata,
                                  xlev = xlev)
                fitted_node(node_party(object), data = mf, 
                            vmatch = match(vnames, names(mf)), perm = perm)
            } else
                stop("") ## FIXME: write error message
        }
      }
    }
    ### compute predictions
    predict_party(object, fitted, newdata, ...)
}

predict_party <- function(party, id, newdata = NULL, ...)
    UseMethod("predict_party")

### do nothing expect returning the fitted ids
predict_party.default <- function(party, id, newdata = NULL, FUN = NULL, ...) {

    if (length(list(...)) > 1) 
        warning("argument(s)", " ", sQuote(names(list(...))), " ", "have been ignored")

    ## get observation names: either node names or
    ## observation names from newdata
    nam <- if(is.null(newdata)) {
      if(is.null(rnam <- rownames(data_party(party)))) names(party)[id] else rnam      
    } else {
      rownames(newdata)
    }
    if(length(nam) != length(id)) nam <- NULL

    if (!is.null(FUN))
        return(.simplify_pred(nodeapply(party, 
            nodeids(party, terminal = TRUE), FUN, by_node = TRUE), id, nam))

    ## special case: fitted ids
    return(structure(id, .Names = nam))
}

predict_party.constparty <- function(party, id, newdata = NULL,
    type = c("response", "prob", "quantile", "density", "node"),
    at = if (type == "quantile") c(0.1, 0.5, 0.9),
    FUN = NULL, simplify = TRUE, ...)
{
    ## extract fitted information
    response <- party$fitted[["(response)"]]
    weights <- party$fitted[["(weights)"]]
    fitted <- party$fitted[["(fitted)"]]
    if (is.null(weights)) weights <- rep(1, NROW(response))

    ## get observation names: either node names or
    ## observation names from newdata
    nam <- if(is.null(newdata)) names(party)[id] else rownames(newdata)
    if(length(nam) != length(id)) nam <- NULL

    ## match type
    type <- match.arg(type)

    ## special case: fitted ids
    if(type == "node")
      return(structure(id, .Names = nam))

    ### multivariate response
    if (is.data.frame(response)) {
        ret <- lapply(response, function(r) {
            ret <- .predict_party_constparty(node_party(party), fitted = fitted, 
                response = r, weights, id = id, type = type, at = at, FUN = FUN, ...)
            if (simplify) .simplify_pred(ret, id, nam) else ret
        })
        if (all(sapply(ret, is.atomic)))
            ret <- as.data.frame(ret)
        names(ret) <- colnames(response)
        return(ret)
    }

    ### univariate response
    ret <- .predict_party_constparty(node_party(party), fitted = fitted, response = response, 
        weights = weights, id = id, type = type, at = at, FUN = FUN, ...)
    if (simplify) .simplify_pred(ret, id, nam) else ret[as.character(id)]
}

### functions for node prediction based on fitted / response
.pred_Surv <- function(y, w) {
    if (length(y) == 0) return(NA)
    survfit(y ~ 1, weights = w, subset = w > 0)
}

.pred_Surv_response <- function(y, w) {
    if (length(y) == 0) return(NA)
    .median_survival_time(.pred_Surv(y, w))
}
 
.pred_factor <- function(y, w) {
    lev <- levels(y)
    sumw <- tapply(w, y, sum)
    sumw[is.na(sumw)] <- 0
    prob <- sumw / sum(w)
    names(prob) <- lev
    return(prob)
}

.pred_factor_response <- function(y, w) {
    prob <- .pred_factor(y, w)
    return(factor(which.max(prob), levels = 1:nlevels(y),
                  labels = levels(y), 
                  ordered = is.ordered(y)))
    return(prob) 
}
                    
.pred_numeric_response <- function(y, w) 
    weighted.mean(y, w, na.rm = TRUE)

.pred_ecdf <- function(y, w) {
    if (length(y) == 0) return(NA)
    iw <- as.integer(round(w))
    if (max(abs(w - iw)) < sqrt(.Machine$double.eps)) {
        y <- rep(y, w)
        return(ecdf(y))
    } else {
        stop("cannot compute empirical distribution function with non-integer weights")
    }
}

.pred_quantile <- function(y, w) {
    y <- rep(y, w)
    function(p, ...) quantile(y, probs = p, ...)
}

.pred_density <- function(y, w) {
    ### we only have integer-valued weights and density complains
    ### about weights since R 4.3.0 (because bandwidth selection doesn't
    ### work with weights)
    yw <- rep(y, w)
    d <- density(yw)
    approxfun(d[1:2], rule = 2)
}

### workhorse: compute predictions based on fitted / response data
.predict_party_constparty <- function(node, fitted, response, weights,
    id = id, type = c("response", "prob", "quantile", "density"),
    at = if (type == "quantile") c(0.1, 0.5, 0.9), FUN = NULL, ...) {

    type <- match.arg(type)
    if (is.null(FUN)) {

        rtype <- class(response)[1]
        if (rtype == "ordered") rtype <- "factor"    
        if (rtype == "integer") rtype <- "numeric"
        if (rtype == "AsIs") rtype <- "numeric"

        if (type %in% c("quantile", "density") && rtype != "numeric")
            stop("quantile and density estimation currently only implemented for numeric responses")

        FUN <- switch(rtype,
            "Surv" = if (type == "response") .pred_Surv_response else .pred_Surv,
            "factor" = if (type == "response") .pred_factor_response else .pred_factor,
            "numeric" = switch(type,
                "response" = .pred_numeric_response,
                "prob" = .pred_ecdf,
                "quantile" = .pred_quantile, 
                "density" = .pred_density) 
       )
    }
      
    ## empirical distribution in each leaf
    if (all(id %in% fitted)) {
        tab <- tapply(1:NROW(response), fitted, 
                      function(i) FUN(response[i], weights[i]), simplify = FALSE)
    } else {
        ### id may also refer to inner nodes
        tab <- as.array(lapply(sort(unique(id)), function(i) {
            index <- fitted %in% nodeids(node, i, terminal = TRUE)
            ret <- FUN(response[index], weights[index])
            ### no information about i in fitted
            if (all(!index)) ret[1] <- NA
            return(ret)
        }))
        names(tab) <- as.character(sort(unique(id)))
    }
    if (inherits(tab[[1]], "function") && !is.null(at))
        tab <- lapply(tab, function(f) f(at))
    tn <- names(tab)
    dim(tab) <- NULL
    names(tab) <- tn

    tab
}


### simplify structure of predictions
.simplify_pred <- function(tab, id, nam) {

    if (all(sapply(tab, length) == 1) & all(sapply(tab, is.atomic))) {
        ret <- do.call("c", tab)
        names(ret) <- names(tab)
        ### R 4.1.x allows to call c() on factors, this is needed for
        ### backward-compatibility
        ret <- if (is.factor(tab[[1]]) & !is.factor(ret))
            factor(ret[as.character(id)], levels = 1:length(levels(tab[[1]])),
		   labels = levels(tab[[1]]), ordered = is.ordered(tab[[1]]))
        else 
            ret[as.character(id)]
        names(ret) <- nam
    } else if (length(unique(sapply(tab, length))) == 1 & 
               all(sapply(tab, is.numeric))) {
        ret <- matrix(unlist(tab), nrow = length(tab), byrow = TRUE)
        colnames(ret) <- names(tab[[1]])
        rownames(ret) <- names(tab)
        ret <- ret[as.character(id),, drop = FALSE]
        rownames(ret) <- nam
    } else {
        ret <- tab[as.character(id)]
        names(ret) <- nam
    }
    ret
}

data_party <- function(party, id = 1L)
    UseMethod("data_party")

data_party.default <- function(party, id = 1L) {
    
    extract <- function(id) {
        if(is.null(party$fitted))
            if(nrow(party$data) == 0) return(NULL)
        else
            stop("cannot subset data without fitted ids")

        ### which terminal nodes follow node number id?
        nt <- nodeids(party, id, terminal = TRUE)
        wi <- party$fitted[["(fitted)"]] %in% nt

        ret <- if (nrow(party$data) == 0)
            subset(party$fitted, wi)
        else
            subset(cbind(party$data, party$fitted), wi)
        ret
    }
    if (length(id) > 1)
        return(lapply(id, extract))
    else 
        return(extract(id))
}

width.party <- function(x, ...) {
  width(node_party(x), ...)
}

depth.party <- function(x, root = FALSE, ...) {
  depth(node_party(x), root = root, ...)
}

getCall.party <- function(x, ...) {
  x$info$call
}

getCall.constparties <- function(x, ...) {
  x$info$call
}

formula.party <- function(x, ...) {
  x <- terms(x)
  NextMethod()
}

model.frame.party <- function(formula, ...)
{
  mf <- formula$data
  if(nrow(mf) > 0L) return(mf)

  dots <- list(...)
  nargs <- dots[match(c("data", "na.action", "subset"), names(dots), 0L)]
  mf <- getCall(formula)
  mf <- mf[c(1L, match(c("formula", "data", "subset", "na.action"), names(mf), 0L))]
  mf$drop.unused.levels <- TRUE
  mf[[1L]] <- quote(stats::model.frame)
  mf[names(nargs)] <- nargs
  if(is.null(env <- environment(terms(formula)))) env <- parent.frame()
  eval(mf, env)
}


nodeprune <- function(x, ids, ...)
    UseMethod("nodeprune")



nodeprune.partynode <- function(x, ids, ...) {

    stopifnot(ids %in% nodeids(x))

    ### compute indices path to each node
    ### to be pruned off
    idxs <- lapply(ids, .get_path, obj = x)

    ### [[.partynode is NOT [[.list
    cls <- class(x)
    x <- unclass(x)

    for (i in 1:length(idxs)) {
        ## path to be pruned
        idx <- idxs[[i]]
	if(!is.null(idx)) {
          ### check if we already pruned-off this node
          tmp <- try(x[[idx]], silent = TRUE)
          if(inherits(tmp, "try-error")) next()
          ### prune node by introducing a "new" terminal node
          x[[idx]] <- partynode(id = id_node(tmp), info = info_node(tmp))
	} else {
	  ## if idx path is NULL prune everything
	  x[2L:4L] <- NULL
	}
    }

    class(x) <- cls
    return(as.partynode(x, from = 1L))
}

nodeprune.default <- function(x, ids, ...)
    stop("No", sQuote("nodeprune"), "method for class", class(x), "implemented")

.list.rules.party <- function(x, i = NULL, ...) {
    if (is.null(i)) i <- nodeids(x, terminal = TRUE)
    if (length(i) > 1) {
        ret <- sapply(i, .list.rules.party, x = x)
        names(ret) <- if (is.character(i)) i else names(x)[i]
        return(ret)
    }
    if (is.character(i) && !is.null(names(x)))
        i <- which(names(x) %in% i)
    stopifnot(length(i) == 1 & is.numeric(i))
    stopifnot(i <= length(x) & i >= 1)
    i <- as.integer(i)
    dat <- data_party(x, i)  
    if (!is.null(x$fitted)) {
        findx <- which("(fitted)" == names(dat))[1]  
        fit <- dat[,findx:ncol(dat), drop = FALSE]   
        dat <- dat[,-(findx:ncol(dat)), drop = FALSE]
        if (ncol(dat) == 0)
            dat <- x$data
    } else {
        fit <- NULL  
        dat <- x$data
    }

    rule <- c()

    recFun <- function(node) {
        if (id_node(node) == i) return(NULL)   
        kid <- sapply(kids_node(node), id_node)
        whichkid <- max(which(kid <= i))
        split <- split_node(node)
        ivar <- varid_split(split)
        svar <- names(dat)[ivar]
        index <- index_split(split)
        if (is.factor(dat[, svar])) {
            if (is.null(index)) 
                index <- ((1:nlevels(dat[, svar])) > breaks_split(split)) + 1
            slevels <- levels(dat[, svar])[index == whichkid]
            srule <- paste(svar, " %in% c(\"", 
                paste(slevels, collapse = "\", \"", sep = ""), "\")",
                sep = "")
        } else {
            if (is.null(index)) index <- 1:length(kid)
            breaks <- cbind(c(-Inf, breaks_split(split)), 
                            c(breaks_split(split), Inf))
            sbreak <- breaks[index == whichkid,]
            right <- right_split(split)
            srule <- c()
            if (is.finite(sbreak[1]))
                srule <- c(srule, 
                    paste(svar, ifelse(right, ">", ">="), sbreak[1]))
            if (is.finite(sbreak[2]))
                srule <- c(srule, 
                    paste(svar, ifelse(right, "<=", "<"), sbreak[2]))
            srule <- paste(srule, collapse = " & ")
        }
        rule <<- c(rule, srule)
        return(recFun(node[[whichkid]]))
    }
    node <- recFun(node_party(x))
    paste(rule, collapse = " & ")
}

Try the partykit package in your browser

Any scripts or data that you put into this service are public.

partykit documentation built on April 14, 2023, 5:09 p.m.