R/methods.r

Defines functions print.cv.best.models print.cv.result print.cv.models extract.metrics plot.cv.models extract.fit.cv.best.models extract.fit.cv.result extract.fit.cv.models extract.fit

#-----------------------------------------------------------------------------
#'	Extract predicted values
#'
#'	Extract predicted values from a \code{cv.models}, \code{cv.result} or
#'	\code{cv.best.models} object.
#'
#'	@param object
#'	an object of \code{cv.models}, \code{cv.result} or \code{cv.best.models}.
#'
#'	@param index
#'	an integer specifying index of model to extract.
#'	If the \code{object} is \code{cv.models}, index in the models created by
#'	hyper-parameter tuning.
#'	If the \code{object} is \code{cv.best.models}, index in the models having
#'	same performance measure.
#'	Ignored if the \code{object} is \code{cv.result}.
#'
#'	@return a data.frame having original response variable ("response"),
#'	predicted values ("prediction"), and index in the original data ("index").
#'
#'	@export
#-----------------------------------------------------------------------------
extract.fit <- function(object, index = 1) {
	UseMethod("extract.fit")
}

#-----------------------------------------------------------------------------
#'	@export
#'	@method extract.fit cv.models
#'	@describeIn extract.fit S3 method for \code{cv.models}.
#-----------------------------------------------------------------------------
extract.fit.cv.models <- function(object, index = 1) {
	fits <- lapply(object$cv.results[[index]]$fits, as.data.frame)
	fits <- do.call(rbind, fits)
	return(fits)
}

#-----------------------------------------------------------------------------
#'	@export
#'	@method extract.fit cv.result
#'	@describeIn extract.fit S3 method for \code{cv.result}.
#-----------------------------------------------------------------------------
extract.fit.cv.result <- function(object, index = 1) {
	dfs <- lapply(object$fits, as.data.frame)
	result <- do.call(rbind, dfs)
	return(result)
}

#-----------------------------------------------------------------------------
#'	@export
#'	@method extract.fit cv.best.models
#'	@describeIn extract.fit S3 method for \code{cv.best.models}.
#-----------------------------------------------------------------------------
extract.fit.cv.best.models <- function(object, index = 1) {
	return(extract.fit(object[[index]]))
}


#-----------------------------------------------------------------------------
#'	Plot predicted and actual values
#'
#'	Draw scatter plot for predicted and actual values.
#'
#'	@param x
#'	a \code{cv.models} or \code{cv.best.models} object.
#'
#'	@param index
#'	an integer specifying index of the model.
#'	For the detail, see \code{index} argument of \code{\link{extract.fit}}.
#'
#'	@param ...
#'	other arguments passed to plot.
#'
#'	@export
#-----------------------------------------------------------------------------
plot.cv.models <- function(x, index = 1, ...) {
	fits <- extract.fit(x, index)
	graphics::plot(
		fits$prediction, fits$response, xlab = "Prediction", ylab = "Reponse",
		...
	)
	graphics::curve(x * 1, add = TRUE)
}


#------------------------------------------------------------------------------
#'	Extract performance measures
#'
#'	@param object
#'	a \code{cv.models} object.
#'	@return data.frame having all performance measures.
#'	@export
#------------------------------------------------------------------------------
extract.metrics <- function(object) {
	metrics <- lapply(object$cv.results, "[[", "metrics")
	result <- do.call(rbind, metrics)
	rownames(result) <- NULL
	return(result)
}


#------------------------------------------------------------------------------
#'	print method
#'	
#'	@param x
#'	a \code{cv.result}, \code{cv.best.models} or \code{cv.best.models} object.
#'	@param ... currently not used.
#'	
#'	@name print
#------------------------------------------------------------------------------
NULL

#------------------------------------------------------------------------------
#'	@method print cv.models
#'	@describeIn print S3 method for class 'cv.models'
#'	@export
#------------------------------------------------------------------------------
print.cv.models <- function(x, ...) {
	cat("Result of cross validation\n")
	cat(sprintf("Function name: %s\n", x$function.name))
	cat("Cross validation metrics:\n")
	print(extract.metrics(x))
	cat("\n")
}

#------------------------------------------------------------------------------
#'	@method print cv.result
#'	@describeIn print S3 method for class 'cv.result'
#'	@export
#------------------------------------------------------------------------------
print.cv.result <- function(x, ...) {
	cat("Result of cross validation\n")
	cat(sprintf("Function name: %s\n", x$function.name))
	cat("Cross validation metrics:\n")
	print(x$metrics)
	cat("\n")
}

#------------------------------------------------------------------------------
#'	@method print cv.best.models
#'	@describeIn print S3 method for class 'cv.best.models'
#'	@export
#------------------------------------------------------------------------------
print.cv.best.models <- function(x, ...) {
	hr <- "------------------------------------------------"
	if (length(x) > 1) {
		msg <- sprintf("-> %s best models were found:", length(x))
	} else {
		msg <- "-> 1 best models was found:"
	}
	msg <- paste(
		hr, "Result of model selection by cross validation", msg, hr,
		sep = "\n"
	)
	cat(msg)
	for (i in 1:length(x)) {
		cat(paste(sprintf("\nModel index %s", i), hr, sep = "\n"), "\n")
		print(x[[i]])
	}
}
Marchen/cv.models documentation built on Sept. 2, 2020, 2:04 a.m.