R/plot.gensvm.R

Defines functions gensvm.plot.markers gensvm.fill.colors gensvm.plot.colors gensvm.plot.1d gensvm.plot.2d plot.gensvm

Documented in plot.gensvm

#' @title Plot the simplex space of the fitted GenSVM model
#' 
#' @description This function creates a plot of the simplex space for a fitted 
#' GenSVM model and the given data set. This function works for dataset with 
#' two or three classes. For more than 3 classes, the simplex space is too high 
#' dimensional to easily visualize.
#'
#' @param x A fitted \code{gensvm} object
#' @param labels the labels to color points with. If this is omitted the 
#' predicted labels are used.
#' @param newdata the dataset to plot. If this is NULL the training data is 
#' used.
#' @param with.margins plot the margins
#' @param with.shading show shaded areas for the class regions
#' @param with.legend show the legend for the class labels
#' @param center.plot ensure that the boundaries and margins are always visible 
#' in the plot
#' @param xlim allows the user to force certain plot limits. If set, these 
#' bounds will be used for the horizontal axis.
#' @param ylim allows the user to force certain plot limits. If set, these 
#' bounds will be used for the vertical axis and the value of center.plot will 
#' be ignored
#' @param ... further arguments are passed to the builtin plot() function
#'
#' @return returns the object passed as input
#'
#' @author
#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
#'
#' @references
#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized 
#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, 
#' 17(225):1--42. URL \url{https://jmlr.org/papers/v17/14-526.html}.
#'
#' @seealso
#' \code{\link{plot.gensvm.grid}}, \code{\link{predict.gensvm}}, 
#' \code{\link{gensvm}}, \code{\link{gensvm-package}}
#'
#' @method plot gensvm
#'
#' @export
#'
#' @importFrom grDevices rgb
#' @importFrom graphics legend par plot polygon segments
#'
#' @examples
#' x <- iris[, -5]
#' y <- iris[, 5]
#'
#' # train the model
#' fit <- gensvm(x, y)
#'
#' # plot the simplex space
#' plot(fit)
#'
#' # plot and use the true colors (easier to spot misclassified samples)
#' plot(fit, y)
#'
#' # plot only misclassified samples
#' x.mis <- x[predict(fit) != y, ]
#' y.mis.true <- y[predict(fit) != y]
#' plot(fit, newdata=x.mis)
#' plot(fit, y.mis.true, newdata=x.mis)
#'
#' # plot a 2-d model
#' xx <- x[y %in% c('versicolor', 'virginica'), ]
#' yy <- y[y %in% c('versicolor', 'virginica')]
#' fit <- gensvm(xx, yy, kernel='rbf', max.iter=1000)
#' plot(fit)
#'
plot.gensvm <- function(x, labels, newdata=NULL, with.margins=TRUE, 
                        with.shading=TRUE, with.legend=TRUE, center.plot=TRUE, 
                        xlim=NULL, ylim=NULL, ...)
{
    fit <- x
    if (!(fit$n.classes %in% c(2,3))) {
        cat("Error: Can only plot with 2 or 3 classes\n")
        return(invisible(NULL))
    }

    newdata <- if(is.null(newdata)) eval.parent(fit$call$x) else newdata

    # Sanity check
    if (ncol(newdata) != fit$n.features) {
        cat("Error: Number of features of fitted model and testing data 
            disagree.\n")
        return(invisible(NULL))
    }

    x.train <- eval.parent(fit$call$x)
    if (fit$kernel == 'linear') {
        V <- coef(fit)
        Z <- cbind(matrix(1, dim(newdata)[1], 1), as.matrix(newdata))
        S <- Z %*% V
        y.pred <- predict(fit, newdata)
        y.pred.int <- match(y.pred, fit$classes)
    } else {
        kernels <- c("linear", "poly", "rbf", "sigmoid")
        kernel.idx <- which(kernels == fit$kernel) - 1
        plotdata <- .Call("R_gensvm_plotdata_kernels",
                          as.matrix(newdata),
                          as.matrix(x.train),
                          as.matrix(fit$V),
                          as.integer(nrow(fit$V)),
                          as.integer(ncol(fit$V)),
                          as.integer(nrow(x.train)),
                          as.integer(nrow(newdata)),
                          as.integer(fit$n.features),
                          as.integer(fit$n.classes),
                          as.integer(kernel.idx),
                          fit$gamma,
                          fit$coef,
                          fit$degree,
                          fit$kernel.eigen.cutoff
                          )
        S <- plotdata$ZV
        y.pred.int <- plotdata$y.pred
    }

    colors <- gensvm.plot.colors(fit$n.classes)
    markers <- gensvm.plot.markers(fit$n.classes)

    indices <- if (missing(labels)) y.pred.int else match(labels, fit$classes)

    col.vector <- colors[indices]
    mark.vector <- markers[indices]

    if (fit$n.classes == 2)
        S <- cbind(S, matrix(0, nrow(S), 1))

    par(pty="s")
    if (center.plot) {
        if (is.null(xlim))
            xlim <- c(min(min(S[, 1]), -1.2), max(max(S[, 1]), 1.2))
        if (is.null(ylim))
            ylim <- c(min(min(S[, 2]), -0.75), max(max(S[, 2]), 1.2))
        plot(S[, 1], S[, 2], col=col.vector, pch=mark.vector, ylab='', xlab='', 
             asp=1, xlim=xlim, ylim=ylim, ...)
    } else {
        plot(S[, 1], S[, 2], col=col.vector, pch=mark.vector, ylab='', 
             xlab='', asp=1, xlim=xlim, ylim=ylim, ...)
    }

    if (fit$n.classes == 3)
        gensvm.plot.2d(fit$classes, with.margins, with.shading, with.legend, 
                       center.plot)
    else
        gensvm.plot.1d(fit$classes, with.margins, with.shading, with.legend, 
                       center.plot)

    invisible(fit)
}


gensvm.plot.2d <- function(classes, with.margins, with.shading, 
                           with.legend, center.plot)
{
    limits <- par("usr")
    xmin <- limits[1]
    xmax <- limits[2]
    ymin <- limits[3]
    ymax <- limits[4]

    # draw the fixed boundaries
    segments(0, 0, 0, ymin)
    segments(0, 0, xmax, xmax/sqrt(3))
    segments(xmin, abs(xmin)/sqrt(3), 0, 0)

    if (with.margins) {
        # margin from left below decision boundary to center
        segments(xmin, -xmin/sqrt(3) - sqrt(4/3), -1, -1/sqrt(3), lty=2)

        # margin from left center to down
        segments(-1, -1/sqrt(3), -1, ymin, lty=2)

        # margin from right center to middle
        segments(1, -1/sqrt(3), 1, ymin, lty=2)

        # margin from right center to right boundary
        segments(1, -1/sqrt(3), xmax, xmax/sqrt(3) - sqrt(4/3), lty=2)

        # margin from center to top left
        segments(xmin, -xmin/sqrt(3) + sqrt(4/3), 0, sqrt(4/3), lty=2)

        # margin from center to top right
        segments(0, sqrt(4/3), xmax, xmax/sqrt(3) + sqrt(4/3), lty=2)
    }

    if (with.shading) {
        fill <- gensvm.fill.colors()
        # bottom left
        polygon(c(xmin, -1, -1, xmin), c(ymin, ymin, -1/sqrt(3), -xmin/sqrt(3) - 
                                         sqrt(4/3)), col=fill$green, border=NA)
        # bottom right
        polygon(c(1, xmax, xmax, 1), c(ymin, ymin, xmax/sqrt(3) - sqrt(4/3), 
                                       -1/sqrt(3)), col=fill$blue, border=NA)
        # top
        polygon(c(xmin, 0, xmax, xmax, xmin), 
                c(-xmin/sqrt(3) + sqrt(4/3), sqrt(4/3), xmax/sqrt(3) + sqrt(4/3),
                  ymax, ymax), col=fill$orange, 
                border=NA)
    }

    if (with.legend) {
        offset <- abs(xmax - xmin) * 0.05
        colors <- gensvm.plot.colors()
        markers <- gensvm.plot.markers()
        legend(xmax + offset, ymax, classes, col=colors, pch=markers, xpd=T)
    }
}

gensvm.plot.1d <- function(classes, with.margins, with.shading, with.legend, 
                           center.plot)
{
    limits <- par("usr")
    xmin <- limits[1]
    xmax <- limits[2]
    ymin <- limits[3]
    ymax <- limits[4]

    # draw the fixed boundaries
    segments(0, ymin, 0, ymax)

    if (with.margins) {
        segments(-1, ymin, -1, ymax, lty=2)
        segments(1, ymin, 1, ymax, lty=2)
    }
    if (with.shading) {
        fill <- gensvm.fill.colors()
        polygon(c(xmin, -1, -1, xmin), c(ymin, ymin, ymax, ymax), 
                col=fill$blue, border=NA)
        polygon(c(1, xmax, xmax, 1), c(ymin, ymin, ymax, ymax), col=fill$orange,
                border=NA)
    }
    if (with.legend) {
        offset <- abs(xmax - xmin) * 0.05
        colors <- gensvm.plot.colors(2)
        markers <- gensvm.plot.markers(2)
        legend(xmax + offset, ymax, classes, col=colors, pch=markers, xpd=T)
    }
}

gensvm.plot.colors <- function(K=3)
{
    point.blue <- rgb(31, 119, 180, maxColorValue=255)
    point.orange <- rgb(255, 127, 14, maxColorValue=255)
    point.green <- rgb(44, 160, 44, maxColorValue=255)

    if (K == 3)
        colors <- as.matrix(c(point.green, point.blue, point.orange))
    else
        colors <- as.matrix(c(point.blue, point.orange))
    return(colors)
}

gensvm.fill.colors <- function()
{
    fill.blue <- rgb(31, 119, 180, 51, maxColorValue=255)
    fill.orange <- rgb(255, 127, 14, 51, maxColorValue=255)
    fill.green <- rgb(44, 160, 44, 51, maxColorValue=255)

    fills <- list(blue=fill.blue, orange=fill.orange, green=fill.green)
    return(fills)
}

gensvm.plot.markers <- function(K=3)
{
    markers <- as.vector(c(15, 16, 17))
    return(markers[1:K])
}

Try the gensvm package in your browser

Any scripts or data that you put into this service are public.

gensvm documentation built on Feb. 16, 2023, 5:58 p.m.