R/fit_model.R

Defines functions predict.fit_model summary.fit_model plot.fit_model print.fit_model fit_model

Documented in fit_model plot.fit_model predict.fit_model print.fit_model summary.fit_model

#' Fit VAST to data
#'
#' \code{fit_model} fits a spatio-temporal model to data
#'
#' This function is the user-interface for the multiple mid-level functions that
#' perform separate components of a spatio-temporal analysis:
#' \itemize{
#' \item determine the extrapolation-grid \code{\link{make_extrapolation_info}},
#' \item define spatial objects \code{\link{make_spatial_info}},
#' \item build covariates from a formula interface \code{\link{make_covariates}},
#' \item assemble data \code{\link[VAST]{make_data}},
#' \item build model \code{\link[VAST]{make_model}},
#' \item estimate parameters \code{\link[TMBhelper]{fit_tmb}}, and
#' \item check for obvious problems with the estimates \code{\link[VAST]{check_fit}}.
#' }
#' Please see reference documetation for each of those functions (e.g., \code{?make_extrapolation_info}) to see a list of arguments used by each mid-level function.
#'
#' Specifically, the mid-level functions called by \code{fit_model(.)} look for arguments in the following order of precedence (from highest to lowest precedence):
#' \enumerate{
#' \item \code{fit_model(.)} prioritizes using named arguments passed directly to \code{fit_model(.)}. If arguments are passed this way, they are used instead of other options below.
#' \item If an argument is not passed supplied directly to \code{fit_model(.)}, then \code{fit_model(.)} looks for elements in input \code{settings}, as typically created by \code{\link{make_settings}}.
#' \item If an argument is not supplied via (1) or (2) above, then each mid-level function uses default values defined in those function arguments, e.g., see \code{args(make_extrapolation_info)} for defaults for function \code{make_extrapolation_info(.)}
#' }
#' Collectively, this order of precedence allows users to specify inputs for a specific project via input method (1), the package author to change defaults through changes in the settings
#' defined for a given purpose in \code{make_settings(.)} via input method (2), while still defaulting to package defaults via option (3).
#'
#' Variables are indexed internally for locations \code{g}, categories \code{c}, and times \code{y}.
#' Location index \code{g} represents Longitude-Latitude \code{fit$extrapolation_list$Data_Extrap[which(fit$spatial_list$g_e==g),c('Lon','Lat')]};
#' Time index \code{y} represents time \code{fit$year_labels}; and
#' Category \code{g} corresponds to values in \code{fit$data_list$g_i}.
#'
#' @inheritParams make_extrapolation_info
#' @inheritParams make_spatial_info
#' @inheritParams make_covariates
#' @inheritParams VAST::make_data
#' @inheritParams VAST::make_model
#' @inheritParams TMBhelper::fit_tmb
#' @param settings Output from \code{\link{make_settings}}
#' @param run_model Boolean indicating whether to run the model or simply return the inputs and built TMB object
#' @param test_fit Boolean indicating whether to apply \code{\link[VAST]{check_fit}} before calculating standard errors, to test for parameters hitting bounds etc; defaults to TRUE
#' @param category_names character vector specifying names for labeling categories \code{c_i}
#' @param year_labels character vector specifying names for labeling times \code{t_i}
#' @param ... additional arguments to pass to \code{\link{make_extrapolation_info}}, \code{\link{make_spatial_info}}, \code{\link[VAST]{make_data}}, \code{\link[VAST]{make_model}}, or \code{\link[TMBhelper]{fit_tmb}},
#' where arguments are matched by name against each function.  If an argument doesn't match, it is still passed to \code{\link[VAST]{make_data}}.  Note that \code{\link{make_spatial_info}}
#' passes named arguments to \code{\link[fmesher]{fm_mesh_2d}}.
#'
#' @return Object of class \code{fit_model}, containing formatted inputs and outputs from VAST
#' \describe{
#'   \item{\code{parameter_estimates}}{Output from \code{\link[TMBhelper]{fit_tmb}}; see that documentation for definition of contents}
#'   \item{\code{extrapolation_list}}{Output from \code{\link{make_extrapolation_info}}; see that documentation for definition of contents}
#'   \item{\code{spatial_list}}{Output from \code{\link{make_spatial_info}}; see that documentation for definition of contents}
#'   \item{\code{data_list}}{Output from \code{\link[VAST]{make_data}}; see that documentation for definition of contents}
#'   \item{\code{tmb_list}}{Output from \code{\link[VAST]{make_model}}; see that documentation for definition of contents}
#'   \item{\code{ParHat}}{Tagged list of maximum likelihood estimatesion of fixed effects and empirical Bayes estimates of random effects, following format of initial values generated by \code{\link[VAST]{make_parameters}}; see that documentation for definition of contents}
#'   \item{\code{Report}}{Tagged list of VAST outputs. For example, estimated density for grid \code{g}, category \code{c}, and time \code{y} is available as \code{fit$Report$D_gcy[g,c,y]}; see Details section for description of indexing}
#' }
#'
#' @family wrapper functions
#' @seealso \code{\link[VAST]{VAST}} for general documentation, \code{\link[FishStatsUtils]{make_settings}} for generic settings, \code{\link[FishStatsUtils]{fit_model}} for model fitting, and \code{\link[FishStatsUtils]{plot_results}} for generic plots
#' @seealso VAST wiki \url{https://github.com/James-Thorson-NOAA/VAST/wiki} for examples documenting many different use-cases and features.
#' @seealso GitHub mainpage \url{https://github.com/James-Thorson-NOAA/VAST#description} for a list of user resources and publications documenting features
#' @seealso \code{\link{summary.fit_model}} for methods to summarize output, including obtain a dataframe of estimated densities and an explanation of DHARMa Probability-Integral-Transform residuals
#' @seealso \code{\link{predict.fit_model}} for methods to predict at new locations using existing or updated covariate values
#'
#' @examples
#' \dontrun{
#' # Load packages
#' library(VAST)
#'
#' # load data set
#' # see `?load_example` for list of stocks with example data
#' # that are installed automatically with `FishStatsUtils`.
#' example = load_example( data_set="EBS_pollock" )
#'
#' # Make settings
#' settings = make_settings( n_x=50,
#'          Region=example$Region,
#'          purpose="index",
#'          strata.limits=example$strata.limits )
#'
#' # Run model
#' fit = fit_model( "settings"=settings,
#'     "Lat_i"=example$sampling_data[,'Lat'],
#'     "Lon_i"=example$sampling_data[,'Lon'],
#'     "t_i"=example$sampling_data[,'Year'],
#'     "c_i"=rep(0,nrow(example$sampling_data)),
#'     "b_i"=example$sampling_data[,'Catch_KG'],
#'     "a_i"=example$sampling_data[,'AreaSwept_km2'],
#'     "v_i"=example$sampling_data[,'Vessel'] )
#'
#' # Plot results
#' plot_results( settings=settings, fit=fit )
#' }
#'
#' @export
#' @md
# Using https://cran.r-project.org/web/packages/roxygen2/vignettes/rd-formatting.html for guidance on markdown-enabled documentation
fit_model <-
function( settings,
          Lat_i,
          Lon_i,
          t_i,
          b_i,
          a_i,
          c_iz = rep(0,length(b_i)),
          v_i = rep(0,length(b_i)),
          working_dir = tempdir(),
          X1config_cp = NULL,
          X2config_cp = NULL,
          covariate_data,
          X1_formula = ~ 0,
          X2_formula = ~ 0,
          Q1config_k = NULL,
          Q2config_k = NULL,
          catchability_data,
          Q1_formula = ~ 0,
          Q2_formula = ~ 0,
          newtonsteps = 1,
          silent = TRUE,
          build_model = TRUE,
          run_model = TRUE,
          test_fit = TRUE,
          category_names = NULL,
          year_labels = NULL,
          framework = "TMBad",
          use_new_epsilon = TRUE,
          ... ){

  # Capture extra arguments to function
  extra_args = list(...)
  # Backwards-compatible way to capture previous format to input extra arguments for each function via specific input-lists
  extra_args = c( extra_args,
             extra_args$extrapolation_args,
             extra_args$spatial_args,
             extra_args$optimize_args,
             extra_args$model_args )
  start_time = Sys.time()

  # Assemble inputs
  data_frame = data.frame( "Lat_i"=Lat_i, "Lon_i"=Lon_i, "a_i"=a_i, "v_i"=v_i, "b_i"=b_i, "t_i"=t_i, "c_iz"=c_iz )

  # Decide which years to plot
  if(is.null(year_labels)) year_labels = paste0( seq(min(t_i),max(t_i)) )
  if(is.null(category_names)) category_names = paste0( 1:(max(c_iz,na.rm=TRUE)+1) )
  years_to_plot = which( seq(min(t_i),max(t_i)) %in% t_i )

  # Save record
  message("\n### Writing output from `fit_model` in directory: ", working_dir)
  dir.create(working_dir, showWarnings=FALSE, recursive=TRUE)
  #save( settings, file=file.path(working_dir,"Record.RData"))
  capture.output( settings, file=file.path(working_dir,"settings.txt"))

  # Build extrapolation grid
  if( is.null(extra_args$extrapolation_list) ){
    message("\n### Making extrapolation-grid")
    extrapolation_args_default = list(Region = settings$Region,
                               strata.limits = settings$strata.limits,
                               zone = settings$zone,
                               max_cells = settings$max_cells,
                               DirPath = working_dir)
    extrapolation_args_input = combine_lists( input = extra_args,
                             default = extrapolation_args_default,
                             args_to_use = formalArgs(make_extrapolation_info) )
    extrapolation_list = do.call( what=make_extrapolation_info, args=extrapolation_args_input )
  }else{
    extrapolation_args_input = NULL
    extrapolation_list = extra_args$extrapolation_list
  }

  # Build information regarding spatial location and correlation
  if( is.null(extra_args$spatial_list) ){
    message("\n### Making spatial information")
    spatial_args_default = list( grid_size_km = settings$grid_size_km,
                         n_x = settings$n_x,
                         Method = settings$Method,
                         Lon_i = Lon_i,
                         Lat_i = Lat_i,
                         Extrapolation_List = extrapolation_list,
                         DirPath = working_dir,
                         Save_Results = TRUE,
                         fine_scale = settings$fine_scale,
                         knot_method = settings$knot_method,
                         mesh_package = settings$mesh_package )
    spatial_args_input = combine_lists( input=extra_args, default=spatial_args_default, args_to_use=c(formalArgs(make_spatial_info),formalArgs(fmesher::fm_mesh_2d)) )
    spatial_list = do.call( what=make_spatial_info, args=spatial_args_input )
  }else{
    spatial_args_input = NULL
    spatial_list = extra_args$spatial_list
  }

  # Build data
  # Do *not* restrict inputs to formalArgs(make_data) because other potential inputs are still parsed by make_data for backwards compatibility
  if( is.null(extra_args$data_list) ){
    message("\n### Making data object") # VAST::
    if(missing(covariate_data)) covariate_data = NULL
    if(missing(catchability_data)) catchability_data = NULL
    data_args_default = list( "Version" = settings$Version,
                      "FieldConfig" = settings$FieldConfig,
                      "OverdispersionConfig" = settings$OverdispersionConfig,
                      "RhoConfig" = settings$RhoConfig,
                      "VamConfig" = settings$VamConfig,
                      "ObsModel" = settings$ObsModel,
                      "c_iz" = c_iz,
                      "b_i" = b_i,
                      "a_i" = a_i,
                      "v_i" = v_i,
                      "s_i" = spatial_list$knot_i-1,
                      "t_i" = t_i,
                      "spatial_list" = spatial_list,
                      "Options" = settings$Options,
                      "Aniso" = settings$use_anisotropy,
                      "X1config_cp" = X1config_cp,
                      "X2config_cp" = X2config_cp,
                      "covariate_data" = covariate_data,
                      "X1_formula" = X1_formula,
                      "X2_formula" = X2_formula,
                      "Q1config_k" = Q1config_k,
                      "Q2config_k" = Q2config_k,
                      "catchability_data" = catchability_data,
                      "Q1_formula" = Q1_formula,
                      "Q2_formula" = Q2_formula )
    data_args_input = combine_lists( input=extra_args, default=data_args_default )  # Do *not* use args_to_use
    data_list = do.call( what=make_data, args=data_args_input )
    #return(data_list) }
  }else{
    data_args_input = NULL
    data_list = extra_args$data_list
  }

  # Build object
  message("\n### Making TMB object")
  model_args_default = list("TmbData" = data_list,
                     "RunDir" = working_dir,
                     "Version" = settings$Version,
                     "RhoConfig" = settings$RhoConfig,
                     "loc_x" = spatial_list$loc_x,
                     "Method" = spatial_list$Method,
                     "build_model" = build_model,
                     "framework" = framework )
  model_args_input = combine_lists( input=extra_args, default=model_args_default, args_to_use=formalArgs(make_model) )
  tmb_list = do.call( what=make_model, args=model_args_input )

  # Run the model or optionally don't
  if( run_model==FALSE | build_model==FALSE ){
    # Build and output
    input_args = list( "extra_args" = extra_args,
               "extrapolation_args_input" = extrapolation_args_input,
               "model_args_input" = model_args_input,
               "spatial_args_input" = spatial_args_input,
               "data_args_input" = data_args_input )
    Return = list( "data_frame" = data_frame,
           "extrapolation_list" = extrapolation_list,
           "spatial_list" = spatial_list,
           "data_list" = data_list,
           "tmb_list" = tmb_list,
           "year_labels" = year_labels,
           "years_to_plot" = years_to_plot,
           "category_names" = category_names,
           "settings" = settings,
           "input_args" = input_args)
    class(Return) = "fit_model"
    return(Return)
  }
  if(silent==TRUE) tmb_list$Obj$env$beSilent()

  # Check for obvious problems with model
  if( test_fit==TRUE ){
    message("\n### Checking model at initial values")
    LogLike0 = tmb_list$Obj$fn( tmb_list$Obj$par )
    Gradient0 = tmb_list$Obj$gr( tmb_list$Obj$par )
    if( any( Gradient0==0 ) ){
      message("\n")
      stop("Please check model structure; some parameter has a gradient of zero at starting values\n", call.=FALSE)
    }else{
      message("All fixed effects have a nonzero gradient")
    }
  }

  # Optimize object
  message("\n### Estimating parameters")
  # have user override upper, lower, and loopnum
  optimize_args_default1 = list( lower = tmb_list$Lower,
                         upper = tmb_list$Upper,
                         loopnum = 1)
  optimize_args_default1 = combine_lists( default=optimize_args_default1, input=extra_args, args_to_use=formalArgs(TMBhelper::fit_tmb) )
  # auto-override user inputs for optimizer-related inputs for first test run
  optimize_args_input1 = list(obj = tmb_list$Obj,
                       savedir = NULL,
                       newtonsteps = 0,
                       bias.correct = FALSE,
                       control = list(eval.max = 50000, iter.max = 50000, trace = 1),
                       quiet = TRUE,
                       getsd = FALSE )
  # combine
  optimize_args_input1 = combine_lists( default=optimize_args_default1, input=optimize_args_input1, args_to_use=formalArgs(TMBhelper::fit_tmb) )
  parameter_estimates1 = do.call( what=TMBhelper::fit_tmb, args=optimize_args_input1 )

  # Check fit of model (i.e., evidence of non-convergence based on bounds, approaching zero, etc)
  if(exists("check_fit") & test_fit==TRUE ){
    problem_found = VAST::check_fit( parameter_estimates1 )
    if( problem_found==TRUE ){
      message("\n")
      stop("Please change model structure to avoid problems with parameter estimates and then re-try; see details in `?check_fit`\n", call.=FALSE)
    }
  }

  # Override default bias-correction
  if( (use_new_epsilon==TRUE) & (settings$bias.correct==TRUE) & (framework=="TMBad") & ("Index_ctl" %in% settings$vars_to_correct) ){
    settings$vars_to_correct = setdiff(settings$vars_to_correct, c("Index_ctl","Index_cyl"))
    # If length(settings$vars_to_correct)==0, then fit_tmb currently bias-corrects all parameters, so fixing that here
    if( length(settings$vars_to_correct)==0 ){
      settings$bias.correct = FALSE
    }
    settings$vars_to_correct = c( settings$vars_to_correct, "eps_Index_ctl" )
  }

  # Restart estimates after checking parameters
  optimize_args_default2 = list( obj = tmb_list$Obj,
                         lower = tmb_list$Lower,
                         upper = tmb_list$Upper,
                         savedir = working_dir,
                         bias.correct = settings$bias.correct,
                         newtonsteps = newtonsteps,
                         bias.correct.control = list(sd = FALSE, split = NULL, nsplit = 1, vars_to_correct = settings$vars_to_correct),
                         control = list(eval.max = 10000, iter.max = 10000, trace = 1),
                         loopnum = 1,
                         getJointPrecision = TRUE,
                         start_time_elapsed = parameter_estimates1$time_for_run )
  # combine while over-riding defaults using user inputs
  optimize_args_input2 = combine_lists( input=extra_args, default=optimize_args_default2, args_to_use=formalArgs(TMBhelper::fit_tmb) )
  # over-ride inputs to start from previous MLE
  optimize_args_input2 = combine_lists( input=list(startpar=parameter_estimates1$par), default=optimize_args_input2 )
  parameter_estimates2 = do.call( what=TMBhelper::fit_tmb, args=optimize_args_input2 )

  # Override default bias-correction
  if( (use_new_epsilon==TRUE) & (framework=="TMBad") & ("eps_Index_ctl" %in% settings$vars_to_correct) & !is.null(parameter_estimates2$SD) ){
    message("\n### Applying faster epsilon bias-correction estimator")
    fit = list( "parameter_estimates"=parameter_estimates2, "tmb_list"=tmb_list, "input_args"=list("model_args_input"=model_args_input) )
    parameter_estimates2$SD = apply_epsilon( fit )
  }

  # Extract standard outputs
  if( "par" %in% names(parameter_estimates2) ){
    if( !is.null(tmb_list$Obj$env$intern) && tmb_list$Obj$env$intern==TRUE ){
      Report = as.list(tmb_list$Obj$env$reportenv)
    }else{
      Report = tmb_list$Obj$report()
    }
    ParHat = tmb_list$Obj$env$parList( parameter_estimates2$par )

    # Label stuff
    Report = amend_output( Report = Report,
                           TmbData = data_list,
                           Map = tmb_list$Map,
                           Sdreport = parameter_estimates2$SD,
                           year_labels = year_labels,
                           category_names = category_names,
                           extrapolation_list = extrapolation_list )
  }else{
    Report = ParHat = "Model is not converged"
  }

  # Build and output
  input_args = list( "extra_args" = extra_args,
             "extrapolation_args_input" = extrapolation_args_input,
             "model_args_input" = model_args_input,
             "spatial_args_input" = spatial_args_input,
             "optimize_args_input1" = optimize_args_input1,
             "optimize_args_input2" = optimize_args_input2,
             "data_args_input" = data_args_input )
  Return = list( "data_frame" = data_frame,
         "extrapolation_list" = extrapolation_list,
         "spatial_list" = spatial_list,
         "data_list" = data_list,
         "tmb_list" = tmb_list,
         "parameter_estimates" = parameter_estimates2,
         "Report" = Report,
         "ParHat" = ParHat,
         "year_labels" = year_labels,
         "years_to_plot" = years_to_plot,
         "category_names" = category_names,
         "settings" = settings,
         "input_args" = input_args,
         "X1config_cp" = X1config_cp,
         "X2config_cp" = X2config_cp,
         "covariate_data" = covariate_data,
         "X1_formula" = X1_formula,
         "X2_formula" = X2_formula,
         "Q1config_k" = Q1config_k,
         "Q2config_k" = Q1config_k,
         "catchability_data" = catchability_data,
         "Q1_formula" = Q1_formula,
         "Q2_formula" = Q2_formula,
         "total_time" = Sys.time() - start_time )

  # Add stuff for effects package
  Return$effects = list()
  if( !is.null(catchability_data) ){
    catchability_data_full = data.frame( catchability_data, "linear_predictor"=0 )
    Q1_formula_full = update.formula(Q1_formula, linear_predictor~.+0)
    call_Q1 = lm( Q1_formula_full, data=catchability_data_full)$call
    Q2_formula_full = update.formula(Q2_formula, linear_predictor~.+0)
    call_Q2 = lm( Q2_formula_full, data=catchability_data_full)$call
    Return$effects = c( Return$effects, list(call_Q1=call_Q1, call_Q2=call_Q2, catchability_data_full=catchability_data_full) )
  }
  if( !is.null(covariate_data) ){
    covariate_data_full = data.frame( covariate_data, "linear_predictor"=0 )
    X1_formula_full = update.formula(X1_formula, linear_predictor~.+0)
    call_X1 = lm( X1_formula_full, data=covariate_data_full)$call
    X2_formula_full = update.formula(X2_formula, linear_predictor~.+0)
    call_X2 = lm( X2_formula_full, data=covariate_data_full)$call
    Return$effects = c( Return$effects, list(call_X1=call_X1, call_X2=call_X2, covariate_data_full=covariate_data_full) )
  }

  # Add stuff for marginaleffects package
  Return$last.par.best =  tmb_list$Obj$env$last.par.best

  # class and return
  class(Return) = "fit_model"
  return( Return )
}

#' Print parameter estimates and standard errors.
#'
#' @title Print parameter estimates
#' @param x Output from \code{\link{fit_model}}
#' @param ... Not used
#' @return NULL
#' @method print fit_model
#' @export
print.fit_model <- function(x, ...)
{
  cat("fit_model(.) result\n")
  if( "parameter_estimates" %in% names(x) ){
    print( x$parameter_estimates )
  }else{
    cat("`parameter_estimates` not available in `fit_model`\n")
  }
  invisible(x$parameter_estimates)
}

#' Print parameter estimates and standard errors.
#'
#' @title Print parameter estimates
#' @param fit Output from \code{\link{fit_model}}
#' @param what String specifying what elements of results to plot;  options include `extrapolation_grid`, `spatial_mesh`, and `results`
#' @param ... Arguments passed to \code{\link{plot_results}}
#' @return NULL
#' @method plot fit_model
#' @export
plot.fit_model <- function(x, what="results", ...)
{
  if(!is.character(what)) stop("Check `what` in `plot.fit_model`")

  ## Plot extrapolation-grid
  if( length(grep(what, "extrapolation_grid")) ){
    cat("\n### Running `plot.make_extrapolation_info`\n")
    plot( x$extrapolation_list )
    return(invisible(NULL))
  }

  ## Plot extrapolation-grid
  if( length(grep(what, c("spatial_info","inla_mesh"))) ){
    cat("\n### Running `plot.make_spatial_info`\n")
    plot( x$spatial_list )
    return(invisible(NULL))
  }

  # diagnostic plots
  if( length(grep(what, "results")) ){
    cat("\n### Running `plot_results`\n")
    ans = plot_results( x, ... )
    return(invisible(ans))
  }

  stop( "input `what` not matching available options" )
}

#' Extract summary of spatial estimates
#'
#' \code{summary.fit_model} extracts commonly used quantities derived from a fitted VAST model
#'
#' \code{summary.fit_model} faciliates common queries for model output including:
#' \itemize{
#' \item \code{what="density"} returns a tagged list containing element \code{Density_dataframe},
#' which lists the estimated density for every Latitude-Longitude-Year-Category combination
#' for every modelled location in the extrapolation-grid.
#' \item \code{what="residuals"} returns a DHARMa object containing PIT residuals;
#' See details section for more information.
#' }
#'
#' For calculating residuals, the function calls package \code{\link[DHARMa]{DHARMa}}
#' to create a diagnostic object for simulation-based quantile residuals.
#' It specifically simulates replicated data sets from the predictive distribution of data
#' conditional on estimated fixed and random effects. It then
#' calculates probability-integral-transform (PIT) residuals from the observed and simulated values.
#' It then replaces the automatically calculated residuals in the DHARMa object with these these PIT residuals,
#' so that DHARMa can be used to plot those PIT residuals. PIT residuals are used because the original DHARMa calculations
#' are not correct when using a delta-model (due to additional jittered values added by DHARMa when detecting multiple 0-valued observations), hence
#' the need to call this function to correctly calculate PIT residuals for a delta-model.
#'
#' Note that \code{summary(fit, ..., type=0} uses \code{\link[FishStatsUtils]{oneStepPredict_deltaModel}} to calculate one-step-ahead
#' residuals.  These are probably the most appropriate method for evaluating residuals, but are also *very* slow to calculate relative
#' to other methods.
#'
#' @inheritParams simulate_data
#' @inheritParams DHARMa::plotResiduals
#'
#' @param x Output from \code{\link{fit_model}}
#' @param what String indicating what to summarize; options are `density`, `index` or `residuals`
#' @param n_samples Number of samples used when \code{what="residuals"}
#' @param ... additional arguments passed to \code{\link[DHARMa]{plotResiduals}} when \code{what="residuals"}
#'
#' @return NULL
#' @seealso \code{\link{plot_quantile_residuals}} to plot output of \code{summary.fit_model(x,what="residuals")}
#' @method summary fit_model
#' @export
summary.fit_model <-
function( x,
          what = "density",
          n_samples = 250,
          working_dir = NULL,
          type = 1,
          random_seed = NULL,
          form = NULL,
          category_names = x$category_names,
          year_labels = x$year_labels,
          ...)
{
  ans = NULL

  # Check and implement units and labels
  x$Report = amend_output( fit = x,
                           year_labels = year_labels,
                           category_names = category_names )

  if( tolower(what) %in% c("density","index") ){
    # Load location of extrapolation-grid
    ans[["extrapolation_grid"]] = print( x$extrapolation_list, quiet=TRUE )

    # Load density estimates
    if( tolower(what) == "density" ){
      ans[["Density_array"]] =  x$Report[["D_gct"]]
    }else if( tolower(what) == "index" ){
      ans[["Density_array"]] =  x$Report[["Index_gctl"]]
    }

    # Exclude boundary knots
    if( !( x$settings$fine_scale==TRUE | x$spatial_list$Method=="Stream_network" ) ){
      index_tmp = x$spatial_list$NN_Extrap$nn.idx[ which(x$extrapolation_list[["Area_km2_x"]]>0), 1 ]
      ans[["Density_array"]] = ans[["Density_array"]][ index_tmp,,,drop=FALSE]
    }
    # Error check
    if( any(sapply(dimnames(ans[["Density_array"]]),FUN=is.null)) ){
      stop("`summary.fit_model` assumes that arrays are labeled")
    }

    # Expand as grid
    Density_dataframe = expand.grid( dimnames(ans[["Density_array"]]) )
    Density_dataframe = cbind( Density_dataframe, ans[["extrapolation_grid"]][Density_dataframe[,'Site'],], as.vector(ans[["Density_array"]]) )
    colnames(Density_dataframe)[ncol(Density_dataframe)] = tolower(what)

    # Save output
    ans[["Density_dataframe"]] = Density_dataframe
    ans[['year_labels']] = x[['year_labels']]
    rownames(Density_dataframe) = NULL

    # Print to terminal
    cat("\n### Printing head of and tail `Density_dataframe`, and returning data frame in output object\n")
    print(head(Density_dataframe))
    print(tail(Density_dataframe))
  }

  # Residuals
  if( tolower(what) == "residuals" ){
    # extract objects
    Obj = x$tmb_list$Obj

    # Change n_g
    # Must change back explicitly because TMB appears to pass env as a pointer, so changes in copy affect original x outside of function!
    n_g_orig = Obj$env$data$n_g
    revert_settings = function(n_g){Obj$env$data$n_g = n_g}
    on.exit( revert_settings(n_g_orig) )
    Obj$env$data$n_g = 0

    if( type %in% c(1,4) ){
      b_iz = matrix(NA, nrow=length(x$data_list$b_i), ncol=n_samples)
      message( "Sampling from the distribution of data conditional on estimated fixed and random effects" )
      for( zI in 1:n_samples ){
        if( zI%%max(1,floor(n_samples/10)) == 0 ){
          message( "  Finished sample ", zI, " of ",n_samples )
        }
        b_iz[,zI] = simulate_data( fit=list(tmb_list=list(Obj=Obj)), type=type, random_seed=list(random_seed+zI,NULL)[[1+is.null(random_seed)]] )$b_i
      }
      #if( any(is.na(x$data_list$b_i)) ){
      #  stop("dharmaRes not designed to work when any observations have b_i=NA")
      #}
      # Substitute any observation where b_i = NA with all zeros, which will then have a uniform PIT
      which_na = which(is.na(x$data_list$b_i))
      if( length(which_na) > 0 ){
        x$data_list$b_i[which_na] = 0
        b_iz[which_na,] = 0
        warning("When calculating DHARMa residuals, replacing instances where b_i=NA with a uniform PIT residual")
      }
      if( any(is.na(b_iz)) ){
        stop("Check simulated residuals for NA values")
      }

      # Run DHARMa
      # Adding jitters because DHARMa version 0.3.2.0 sometimes still throws an error method="traditional" and integer=FALSE without jitters
      dharmaRes = DHARMa::createDHARMa( simulatedResponse = strip_units(b_iz) + 1e-10*array(rnorm(prod(dim(b_iz))),dim=dim(b_iz)),
                                        observedResponse = strip_units(x$data_list$b_i) + 1e-10*rnorm(length(x$data_list$b_i)),
                                        fittedPredictedResponse = strip_units(x$Report$D_i),
                                        integer = FALSE)
      #dharmaRes = DHARMa::createDHARMa(simulatedResponse=strip_units(b_iz),
      #  observedResponse=strip_units(x$data_list$b_i),
      #  fittedPredictedResponse=strip_units(x$Report$D_i),
      #  method="PIT")

      # Save to report error
      if( FALSE ){
        all = list( simulatedResponse=strip_units(b_iz), observedResponse=strip_units(x$data_list$b_i), fittedPredictedResponse=strip_units(x$Report$D_i) )
        #save(all, file=paste0(root_dir,"all.RData") )
        dharmaRes = DHARMa::createDHARMa(simulatedResponse=all$simulatedResponse + rep(1,nrow(all$simulatedResponse))%o%c(0.001*rnorm(1),rep(0,ncol(all$simulatedResponse)-1)),
          observedResponse=all$observedResponse,
          fittedPredictedResponse=all$fittedPredictedResponse,
          method="PIT")
      }

      # Calculate probability-integral-transform (PIT) residuals
      message( "Substituting probability-integral-transform (PIT) residuals for DHARMa-calculated residuals" )
      prop_lessthan_i = apply( b_iz<outer(x$data_list$b_i,rep(1,n_samples)),
        MARGIN=1,
        FUN=mean )
      prop_lessthanorequalto_i = apply( b_iz<=outer(x$data_list$b_i,rep(1,n_samples)),
        MARGIN=1,
        FUN=mean )
      PIT_i = runif(min=prop_lessthan_i, max=prop_lessthanorequalto_i, n=length(prop_lessthan_i) )
      # cbind( "Difference"=dharmaRes$scaledResiduals - PIT_i, "PIT"=PIT_i, "Original"=dharmaRes$scaledResiduals, "b_i"=x$data_list$b_i )
      dharmaRes$scaledResiduals = PIT_i
    }else if( type==0 ){
      # Check for issues
      if( !all(x$data_list$ObsModel_ez[,1] %in% c(1,2)) ){
        stop("oneStepAhead residuals only code for gamma and lognormal distributions")
      }

      # Run OSA
      message( "Running oneStepPredict_deltaModel for each observation, to then load them into DHARMa object for plotting" )
      osa = TMBhelper::oneStepPredict_deltaModel( obj = x$tmb_list$Obj,
        observation.name = "b_i",
        method = "cdf",
        data.term.indicator = "keep",
        deltaSupport = 0,
        trace = TRUE )

      # Build DHARMa object on fake inputs and load OSA into DHARMa object
      dharmaRes = DHARMa::createDHARMa(simulatedResponse=matrix(rnorm(x$data_list$n_i*10,mean=x$data_list$b_i),ncol=10),
        observedResponse=x$data_list$b_i,
        fittedPredictedResponse=x$Report$D_i,
        integer=FALSE)
      dharmaRes$scaledResiduals = pnorm(osa$residual)
    }else{
      stop("`type` only makes sense for 0 (oneStepAhead), 1 (conditional, a.k.a. measurement error) or 4 (unconditional) simulations")
    }

    # do plot
    if( is.null(working_dir) ){
      plot_dharma(dharmaRes, ...)
    }else if(!is.na(working_dir) ){
      png(file=file.path(working_dir,"quantile_residuals.png"), width=8, height=4, res=200, units='in')
        plot_dharma(dharmaRes, form=form, ...)
      dev.off()
    }

    # Return stuff
    ans = dharmaRes
    message( "Invisibly returning output from `DHARMa::createDHARMa`, e.g., to apply `plot.DHARMa` to this output")
  }

  if( tolower(what) %in% c("parhat","estimates") ){
    ans[["estimates"]] = x$ParHat
    cat("\n### Printing slots of `ParHat`, and returning list in output object")
    print(names(x$ParHat))
  }

  if( is.null(ans) ){
    stop( "`summary.fit_model` not implemented for inputted value of argument `what`" )
  }

  # diagnostic plots
  return(invisible(ans))
}


#' Predict density for new samples (\emph{Beta version; may change without notice})
#'
#' \code{predict.fit_model} calculates predictions given new data
#'
#' \code{predict.fit_model} is designed with two purposes in mind:
#' \enumerate{
#' \item If \code{new_covariate_data=NULL} as by default, then the model uses the covariate values supplied during original model fits,
#'       and interpolates as needed from those supplied values to new predicted locations.  This then uses *exactly* the same information
#'       as was available during model fitting.
#' \item If \code{new_covariate_data} is supplied with new values (e.g., at locations for predictions), then these values are used in
#'       combination with original covariate values when interpolating to new values.  However, supplying \code{new_oovariate_data}
#'       at the same Lat-Lon-Year combination as any original covariate value will delete those matches in the latter, such that originally fitted data
#'       can be predicted using alternative values for covariates (e.g., when calculating partial dependence plots)
#' }
#'
#' @inheritParams make_covariates
#' @inheritParams VAST::make_data
#' @param x Output from \code{\link{fit_model}}
#' @param what Which output from \code{fit$Report} should be extracted; default is predicted density
#' @param keep_old_covariates Whether to add new_covariate_data to existing data.
#'        This is useful when predicting values at new locations, but does not work
#'        when predicting data are locations with existing data (because the interpolation of
#'        covariate values will conflict for existing and new covariate values), e.g.,
#'        when calculating partial dependence plots for existing data.
#'
#' @return NULL
#'
#' @examples
#' \dontrun{
#'
#' # Showing use of package pdp for partial dependence plots
#' pred.fun = function( object, newdata ){
#'   predict( x=object,
#'     Lat_i = object$data_frame$Lat_i,
#'     Lon_i = object$data_frame$Lon_i,
#'     t_i = object$data_frame$t_i,
#'     a_i = object$data_frame$a_i,
#'     what = "P1_iz",
#'     new_covariate_data = newdata,
#'     do_checks = FALSE )
#' }
#'
#' library(ggplot2)
#' library(pdp)
#' Partial = partial( object = fit,
#'                    pred.var = "BOT_DEPTH",
#'                    pred.fun = pred.fun,
#'                    train = fit$covariate_data )
#' autoplot(Partial)
#'
#' }
#'
#' @method predict fit_model
#' @export
predict.fit_model <- function(x,
                  what="D_i",
                  Lat_i,
                  Lon_i,
                  t_i,
                  a_i,
                  c_iz = rep(0,length(t_i)),
                  v_i = rep(0,length(t_i)),
                  new_covariate_data = NULL,
                  new_catchability_data = NULL,
                  do_checks = TRUE,
                  working_dir = getwd() )
{
  message("`predict.fit_model(.)` is in beta-testing, and please explore results carefully prior to using")

  # Check issues
  if( !(what%in%names(x$Report)) || (length(x$Report[[what]])!=x$data_list$n_i) ){
    stop("`what` can only take a few options")
  }
  if( !is.null(new_covariate_data) ){
    # Confirm all columns are available
    if( !all(colnames(x$covariate_data) %in% colnames(new_covariate_data)) ){
      stop("Please ensure that all columns of `x$covariate_data` are present in `new_covariate_data`")
    }
    # Eliminate unnecessary columns
    new_covariate_data = new_covariate_data[,match(colnames(x$covariate_data),colnames(new_covariate_data))]
    # Eliminate old-covariates that are also present in new_covariate_data
    NN = RANN::nn2( query=x$covariate_data[,c('Lat','Lon','Year')], data=new_covariate_data[,c('Lat','Lon','Year')], k=1 )
    if( any(NN$nn.dist==0) ){
      x$covariate_data = x$covariate_data[-which(NN$nn.dist==0),,drop=FALSE]
    }
  }
  if( !is.null(new_catchability_data) ){
    stop("Option not implemented")
  }

  # Process covariates
  catchability_data = rbind( x$catchability_data, new_catchability_data )
  covariate_data = rbind( x$covariate_data, new_covariate_data )

  # Process inputs
  PredTF_i = c( x$data_list$PredTF_i, rep(1,length(t_i)) )
  b_i = c( x$data_frame[,"b_i"], rep(1,length(t_i)) )
  c_iz = rbind( matrix(x$data_frame[,grep("c_iz",names(x$data_frame))]), matrix(c_iz) )
  Lat_i = c( x$data_frame[,"Lat_i"], Lat_i )
  Lon_i = c( x$data_frame[,"Lon_i"], Lon_i )
  a_i = c( x$data_frame[,"a_i"], a_i )
  v_i = c( x$data_frame[,"v_i"], v_i )
  t_i = c( x$data_frame[,"t_i"], t_i )
  #assign("b_i", b_i, envir=.GlobalEnv)

  # Build information regarding spatial location and correlation
  message("\n### Re-making spatial information")
  spatial_args_new = list("anisotropic_mesh"=x$spatial_list$MeshList$anisotropic_mesh, "Kmeans"=x$spatial_list$Kmeans,
    "Lon_i"=Lon_i, "Lat_i"=Lat_i )
  spatial_args_input = combine_lists( input=spatial_args_new, default=x$input_args$spatial_args_input )
  spatial_list = do.call( what=make_spatial_info, args=spatial_args_input )

  # Check spatial_list
  if( !all.equal(spatial_list$MeshList,x$spatial_list$MeshList) ){
    stop("`MeshList` generated during `predict.fit_model` doesn't match that of original fit; please email package author to report issue")
  }

  # Build data
  # Do *not* restrict inputs to formalArgs(make_data) because other potential inputs are still parsed by make_data for backwards compatibility
  message("\n### Re-making data object")
  data_args_new = list( "c_iz"=c_iz, "b_i"=b_i, "a_i"=a_i, "v_i"=v_i, "PredTF_i"=PredTF_i,
    "t_i"=t_i, "spatial_list"=spatial_list,
    "covariate_data"=covariate_data, "catchability_data"=catchability_data )
  data_args_input = combine_lists( input=data_args_new, default=x$input_args$data_args_input )  # Do *not* use args_to_use
  data_list = do.call( what=make_data, args=data_args_input )
  data_list$n_g = 0

  # Build object
  message("\n### Re-making TMB object")
  model_args_default = list("TmbData"=data_list, "RunDir"=working_dir, "Version"=x$settings$Version,
    "RhoConfig"=x$settings$RhoConfig, "loc_x"=spatial_list$loc_x, "Method"=spatial_list$Method)
  model_args_input = combine_lists( input=list("Parameters"=x$ParHat),
    default=model_args_default, args_to_use=formalArgs(make_model) )
  tmb_list = do.call( what=make_model, args=model_args_input )

  # Extract output
  Report = tmb_list$Obj$report()
  Y_i = Report[[what]][(1+nrow(x$data_frame)):length(Report$D_i)]

  # sanity check
  #if( all.equal(covariate_data,x$covariate_data) & Report$jnll!=x$Report$jnll){
  if( do_checks==TRUE && (Report$jnll!=x$Report$jnll) ){
    message("Problem detected in `predict.fit_model`; returning outputs for diagnostic purposes")
    Return = list("Report"=Report, "data_list"=data_list)
    return(Return)
  }

  # return prediction
  return(Y_i)
}
James-Thorson/FishStatsUtils documentation built on Feb. 6, 2024, 4:26 a.m.