R/explore_mod.R

Defines functions explore_mod

Documented in explore_mod

#' Explore (pre- and post-) model results.
#'
#'
#' @param taxa Scientific name of taxa (for labelling things).
#' @param common Common name of taxa (also used in labels).
#' @param mod_path Path to saved model.
#' @param mod_type Type of model (e.g. 'reporting rate', or 'occupancy')
#' @param resp_var What is the response variable?
#' @param geo_var Categorical variable in the model (usually biogeographic)
#' @param time_var Continuous time variable (usually `year`)
#' @param max_levels Maximum number of classes to include in categorical plots.
#' @param draws Number of draws from posterior distribution to display in plots.
#' @param reference_year Reference time at which to predict (and compare change)
#' @param recent_year Time at which to predict to compare against reference_year
#' @param quant_probs Quantiles for summarising.
#' @param limit_preds Logical. Limit predictions to within contexts (`time_var`
#' within `geo_var`) that exist in the data.
#'
#' @return A list with components
#' \describe{
#'   \item{count_char}{ggplot object. Count of levels within character variables.}
#'   \item{y_vs_char}{ggplot object. Response variable vs character variables.}
#'   \item{count_num}{ggplot object. Histogram of each numeric variable.}
#'   \item{y_vs_num}{ggplot object. Response variable vs numeric variables with `geom_smooth()`.}
#'   \item{pairs}{`GGally::ggpairs` object from `df`.}
#'   \item{pred}{Dataframe of `posterior_predict` for levels of interest.}
#'   \item{res}{`pred` summarised by `c(resp_var, time_var)`.}
#' }
#' @export
#'
#' @examples
  explore_mod <- function(taxa
                          , common
                          , mod_path
                          , mod_type
                          , resp_var = "prop"
                          , geo_var = "IBRA_SUB_N"
                          , time_var = "year"
                          , max_levels = 30
                          , draws = 200
                          , reference_year = 2000
                          , recent_year = 2010
                          , re_run = FALSE
                          , quant_probs = c(0.05, 0.5, 0.95)
                          , binom_denom = "trials"
                          , limit_preds = TRUE
                          ) {

    `:=` <- rlang::`:=`

    where <- tidyselect:::where

    taxa <- as.character(taxa)
    common <- as.character(common)

    print(taxa)

    #-------do_run--------

    tests <- tibble::tribble(~type, ~year
                             , "reference", reference_year
                             , "recent", recent_year
                             ) %>%
      tidyr::unnest(cols = c(.data$year))

    reference <- tests$year[tests$type == "reference"]

    recent <- tests$year[tests$type == "recent"]

    context <- c(geo_var, time_var, mod_type)

    mod <- rio::import(mod_path)

    df <- mod$data

    if(stats::family(mod)$family == "binomial") {

      df$success <- mod$y[,1]
      df[[binom_denom]] <- mod$y[,1] + mod$y[,2]

      df <- df %>%
        dplyr::select(!tidyselect::matches("cbind")) %>%
        dplyr::mutate(prop = success / !!rlang::ensym(binom_denom))

    }



    #-------setup explore-------

    res <- list()

    plot_titles <- bquote(~italic(.(taxa))*":" ~ .(common))

    has_ll <- sum(grepl("list_length"
                        , names(mod$coefficients))
                  ) > 0

    if(!resp_var %in% names(df)) {

      df$success <- df[,1][,1]
      df$trials <- df[,1][,1] + df[,1][,2]
      df$prop <- df$success / df$trials
      df[,1] <- NULL

    }

    # variables to explore
    var_exp <- c(resp_var
                 , colnames(df)
                 ) %>%
      unique()

    dat_exp <- df %>%
      dplyr::select(tidyselect::any_of(var_exp))

    has_numeric <- dat_exp %>%
      dplyr::select(-1) %>%
      dplyr::select(where(is.numeric)) %>%
      ncol() %>%
      `>` (0)

    has_character <- dat_exp %>%
      dplyr::select(-1) %>%
      dplyr::mutate(dplyr::across(where(is.factor),as.character)) %>%
      dplyr::select(where(is.character)) %>%
      ncol() %>%
      `>` (0)

    # Character variables
    if(has_character) {

      plot_data <- dat_exp %>%
        dplyr::mutate(dplyr::across(where(is.factor), as.character)) %>%
        dplyr::select_if(is.character) %>%
        tidyr::gather(variable, value, 1:ncol(.)) %>%
        dplyr::group_by(.data$variable) %>%
        dplyr::mutate(levels = dplyr::n_distinct(.data$value)) %>%
        dplyr::ungroup() %>%
        dplyr::filter(levels < max_levels)

      # count character
      res$count_char <- ggplot2::ggplot(data = plot_data) +
        ggplot2::geom_histogram(ggplot2::aes(.data$value)
                                , stat = "count"
                                ) +
        ggplot2::facet_wrap(~ .data$variable
                            , scales = "free"
                            ) +
        ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90
                                                           , vjust = 0.5
                                                           , hjust = 1
                                                           )
                       ) +
        ggplot2::labs(title = plot_titles
                      , subtitle = "Count of levels within character variables"
                      )

      # resp_var vs character
      plot_data <- dat_exp %>%
        dplyr::mutate({{ resp_var }} := as.factor(.data[[resp_var]])) %>%
        dplyr::mutate_if(is.factor,as.character) %>%
        dplyr::select_if(is.character) %>%
        dplyr::mutate({{resp_var}} := as.numeric(.data[[resp_var]])) %>%
        tidyr::gather(variable, value, 2:ncol(.)) %>%
        dplyr::group_by(.data$variable) %>%
        dplyr::mutate(levels = dplyr::n_distinct(.data$value)) %>%
        dplyr::ungroup() %>%
        dplyr::filter(levels < max_levels)

        res$y_vs_char <- ggplot2::ggplot(plot_data) +
        ggplot2::geom_boxplot(ggplot2::aes(x = .data$value
                                           , y = .data[[resp_var]]
                                           )
                              ) +
        ggplot2::facet_wrap(~ .data$variable
                            , scales = "free"
                            ) +
        ggplot2::theme(axis.text.x=ggplot2::element_text(angle = 90
                                                , vjust = 0.5
                                                )
                       ) +
        ggplot2::labs(title = plot_titles
                      , subtitle = paste0("Boxplots of response variable ("
                                          , resp_var
                                          , ") against character variables"
                                          )
                      )

    }

    # Numeric variables
    if(has_numeric) {

      plot_data <- dat_exp %>%
        dplyr::select(where(is.numeric)) %>%
        tidyr::gather(variable, value, 1:ncol(.))

      # Count numeric
      res$count_num <- ggplot2::ggplot(data = plot_data
                                       , ggplot2::aes(.data$value)
                                       ) +
        ggplot2::geom_histogram() +
        ggplot2::facet_wrap(~ .data$variable
                            , scales = "free"
                            ) +
        ggplot2::labs(title = plot_titles
                      , subtitle = "Histograms of numeric variables"
                      )

      # resp_var vs. Numeric
      plot_data <- dat_exp %>%
        dplyr::select(tidyselect::any_of(var_exp)) %>%
        dplyr::select(where(is.numeric)) %>%
        tidyr::gather(variable, value, 2:ncol(.)) %>%
        dplyr::arrange({{ resp_var }})

      res$y_vs_num <- ggplot2::ggplot(data = plot_data
                                      , ggplot2::aes(x = .data$value
                                                     , y = .data[[resp_var]]
                                                     )
                                      ) +
        ggplot2::geom_point(alpha = 0.5) +
        ggplot2::geom_smooth() +
        ggplot2::facet_wrap(~.data$variable, scales = "free") +
        ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90
                                                           , vjust = 0.5
                                                           )
                       ) +
        ggplot2::labs(title = plot_titles
                      , subtitle = paste0("Numeric variables plotted against response variable ("
                                          , resp_var
                                          , ")"
                                          )
                      )

    }

    plot_data <- dat_exp %>%
      dplyr::mutate(dplyr::across(where(is.character),factor)) %>%
      dplyr::select(where(~is.numeric(.x)|is.factor(.x) & dplyr::n_distinct(.x) < 15)) %>%
      dplyr::mutate(dplyr::across(where(is.factor),factor))

    res$pairs <- GGally::ggpairs(plot_data) +
      ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90
                                                         , vjust = 0.5
                                                         )
                     )


    #-------residuals-------

    if(length(residuals(mod)) == nrow(df)) {

      res$resid <- tibble::tibble(residual = residuals(mod)
                                  , fitted = fitted(mod)
                                  ) %>%
        dplyr::bind_cols(df)


      res$resid_plot <- ggplot2::ggplot(data = res$resid
                                        , ggplot2::aes(x = .data$fitted
                                                       , y = .data$residual
                                                       )
                                        ) +
        ggplot2::geom_point(size = 2) +
        ggplot2::geom_hline(yintercept = 0
                            , linetype = 2
                            , colour = "red"
                            ) +
        ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90
                                                           , vjust = 0.5
                                                           , hjust = 1
                                                           )
                       ) %>%
        ggplot2::scale_colour_viridis_d(end = 0.9)


      res$resid_plot_norm <- ggplot(res$resid
                                    , aes(residual)
                                    ) +
        ggplot2::geom_function(fun = dnorm
                               , colour = "light blue"
                               , size = 1
                              , args = list(mean = mean(res$resid$residual)
                                            , sd = sd(res$resid$residual)
                                            )
                               ) +
        ggplot2::geom_density(colour = "dark blue"
                              , size = 1
                              ) +
        ggplot2::labs(x = "value"
                      , y = "Density"
                      )

      res$resid_plot_num <- if(has_numeric) {

        plot_data <- res$resid %>%
          dplyr::select_if(is.numeric) %>%
          tidyr::pivot_longer(2:ncol(.))

        ggplot2::ggplot(data = plot_data
                        , ggplot2::aes(x = .data$value
                                       , y = .data$residual
                                       )
                        ) +
          ggplot2::geom_point(size = 2) +
          ggplot2::geom_smooth(method = "lm")  +
          ggplot2::geom_hline(yintercept = 0
                              , linetype = 2
                              , colour = "red"
                              ) +
          ggplot2::facet_wrap(~ .data$name
                              , scales = "free_x"
                              ) +
          ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90
                                                             , vjust = 0.5
                                                             , hjust = 1
                                                             )
                         ) +
          ggplot2::scale_colour_viridis_d()

      } else NULL


      res$resid_plot_char <- if(has_character) {

        plot_data <- res$resid %>%
          dplyr::mutate(dplyr::across(where(is.factor),as.character)) %>%
          dplyr::select(1
                        , where(is.character)
                        ) %>%
          tidyr::pivot_longer(2:ncol(.)) %>%
          dplyr::group_by(.data$name) %>%
          dplyr::mutate(levels = dplyr::n_distinct(.data$value)) %>%
          dplyr::ungroup() %>%
          dplyr::filter(levels < max_levels)

        ggplot2::ggplot(data = plot_data
                        , ggplot2::aes(x = .data$value
                                       , y = .data$residual
                                       )
                        ) +
          ggplot2::geom_boxplot() +
          ggplot2::geom_hline(ggplot2::aes(yintercept = 0)
                              , linetype = 2
                              , colour = "red"
                              ) +
          ggplot2::facet_wrap(~ .data$name
                              , scales = "free"
                              ) +
          ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90
                                                    , vjust = 0.5
                                                    , hjust = 1
                                                    )
                         )

      } else NULL

    }


    #---------post explore-------

    if(stats::family(mod)$family == "beta") class(mod) <- unique(c(class(mod),"betareg"))

    is_binomial_mod <- stats::family(mod)$family == "binomial"

    res$pred <- df %>%
      dplyr::distinct(dplyr::across(tidyselect::any_of(context))) %>%
      dplyr::mutate(list_length = if(has_ll) stats::median(exp(df$log_list_length)) else NULL
                    , log_list_length = if(has_ll) log(list_length) else NULL
                    , col = row.names(.)
                    , success = if(is_binomial_mod) 0 else NULL
                    , !!ensym(binom_denom) := if(is_binomial_mod) 100 else NULL
                    ) %>%
      tidybayes::add_epred_draws(mod
                                 , ndraws = draws
                                 , re_formula = NA
                                 , value = "pred"
                                 ) %>%
      dplyr::ungroup()


    res$res <- res$pred %>%
      dplyr::group_by(dplyr::across(tidyselect::any_of(context))) %>%
      dplyr::summarise(n = dplyr::n()
                       , nCheck = nrow(tibble::as_tibble(mod))
                       , modMean = mean(.data$pred)
                       , modMedian = stats::quantile(.data$pred, 0.5)
                       , modci90lo = stats::quantile(.data$pred, 0.05)
                       , modci90up = stats::quantile(.data$pred, 0.95)
                       , text = paste0(round(.data$modMedian,2)
                                       , " ("
                                       , round(.data$modci90lo,2)
                                       , " to "
                                       , round(.data$modci90up,2)
                                       , ")"
                                       )
                       ) %>%
      dplyr::ungroup()


    #------res plot data-------

    plot_data <- df %>%
      dplyr::distinct(dplyr::across(tidyselect::any_of(context))) %>%
      dplyr::mutate(success = 0
                    , !!ensym(binom_denom) := 100
                    ) %>%
      dplyr::full_join(tibble::tibble(probs = 0.5) %>%
                         {if(has_ll) (.) %>%
                             dplyr::mutate(list_length = purrr::map_dbl(probs
                                                                 , function(x) stats::quantile(unique(exp(df$log_list_length))
                                                                                               , probs = x
                                                                                               )
                                                                 )
                                           , log_list_length = log(list_length)
                                           , length = paste0("At list length quantile "
                                                             , probs
                                                             , " = "
                                                             , list_length
                                                             )
                             ) else (.)
                           }
                       , by = character()
                       ) %>%
      tidybayes::add_epred_draws(mod
                                 , ndraws = draws
                                 , re_formula = NA
                                 , value = "pred"
                                 )

    sub_title <-  if(has_ll) {

      paste0("List length corrected reporting rate.\nDashed red lines indicate years for comparison (see text).")

    } else {

      paste0(mod_type
             , ".\nDashed red lines indicate years for comparison (see text)."
             )

    }

    sub_title_line <- paste0(
      sub_title
      , if(has_ll) {

        paste0("\nLines are "
               , draws
               , " draws from posterior distribution.\n"
               , unique(plot_data$length)
               )

      } else {

        paste0("\nLines are "
               , draws
               , " draws from posterior distribution."
               )

      }
    )

    sub_title_ribbon <- paste0(sub_title
                               , "\nMedian (thick line) and 90% credible intervals (shaded)."
                               )

    #-------res plot_line-----------

    p <- ggplot2::ggplot(data = plot_data
                         , ggplot2::aes(x = .data[[time_var]]
                                        , y = .data[[resp_var]]
                                        )
                         ) +
      ggplot2::geom_line(ggplot2::aes(y = .data$pred
                                      , group = .draw
                                      )
                         , alpha = 0.5
                         ) +
      ggplot2::geom_vline(xintercept = tests$year
                          , linetype = 2
                          , colour = "red"
                          ) +
      ggplot2::facet_wrap(as.formula(paste0("~ "
                                            , geo_var
                                            )
                                     )
                          ) +
      ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90
                                                , vjust = 0.5
                                                , hjust = 1
                                                )
                     ) +
      ggplot2::labs(title = plot_titles
                    , subtitle = sub_title_line
                    )

    if(has_ll) {

      p <- p +
        ggplot2::geom_point(data = df
                             ,ggplot2::aes(x = .data$year
                                           , y = .data[[resp_var]]
                                           , colour = exp(.data$log_list_length)
                                           )
                             , width = 0.1
                             , height = 0.05
                             ) +
        ggplot2::scale_colour_viridis_c() +
        ggplot2::labs(colour = "List length")

    }

    if(!has_ll) {

      p <- p +
        ggplot2::geom_point(data = df
                             , ggplot2::aes(x = .data$year
                                           , y = .data[[resp_var]]
                                           , colour = .data[[binom_denom]]
                                           )
                             , width = 0.1
                             , height = 0.01
                             ) +
        ggplot2::scale_colour_viridis_c()

    }

    res$plot_line <- p


    #------res plot_ribbon-------

    p <- ggplot2::ggplot() +
      ggplot2::geom_ribbon(data = res$res
                           , ggplot2::aes(.data[[time_var]]
                                          , .data$modMean
                                          , ymin = .data$modci90lo
                                          , ymax = .data$modci90up
                                          )
                           , alpha = 0.4
                  ) +
      ggplot2::geom_line(data = res$res
                         , ggplot2::aes(x = .data[[time_var]]
                                        , y = .data$modMean
                                        )
                         , linetype = 1
                         , size = 1.5
                         ) +
      ggplot2::geom_vline(xintercept = tests$year
                          , linetype = 2
                          , colour = "red"
                          ) +
      ggplot2::facet_wrap(as.formula(paste0("~ "
                                            , geo_var
                                            )
                                     )
                          ) +
      ggplot2::labs(title = plot_titles
                    , subtitle = sub_title_ribbon
                    )

    if(has_ll) {

      p <- p +
        ggplot2::geom_point(data = df
                  , ggplot2::aes(.data[[time_var]]
                                 , .data[[resp_var]]
                                 , colour = exp(.data$log_list_length)
                                 )
                  , width = 0.1
                  , height = 0.05
                  ) +
      ggplot2::scale_colour_viridis_c() +
      ggplot2::labs(colour = "List length")

    }

    if(!has_ll) {

      p <- p +
        ggplot2::geom_point(data = df
                             , ggplot2::aes(.data[[time_var]]
                                            , .data[[resp_var]]
                                            , colour = .data[[binom_denom]]
                                            )
                             , width = 0.1
                             , height = 0.05
                             ) +
        ggplot2::scale_colour_viridis_c() +
        ggplot2::labs(colour = binom_denom)

    }

    res$plot_ribbon <- p


    #------year difference df-----------

    filt_preds <- df %>%
      dplyr::distinct(dplyr::across(any_of(geo_var))
                      , dplyr::across(any_of(time_var))
                      ) %>%
      dplyr::group_by(dplyr::across(any_of(geo_var))) %>%
      dplyr::filter(!!rlang::ensym(time_var) == min(!!rlang::ensym(time_var)) |
                      !!rlang::ensym(time_var) == max(!!rlang::ensym(time_var))
                    ) %>%
      dplyr::mutate(minmax = dplyr::case_when(!!rlang::ensym(time_var) == min(!!rlang::ensym(time_var)) ~ "min"
                                              , !!rlang::ensym(time_var) == max(!!rlang::ensym(time_var)) ~ "max"
                                              , TRUE ~ "neither"
                                              )
                    ) %>%
      dplyr::ungroup() %>%
      tidyr::pivot_wider(values_from = !!rlang::ensym(time_var)
                         , names_from = "minmax"
                         ) %>%
      na.omit() %>%
      dplyr::filter(max >= recent
                    , min <= reference
                    )


    if(any(!limit_preds, nrow(filt_preds) > 0)) {

      res$year_diff_df <- df %>%
        dplyr::distinct(dplyr::across(tidyselect::any_of(context[!context %in% time_var]))) %>%
        dplyr::full_join(tests
                         , by = character()
                         ) %>%
        dplyr::mutate(list_length = if(has_ll) median(exp(df$log_list_length)) else NULL
                      , log_list_length = if(has_ll) log(list_length) else NULL
                      , success = 0
                      , binom_denom := 100
                      , nCheck = nrow(tibble::as_tibble(mod))
                      , mod_type = mod_type
                      , taxa = taxa
                      , common = common
                      ) %>%
        {if(limit_preds) (.) %>% dplyr::inner_join(filt_preds) else (.)} %>%
        tidybayes::add_epred_draws(mod
                                   , ndraws = draws
                                   , re_formula = NA
                                   , value = "pred"
                                   ) %>%
        dplyr::ungroup() %>%
        dplyr::select(tidyselect::any_of(context)
                      , .data$type
                      , .data$pred
                      , .data$.draw
                      ) %>%
        tidyr::pivot_wider(names_from = "type"
                           , values_from = c(tidyselect::any_of(time_var)
                                             , "pred"
                                             )
                           ) %>%
        #setNames(gsub("\\d{4}", "", names(.))) %>%
        dplyr::mutate(diff = as.numeric(pred_recent - pred_reference)) %>%
        dplyr::filter(!is.na(diff))


        #-------year difference res---------

        res$year_diff_res <- res$year_diff_df %>%
          dplyr::group_by(dplyr::across(tidyselect::any_of(context))) %>%
          dplyr::summarise(nCheck = dplyr::n()
                           , lower = sum(diff < 0) / nCheck
                           , higher = sum(diff > 0) / nCheck
                           , meanDiff = mean(diff)
                           , medianDiff = median(diff)
                           , cilo = stats::quantile(diff, probs = 0.05)
                           , ciup = stats::quantile(diff, probs = 0.95)
                           , reference = unique({{ reference_year }})
                           , recent = unique({{ recent_year }})
                           ) %>%
          dplyr::ungroup() %>%
          envFunc::add_likelihood(col = "lower") %>%
          tidyr::unnest(cols = c(likelihood)) %>%
          dplyr::mutate(text = paste0(tolower(likelihood)
                                      , " to be lower in "
                                      , {{ geo_var }}
                                      , " ("
                                      , 100*round(lower,2)
                                      , "% chance)"
                                      )
                        , text = gsub("in Kangaroo Island","on Kangaroo Island",text)
                        )

        #------year difference plot--------

        plot_data <- res$year_diff_df %>%
          dplyr::group_by(dplyr::across(tidyselect::any_of(geo_var))) %>%
          dplyr::mutate(lower = sum(diff < 0) / dplyr::n()) %>%
          dplyr::ungroup() %>%
          dplyr::mutate(likelihood = purrr::map(lower
                                         , ~cut(.
                                                , breaks = c(0,envFunc::lulikelihood$maxVal)
                                                , labels = envFunc::lulikelihood$likelihood
                                                , include.lowest = TRUE
                                                )
                                         )
                        ) %>%
          tidyr::unnest(cols = c(likelihood)) %>%
          dplyr::mutate(likelihood = forcats::fct_expand(likelihood
                                                         ,levels(envFunc::lulikelihood$likelihood)
                                                         )
                        )

        res$year_diff_plot <- ggplot2::ggplot(data = plot_data
                                              , ggplot2::aes(.data$diff
                                                             , .data[[geo_var]]
                                                             , fill = .data$likelihood
                                                             )
                                              ) +
          ggridges::geom_density_ridges() +
          ggplot2::geom_vline(xintercept = 0
                     , linetype = 2
                     , colour = "red"
                     ) +
          ggplot2::scale_fill_viridis_d(drop = FALSE) +
          ggplot2::labs(title = plot_titles
                        , subtitle = paste0("Difference in "
                                            , recent
                                            , " "
                                            , tolower(mod_type)
                                            , " compared to "
                                            , reference
                                            )
                        , x = "Difference"
                        #, y = "IBRA Subregion"
                        , fill = "Likelihood of decrease"
                        , caption = paste0("Red dotted line indicates no change between "
                                           , reference
                                           , " and "
                                           , recent
                                           )
                        )

    }

    stuff <- ls() %>% grep("res", ., value = TRUE, invert = TRUE)

    rm(list = stuff)

    gc()

    return(res)

  }
Acanthiza/envTrend documentation built on Nov. 16, 2023, 4:24 p.m.