R/check.derivatives.R

Defines functions check.derivatives

Documented in check.derivatives

# Copyright (C) 2011 Jelmer Ypma. All Rights Reserved.
# This code is published under the L-GPL.
#
# File:   check.derivatives.R
# Author: Jelmer Ypma
# Date:   24 July 2011
#
# Compare analytic derivatives wih finite difference approximations.
#
# Input:
#        .x : compare at this point
#        func : calculate finite difference approximation for the gradient of this function
#        func_grad : function to calculate analytic gradients
#        check_derivatives_tol : show deviations larger than this value (optional)
#        check_derivatives_print : print the values of the function (optional)
#        func_grad_name : name of function to show in output (optional)
#       ... : arguments that are passed to the user-defined function (func and func_grad)
#
# Output: list with analytic gradients, finite difference approximations, relative errors
#         and a comparison of the relative errors to the tolerance.
#
# CHANGELOG:
#   27/10/2013: Added relative_error and flag_derivative_warning to output list.
#   05/05/2014: Replaced cat by message, so messages can now be suppressed by suppressMessages.



#' Check analytic gradients of a function using finite difference
#' approximations
#'
#' This function compares the analytic gradients of a function with a finite
#' difference approximation and prints the results of these checks.
#'
#' @param .x point at which the comparison is done.
#' @param func function to be evaluated.
#' @param func_grad function calculating the analytic gradients.
#' @param check_derivatives_tol option determining when differences between the
#' analytic gradient and its finite difference approximation are flagged as an
#' error.
#' @param check_derivatives_print option related to the amount of output. 'all'
#' means that all comparisons are shown, 'errors' only shows comparisons that
#' are flagged as an error, and 'none' shows the number of errors only.
#' @param func_grad_name option to change the name of the gradient function
#' that shows up in the output.
#' @param ...  further arguments passed to the functions func and func_grad.
#'
#' @return The return value contains a list with the analytic gradient, its
#' finite difference approximation, the relative errors, and vector comparing
#' the relative errors to the tolerance.
#'
#' @export
#'
#' @author Jelmer Ypma
#'
#' @seealso \code{\link[nloptr:nloptr]{nloptr}}
#'
#' @keywords optimize interface
#'
#' @examples
#'
#' library('nloptr')
#'
#' # example with correct gradient
#' f <- function( x, a ) {
#' 	return( sum( ( x - a )^2 ) )
#' }
#'
#' f_grad <- function( x, a ) {
#' 	return( 2*( x - a ) )
#' }
#'
#' check.derivatives( .x=1:10, func=f, func_grad=f_grad,
#'     check_derivatives_print='none', a=runif(10) )
#'
#' # example with incorrect gradient
#' f_grad <- function( x, a ) {
#' 	return( 2*( x - a ) + c(0,.1,rep(0,8)) )
#' }
#'
#' check.derivatives( .x=1:10, func=f, func_grad=f_grad,
#'     check_derivatives_print='errors', a=runif(10) )
#'
#' # example with incorrect gradient of vector-valued function
#' g <- function( x, a ) {
#' 	return( c( sum(x-a), sum( (x-a)^2 ) ) )
#' }
#'
#' g_grad <- function( x, a ) {
#' 	return( rbind( rep(1,length(x)) + c(0,.01,rep(0,8)), 2*(x-a) + c(0,.1,rep(0,8)) ) )
#' }
#'
#' check.derivatives( .x=1:10, func=g, func_grad=g_grad,
#'     check_derivatives_print='all', a=runif(10) )
#'
check.derivatives <-
    function(
        .x,
        func,
        func_grad,
        check_derivatives_tol = 1e-04,
        check_derivatives_print = 'all',
        func_grad_name = 'grad_f',
        ...
    )
{

    analytic_grad <- func_grad( .x, ... )

    finite_diff_grad <- finite.diff( func, .x, ... )

    relative_error <- ifelse(
                        finite_diff_grad == 0,
                        analytic_grad,
                        abs( ( analytic_grad - finite_diff_grad ) / finite_diff_grad )
                      )

    flag_derivative_warning <- relative_error > check_derivatives_tol

    if ( ! ( check_derivatives_print %in% c('all','errors','none') ) ) {
        warning( paste( "Value '", check_derivatives_print, "' for check_derivatives_print is unknown; use 'all' (default), 'errors', or 'none'.", sep='' ) )
        check_derivatives_print <- 'none'
    }

    # determine indices of vector / matrix for printing
    # format indices with width, such that they are aligned vertically
    if ( is.matrix( analytic_grad ) ) {
        indices <- paste(
                    format( rep( 1:nrow(analytic_grad), times=ncol(analytic_grad) ), width=1 + sum( nrow(analytic_grad) > 10^(1:10) ) ),
                    format( rep( 1:ncol(analytic_grad), each=nrow(analytic_grad) ), width=1 + sum( ncol(analytic_grad) > 10^(1:10) ) ),
                    sep=', '
                   )
    }
    else {
        indices <- format( 1:length(analytic_grad), width=1 + sum( length(analytic_grad) > 10^(1:10) ) )
    }

    # Print results.
    message( "Derivative checker results: ", sum( flag_derivative_warning ), " error(s) detected." )
    if ( check_derivatives_print == 'all' ) {

        message( "\n",
            paste(
                ifelse( flag_derivative_warning, "*"," "),
                " ", func_grad_name, "[ ", indices, " ] = ",
                format(analytic_grad, scientific=TRUE),
                " ~ ",
                format(finite_diff_grad, scientific=TRUE),
                "   [",
                format( relative_error, scientific=TRUE),
                "]", sep='', collapse="\n"
            ),
            "\n\n"
        )
    }
    else if ( check_derivatives_print == 'errors' ) {
        if ( sum( flag_derivative_warning ) > 0 ) {
            message( "\n",
                paste(
                    ifelse( flag_derivative_warning[ flag_derivative_warning ], "*"," "),
                    " ", func_grad_name, "[ ", indices[ flag_derivative_warning ], " ] = ",
                    format(analytic_grad[ flag_derivative_warning ], scientific=TRUE),
                    " ~ ",
                    format(finite_diff_grad[ flag_derivative_warning ], scientific=TRUE),
                    "   [",
                    format( relative_error[ flag_derivative_warning ], scientific=TRUE),
                    "]", sep='', collapse="\n"
                ),
                "\n\n"
            )
        }
    }
    else if ( check_derivatives_print == 'none' ) {

    }


    return(
        list(
            "analytic"                = analytic_grad,
            "finite_difference"       = finite_diff_grad,
            "relative_error"          = relative_error,
            "flag_derivative_warning" = flag_derivative_warning
        )
    )
}

Try the nloptr package in your browser

Any scripts or data that you put into this service are public.

nloptr documentation built on May 28, 2022, 1:17 a.m.