R/f_plot.R

#'@import nycflights13
#'@import ggplot2
#'@import RColorBrewer
#'@importFrom stringr str_extract_all

#' @title generate a most distinctive color scale
#' @description based on RColorBrewer colours of length 74
#' for RGB colors see {rapidtables}(https://www.rapidtables.com/web/color/index.html).
#' Basically strings a couple of RColorBrewer palettes together.
#' @param greys boolean, include grey colors, Default: TRUE
#' @param reds boolean, include red colors, Default: TRUE
#' @param blues boolean, include blue colors, Default: TRUE
#' @param greens boolean, include green colors, Default: TRUE
#' @param faint boolean, include faint colors, Default: TRUE
#' @param only_unique boolean, do not allow color repetitions, Default: FALSE
#' @return vector with HEX colours
#' @rdname f_plot_col_vector74
#' @export
#' @import RColorBrewer
#'
f_plot_col_vector74 = function( greys = T
                                , reds = T
                                , blues = T
                                , greens = T
                                , faint = T
                                , only_unique = F ){

  library(RColorBrewer)

  n <- 60
  qual_col_pals = brewer.pal.info[brewer.pal.info$category == 'qual',]
  col_vector = unlist(mapply(brewer.pal, qual_col_pals$maxcolors, rownames(qual_col_pals)))

  # create tibble with RGB
  col = tibble( hex = col_vector ) %>%
    mutate( rgb = map(hex, col2rgb)
            , rgb = map(rgb, t)
            , rgb = map(rgb, as_tibble) ) %>%
    unnest( rgb )

  # for greys R == G == B
  if( greys == F ){
    col = col %>%
      filter( red != green, red != blue)
  }

  if( reds == F ){
    col = col %>%
      filter( ! ( blue < 50 & green < 50  & red > 200 ) )
  }

  if( greens == F ){
    col = col %>%
      filter( ! ( green > red & green > blue ) )
  }

  if( blues == F ){
    col = col %>%
      filter( ! ( blue > green & green > red) )
  }

  if( faint == F){
    col = col %>%
      filter( (red + green + blue ) < 600 )
  }

  if( only_unique ) col_hex = unique(col$hex)

  return( col$hex  )

}

#' @title color code all variables in a data_ls list.
#' @description color coding is stable the same data_ls list gets the same
#'   coding with every function call. Assigns the same colors to the boxcox
#'   transformed and untransformed variant of a variable.
#'@param data_ls data_ls object generated by f_clean_data(), or a named list
#'   list( data = <dataframe>, numericals = < vector with column names of
#'   numerical columns>)
#' @param col_vector character vector denoting Hexcode colors, Default: f_plot_col_vector74()
#' @return tibble
#' \itemize{
#'   \item{variable}{}
#'   \item{color}{HEX code color}
#'   }
#'
#' @examples
#' f_clean_data(mtcars) %>%
#'   f_boxcox() %>%
#'   f_plot_color_code_variables() %>%
#'   print()
#' @seealso \code{\link[stringr]{str_replace_all}}
#' @rdname f_plot_color_code_variables
#' @export
#' @importFrom stringr str_replace_all

f_plot_color_code_variables = function(data_ls
                                       , col_vector = f_plot_col_vector74() ){

  multiple = length( data_ls$all_variables ) / length(col_vector)
  multiple = ceiling(multiple)

  col_vector = rep(col_vector, multiple)

  tib = tibble( variables = data_ls$all_variables
                , color   = col_vector[1:length(data_ls$all_variables)]
                )

  # code boxcox variables with the same colours as their untransformed counterparts
  if( ! is.null(data_ls$boxcox_names) ){

    names_boxcox = data_ls$boxcox_names %>%
      stringr::str_replace_all('_boxcox', '')

    tib_boxcox = tib %>%
      filter( variables %in% names_boxcox ) %>%
      mutate( variables = paste0( variables, '_boxcox') )

    tib = tib %>%
      bind_rows(tib_boxcox)

  }

  return(tib)
}

#' @title adjust length of color vector, by repeating colors
#' @param n length, Default: 74
#' @param col_vector vector containing colors, Default: f_plot_col_vector74()
#' @return vector containing colors of specified length
#' @examples
#'
#' length( f_plot_adjust_col_vector_length(100) )
#'
#' @rdname f_plot_adjust_col_vector_length
#' @export
f_plot_adjust_col_vector_length = function( n = 74, col_vector = f_plot_col_vector74() ){

  multiple = n / length(col_vector)
  multiple = ceiling(multiple)

  col_vector = rep( col_vector, multiple )

  return( col_vector[1:n] )

}



#'@title Plot Histograms
#'@description Function plots smart histograms for variables in a data_ls list
#'  generated by f_clean_data(). It supports three types of histograms: Bar
#'  histograms, density histograms and violin plots. We can further specify a
#'  categorical variable to group on. The function defaults to a sensible
#'  standard output if key word arguments are not applicable for variable type.
#'  Thus we can easily pipe through long lists of variables and thus generate
#'  histograms for all variables in the input (see examples).
#'@param variable character vector naming the variable to be plotted
#'@param data_ls data_ls object generated by f_clean_data(), or a named list
#'  list( data = <dataframe>, numericals = < vector with column names of
#'  numerical columns>)
#'@param group character vector naming the column to be used as grouping
#'  variable, Default: NULL
#'@param graph_type one of c("violin", "bar", "line"), Default: 'violin'
#'@param y_axis one of c("count", "density"), Default: c("count", "density")
#'@param auto_range boolean, Default: T
#'@param n_breaks integer , Default: 30
#'@param x_min double, requires aut_range == F, Default: 0
#'@param x_max double, requires aut_range == F,  Default: 100
#'@param title character vector plot title
#'@param rug boolean
#'@param add character vector one_of( c('mean','median','none') ), This feature
#'  is currently enabled because it does not seem to be supported by ggpubr
#'  under R 3.5, Default: 'mean'
#'@param col_vector vector with RGB colors, Default:
#'  f_plot_adjust_col_vector_length(100, RColorBrewer::brewer.pal(name =
#'  "Dark2", n = 8))
#'@param p_val boolean, Default: T
#'@param y_max double, requires aut_range == F,  Default: 100
#'@param ... additional arguments passed to labs()
#'@return plot object
#' @examples
#' \dontrun{
#' #'
#' #plot single variable
#' data_ls = f_clean_data(mtcars)
#' f_plot_hist('disp', data_ls)
#' f_plot_hist('disp', data_ls, add = 'median')
#' f_plot_hist('disp', data_ls, add = 'none')
#' f_plot_hist('disp', data_ls, y_axis = 'density')
#' f_plot_hist('cyl', data_ls , group = 'gear' )
#' f_plot_hist('cyl', data_ls , group = 'gear', y_axis = 'density' )
#' f_plot_hist('cyl', data_ls, y_axis = 'density' )
#' f_plot_hist('cyl', data_ls, y_axis = 'count' )
#' f_plot_hist('disp', data_ls, graph_type = 'line', group = 'cyl')
#' f_plot_hist('disp', data_ls, graph_type = 'bar', group = 'cyl')
#' f_plot_hist('disp', data_ls, graph_type = 'violin', group = 'cyl'
#'              , caption ='caption', title = 'title', subtitle = 'subtitle')
#'
#'#plot all variables
#'vars = data_ls$all_variables[ data_ls$all_variables != 'cyl' ] %>%
#'  map( f_plot_hist, data_ls, group = 'cyl')
#'vars
#'}
#'@rdname f_plot_hist
#'@export
#'@importFrom RColorBrewer brewer.pal
#'@importFrom ggpubr ggdensity gghistogram ggviolin stat_compare_means
#'@importFrom stringr str_extract_all str_c
#'
#'
f_plot_hist = function(variable
                       , data_ls
                       , group = 'None'
                       , graph_type = 'violin'
                       , y_axis = 'count'
                       , auto_range = T
                       , n_breaks = 30
                       , rug = T
                       , x_min = 0
                       , x_max = 100
                       , y_max = 100
                       , title = ''
                       , col_vector = f_plot_adjust_col_vector_length( 100, RColorBrewer::brewer.pal(name = 'Dark2', n = 8) )
                       , p_val = T
                       , add = 'mean'
                       , ...
                       ){

  data          = data_ls$data
  categoricals  = data_ls$categoricals
  all_variables = data_ls$all_variables
  numericals    = data_ls$numericals

  #y-axis
  if(y_axis == 'density' | graph_type == 'line') {
    y_axis = '..density..'
  }  else{
    y_axis = '..count..'
  }

  #group

  if(group == 'None') group = NULL

  if( graph_type == 'density') graph_type = 'line'

  if( graph_type == 'violin' & is.null(group) ) graph_type = 'bar'

  #violin plots require a grouping variable, defaults to bar histogram

  #geom_freqpoly
  if( variable %in% numericals & graph_type == 'line' & ! is.null(group) ){

    p = ggpubr::ggdensity( data = data
                   , x = variable
                   , y = y_axis
                   , color = group
                   , fill = group
                   , palette = col_vector
                   , rug = rug
                   #, add = add
    )

  }


  #geom_histo
  if( variable %in% numericals & graph_type == 'bar' & ! is.null(group) ){

    p = ggpubr::gghistogram( data = data
                 , x = variable
                 , y = y_axis
                 , color = group
                 , fill = group
                 , palette = col_vector
                 , rug = rug
                 #, add = add
                 , bins = n_breaks
    )
  }

  #geom_violin
  if(variable %in% numericals & graph_type == 'violin' & ! is.null(group) ){

    p = ggplot(data, aes_string( x = group, y = variable) ) +
      geom_violin( aes_string( fill = group), alpha = 0.5, trim = F ) +
      geom_boxplot( width = 0.25) +
      scale_fill_manual( values = col_vector )

    if( p_val ){
      p = p +
        ggpubr::stat_compare_means(comparisons = f_plot_generate_comparison_pairs(data, variable, group)
                                 , label = "p.signif"
                                 , method = "wilcox.test")
    }

  }

  # default if no grouping
  if( variable %in% numericals & is.null(group) & y_axis == '..count..'){

    p = ggpubr::gghistogram( data = data
                   , x = variable
                   , y = y_axis
                   , rug = rug
                   #, add = add
                   , palette = col_vector
                   , bins = n_breaks
                   ) +
      # the second histogram is necessary for having better coloring options
      geom_histogram( fill = col_vector[1]
                      , color = 'black'
                      , bins = n_breaks)
  }

  if( variable %in% numericals & is.null(group) & y_axis == '..density..'){

    p = ggpubr::ggdensity( data = data
                             , x = variable
                             , y = y_axis
                             , rug = rug
                             #, add = add
                             , palette = col_vector
    ) +
      geom_histogram( fill = col_vector[1], color = 'black' , alpha = 0.5)
  }

  # add x range for regular histograms

  if(variable %in% numericals & auto_range == F & !graph_type == 'violin'){

    p = p +
      xlim( c( as.numeric(x_min), as.numeric(x_max)) ) +
      ylim( c(0, y_max) )
  }

  # add y range for violin plot

  if(variable %in% numericals & auto_range == F & graph_type == 'violin'){

    p = p +
      ylim( c( as.numeric(x_min), as.numeric(x_max)) ) +
      xlim( c(0, y_max) )
  }


  # categoricals ----------------------------------------------------------------------------



  #geom_bar
  if(variable %in% categoricals ){

    #calculate stats

    if( ! is.null(group) ){
      data_gr = data %>%
        #filter( ! ( gear == 4 & cyl == 4) ) %>%
        group_by( !! as.name(group), !! as.name(variable) ) %>%
        count() %>%
        ungroup() %>%
        complete( !! as.name(group), !! as.name(variable) ) %>%
        group_by( !! as.name(group) ) %>%
        mutate( Percent = n / sum(n, na.rm = T )
                , Count = ifelse( is.na(n), 0, n )
                , Percent = ifelse( is.na(Percent), 0, Percent )
        )
    }else{
      data_gr = data %>%
        group_by( !! as.name(variable) ) %>%
        count() %>%
        ungroup() %>%
        mutate( Percent = n / sum(n, na.rm = T )
                , Count = n )
    }

    if(y_axis == '..density..'){
      y = 'Percent'
    }else{
      y = 'Count'
    }

    if( ! is.null(group) ){

      p = data_gr %>%
        ggplot( ) +
        geom_col( aes_string(x = variable, y = y, fill = group)
                  , position = 'dodge')

    } else {

      p = data_gr %>%
        ggplot( ) +
        geom_col( aes_string(x = variable, y = y, fill = variable) ) +
        theme( legend.position = 'None')
    }

    p = p +
      theme( axis.text.x = element_text( angle = 90 ) ) +
      scale_fill_manual( values = col_vector )

    if(p_val & ! is.null(group) ){

      suppressWarnings({
        t = chisq.test( data[[group]], data[[variable]])
      })

      if( t$p.value <= 0.05 )

      p = p +
        labs( subtitle = paste('Chi-Square Test', f_stat_stars(t$p.value) ) )
    }

  }

  # set theme -------------------------------------------------------------------------------

  if( graph_type == 'violin' & ! is.null(group) & ! variable %in% categoricals){
    y_axis_str = variable
  }else if( y_axis == '..density..'  & variable %in% categoricals ){
    y_axis_str = 'Percent'
  }else if( y_axis == '..density..' ){
    y_axis_str = 'Density'
  }else if(y_axis == '..count..') {
    y_axis_str = 'Count'
  }

  p = p +
    labs( title = title, y = y_axis_str, ...)

  p = p +
    theme_gray()

  if(  ( ! is.null(group)  & ( variable %in% categoricals | graph_type != 'violin' ) ) ){
    p = p +
      theme( legend.position = 'right')
  } else {
    p = p +
      theme( legend.position = 'None')
  }

  return(p)

}


#'@title plot variable distribution over time as reduced overlapping boxplots
#'@description It is difficult to compare two timeerieses when you have more
#'  than one observation per timepoint without reducing all observations to a
#'  single statistical variable such as average or mean. This visualisation
#'  plots the median and the upper and lower 25% percentile instead drawing a
#'  contineuos line between the medians ot the timepoints.
#'@param variable character vector naming the variable to be plotted
#'@param data_ls data_ls object generated by f_clean_data(), or a named list
#'  list( data = <dataframe>, numericals = < vector with column names of
#'  numerical columns>)
#'@param group character vector naming the column to be used as grouping
#'  variable, Default: NULL
#'@param time_variable character vector naming the timevariable to be plotted
#'@param time_variable_as_factor If TRUE will convert time_variable to a factor,
#'  this will equalize the distance between timepoints on the plots and drops
#'  the connective line between timepoints, Default: F
#'@param normalize If TRUE y variable will be divided by x variable, usefull if
#'  y variable represents a cumulated sum, Default: F
#'@param time_unit character vector used as an x-axis lable , Default: 'day'
#'@return plot
#' @examples
#' \dontrun{
#'
#' set.seed(1)
#' data       = dplyr::sample_n( nycflights13::flights, 1000 )
#' data$is_ua = ifelse( data$carrier == 'UA', 'UA', 'other')
#' data$date  = data$year * 10000 + data$month * 100 + data$day
#' data$date  = lubridate::as_date( data$date )
#' data_ls    = f_clean_data( data, replace_neg_values_with_zero = F)
#' f_plot_time( 'arr_delay', 'month', data_ls, group = 'is_ua', time_unit = 'month', time_variable_as_factor = T)
#'
#' #without grouping
#' f_plot_time( 'arr_delay', 'month', data_ls, time_unit = 'month', time_variable_as_factor = F)
#'
#' }
#'
#'@rdname f_plot_time
#'@export
f_plot_time = function(variable
                       , time_variable
                       , data_ls
                       , time_variable_as_factor = F
                       , group = NULL
                       , normalize = F
                       , time_unit = 'day'){

  data       = data_ls$data
  numericals = data_ls$numericals


  time_variable_sym  <- as.name(time_variable)
  variable_sym       <- as.name(variable)


  #make sure variables are in correct format

  if( time_variable_as_factor == T){

    data = data %>%
      mutate( !!time_variable_sym := as.factor(!!time_variable_sym) )

  } else {

    data = data %>%
      mutate( !!time_variable_sym    := as.character(!!time_variable_sym)
               , !!time_variable_sym := as.numeric(!!time_variable_sym) )

  }

  if(!variable %in% numericals){
    stop('variable not in numericals')
  }

  # normalize data if appliquable

  if(normalize == T){
    data = data %>%
      filter( (!!time_variable_sym) != 0) %>%
      mutate( !!variable_sym := (!!variable_sym) / (!!time_variable_sym) )

    y_title = paste('Median of', variable, 'per', time_unit)

  }else{
    y_title = paste('Median of', variable)
  }

  boxplot_sum = data %>%
    select( one_of( c(time_variable, variable, group ) ) ) %>%
    group_by_at( vars( one_of(time_variable, group ) ) ) %>%
    nest() %>%
    mutate( data            = map( data, variable )
            , boxplot       = map( data, boxplot.stats)
            , boxplot_stats = map( boxplot, 'stats')
            , box_min       = map_dbl( boxplot_stats, function(x) x[1] )
            , box_min_box   = map_dbl( boxplot_stats, function(x) x[2] )
            , box_median    = map_dbl( boxplot_stats, function(x) x[3] )
            , box_max_box   = map_dbl( boxplot_stats, function(x) x[4] )
            , box_max       = map_dbl( boxplot_stats, function(x) x[5] )
    ) %>%
    select( - boxplot_stats, - boxplot )

  p = ggplot(boxplot_sum, aes_string( x      = time_variable
                                     , y     = 'box_median'
                                     , color = group )
  )+
    geom_line( size = 1 )+
    geom_crossbar( aes_string( ymin  = 'box_min_box'
                              , ymax = 'box_max_box'
                              , fill = group
                )
                , stat     = 'identity'
                , alpha    = 0.1
                , position = 'identity') +
    geom_errorbar( aes_string( ymin  = 'box_median'
                              , ymax = 'box_median'
                            )
                , size = 1) +
    geom_point( size = 2 ) +
    labs( y = y_title
          , x = time_unit
          , subtitle = 'Boxes denote upper and lower 25% percentile')

  return(p)

}


#' @title generate a separate html file from a list of various objects
#' @description lists of graphical objects like html(taglists), plots, tabplots,
#'   grids can be converted to html files
#' @param obj_list htmltools::tagList
#' @param type one of c('taglist','plots','tabplots','grids' ,
#'   'model_performance') some templates take additional arguments via the ...
#'   argument \describe{
#'   \item{taglist}{taglist ceated with htmltools::tagList,
#'   a good container for html widgets}
#'   \item{plots}{a list with ggplot objects,
#'   takes additional arguments: \emph{fig.height: Default 5, fig.width: Default
#'   7} }
#'   \item{tabplots}{ a list of objects created with tabplot::tableplot,
#'   takes additional arguments: \emph{fig.height: Default 5, fig.width: Default
#'   7, titles: list of titles must be same length as obj_list} }
#'   \item{grids}{a
#'   list of grids created with gridExtra::arrangeGrob, takes additional
#'   argument: \emph{height: Default 30} }
#'   \item{model_performance}{takes a
#'   taglist created with f_predict_plot_model_performance_regression, takes the
#'   additional arguments: \emph{ alluvial, plot objec created with
#'   f_predict_plot_regression_alluvials, dist, list of two plots created with
#'   f_predict_plot_regression_distribution, render_points_as_png: Default:
#'   TRUE, takes screenshots of point plots, otherwise plotly will load
#'   all points into memory which is not compatible with large data sets} }
#'   }
#' @param output_file file_name of the html file, without .html suffix
#' @param title character vector of html document title, Default: 'Plots'
#' @param quiet bollean, suppress markdown console print output, Default: FALSE
#' @param ... additional arguments passed to rmarkdown::render argument params
#' @examples
#'
#' # type = taglist---------------------------------------------------------------
#'
#' taglist = f_clean_data(mtcars) %>%
#'   f_boxcox() %>%
#'   f_pca() %>%
#'   f_pca_plot_components()
#'
#' f_plot_obj_2_html(taglist, type = "taglist", output_file =  'test_me', title = 'Plots')
#'
#' file.remove('test_me.html')
#'
#' #type = tabplot-----------------------------------------------------------------
#'
#' form = as.formula('disp~cyl+mpg+hp')
#' pipelearner::pipelearner(mtcars) %>%
#'   pipelearner::learn_models( rpart::rpart, form ) %>%
#'   pipelearner::learn_models( randomForest::randomForest, form ) %>%
#'   pipelearner::learn_models( e1071::svm, form ) %>%
#'   pipelearner::learn() %>%
#'   dplyr::mutate( imp = map2(fit, train, f_model_importance)
#'                  , tabplot = pmap( list( data = train
#'                                          , ranked_variables = imp
#'                                          , response_var = target
#'                                          , title = model
#'                  )
#'                  , f_model_importance_plot_tableplot
#'                  , limit = 5
#'                  )
#'   )  %>%
#'   .$tabplot %>%
#'   f_plot_obj_2_html( type = "tabplots", output_file =  'test_me', title = 'Plots')
#'
#' file.remove('test_me.html')
#'
#' #type = plots --------------------------------------------------------------------
#'
#' data_ls = f_clean_data(mtcars)
#' form = as.formula('disp~cyl+mpg+hp')
#' variable_color_code = f_plot_color_code_variables(data_ls)
#'
#' pipelearner::pipelearner(data_ls$data) %>%
#'  pipelearner::learn_models( rpart::rpart, form ) %>%
#'  pipelearner::learn_models( randomForest::randomForest, form ) %>%
#'  pipelearner::learn_models( e1071::svm, form ) %>%
#'  pipelearner::learn() %>%
#'  dplyr::mutate( imp = map2(fit, train, f_model_importance)
#'                 , plots = pmap( list( m = fit
#'                                         , ranked_variables = imp
#'                                         , title = model
#'                                         )
#'                                    , f_model_plot_variable_dependency_regression
#'                                    , formula = form
#'                                    , data_ls = data_ls
#'                                    , variable_color_code = variable_color_code
#'                                   )
#'  )  %>%
#'  .$plots %>%
#'  f_plot_obj_2_html( type = "plots"
#'                     , output_file =  'test_me'
#'                     , title = 'Plots'
#'                     , fig.width = 30
#'                     , fig.height = 21)
#'
#' file.remove('test_me.html')
#'
#' #type = grids -------------------------------------------------------------------
#'
#' data_ls = f_clean_data(mtcars)
#'
#' form = as.formula('disp~cyl+mpg+hp+am+gear+drat+wt+vs+carb')
#'
#' variable_color_code = f_plot_color_code_variables(data_ls)
#'
#' grids = pipelearner::pipelearner(data_ls$data) %>%
#'   pipelearner::learn_models( rpart::rpart, form ) %>%
#'   pipelearner::learn_models( randomForest::randomForest, form ) %>%
#'   pipelearner::learn_models( e1071::svm, form ) %>%
#'   pipelearner::learn() %>%
#'   dplyr::mutate( imp = map2(fit, train, f_model_importance)
#'                  , range_var = map_chr(imp, function(x) head(x,1)$row_names )
#'                  , grid = pmap( list( m = fit
#'                                       , title = model
#'                                       , variables = imp
#'                                       , range_variable = range_var
#'                                       , data = test
#'                  )
#'                  , f_model_plot_var_dep_over_spec_var_range
#'                  , formula = form
#'                  , data_ls = data_ls
#'                  , variable_color_code = variable_color_code
#'                  , log_y = F
#'                  , limit = 12
#'                  )
#'   )  %>%
#'   .$grid
#'
#' f_plot_obj_2_html( grids
#'                    , type = "grids"
#'                    , output_file = 'test_me'
#'                    , title = 'Grids'
#'                    , height = 30 )
#'
#' file.remove('test_me.html')
#'
#'#' #type = model_performance -------------------------------------------------------
#'
#' form = displacement ~ cylinders + mpg
#'
#' df = ISLR::Auto %>%
#'  mutate( name = paste( name, row_number() ) ) %>%
#'  pipelearner::pipelearner() %>%
#'  pipelearner::learn_models( rpart::rpart, form ) %>%
#'  pipelearner::learn_models( randomForest::randomForest, form ) %>%
#'  pipelearner::learn_models( e1071::svm, form ) %>%
#'  pipelearner::learn() %>%
#'  f_predict_pl_regression( 'name' ) %>%
#'  unnest(preds) %>%
#'  mutate( bins = cut(target1, breaks = 3 , dig.lab = 4)
#'          , title = model )
#'
#' dist = f_predict_plot_regression_distribution(df
#'                                              , col_title = 'title'
#'                                              , col_pred = 'pred'
#'                                              , col_obs = 'target1')
#'
#'
#' alluvial = f_predict_plot_regression_alluvials(df
#'                                               , col_id = 'name'
#'                                               , col_title = 'title'
#'                                               , col_pred = 'pred'
#'                                               , col_obs = 'target1')
#'
#'
#' taglist = f_predict_plot_model_performance_regression(df)
#'
#' f_plot_obj_2_html( taglist
#'                   , type = 'model_performance'
#'                   , output_file = 'test_me'
#'                   , dist = dist
#'                   , alluvial = alluvial
#'                   , render_points_as_png = TRUE
#'                  )
#'
#'
#' file.remove('test_me.html')
#'
#' @rdname f_plot_obj_2_html
#' @export
#' @importFrom readr read_file write_file
#' @importFrom rmarkdown render
#' @importFrom stringr str_replace
#' @import tabplot
#'
f_plot_obj_2_html = function(obj_list
                             , type
                             , output_file
                             , title = 'Plots'
                             , quiet = FALSE
                             , ...){

  file_name_template = switch(type
                             , taglist  = 'taglist_2_html_template.Rmd'
                             , plots    = 'plots_2_html_template.Rmd'
                             , tabplots = 'tabplots_2_html_template.Rmd'
                             , grids    = 'grids_2_html_template.Rmd'
                             , model_performance = 'model_perf_2_html_template.Rmd'
                             )

  if( is.null(file_name_template) ){
    stop( paste( 'No Rmd template found for type:', type) )
  }

  path_template = file.path( system.file(package = 'oetteR')
                             , 'templates'
                             , file_name_template)

  txt = readr::read_file(path_template) %>%
    stringr::str_replace('template', title)

  readr::write_file( txt, file_name_template)

  rmarkdown::render( file_name_template
                     , output_file = paste0( output_file, '.html')
                     , params      = list(obj_list = obj_list, ... )
                     , quiet       = quiet
                     )


  file.remove(file_name_template)

  return( file.path( getwd(), paste0( output_file, '.html')) )

}


#' @title plot prettier dot plot
#' @description color is contineoulsy scaled based on PC1 values and alpha
#'   values depend on point density.
#' @param df datafram containing x,y pairs
#' @param col_x character vector denoting x axis values
#' @param col_y character vector denoting y axis values
#' @param col_facet character vector denoting facetting column
#' @param size size of points, Default: 4
#' @param title character vector, Default: NULL
#' @param x_title character vector, Default: col_x
#' @param y_title character vector, Default: col_y
#' @param ... arguments passed to facet_wrap()
#' @return plot
#' @details Code adapted from \url{https://drsimonj.svbtle.com/pretty-scatter-plots-with-ggplot2}
#' @examples
#' df = ggplot2::diamonds %>%
#'   sample_n(2500)
#' col_x = 'carat'
#' col_y = 'price'
#' col_facet = 'cut'
#'
#' f_plot_pretty_points(df, col_x, col_y, col_facet, title = 'price of diamonds by carat')
#' @seealso \code{\link[fields]{interp.surface}} \code{\link[MASS]{kde2d}}
#' @rdname f_plot_pretty_points
#' @export
#' @importFrom fields interp.surface
#' @importFrom MASS kde2d
#' @importFrom viridis scale_color_viridis
f_plot_pretty_points = function(df
                                , col_x
                                , col_y
                                , col_facet = NULL
                                , size = 4
                                , title = NULL
                                , x_title = col_x
                                , y_title = col_y
                                , ...){


  sym_x = as.name( col_x )
  sym_y = as.name( col_y )

  if( ! is.null(col_facet) ){
    sym_facet = as.name( col_facet )
  }else{
    sym_facet = NULL
  }

  df = select(df, one_of(col_x, col_y, col_facet) )

  #functions do not support tibble
  df = as.data.frame(df)


  add_pc1_and_density = function(df){

    #add first principle component

    pca_ls = f_pca( f_boxcox( f_clean_data( df ) ), include_ordered_categoricals = F )

    df$PC1 = pca_ls$pca$x[['PC1']]


    # add density  for facet

    if ( ! is.null(col_facet) ){

      df = df %>%
        group_by( !! sym_facet )
    }


    df = df %>%
      mutate( density = list(fields::interp.surface( MASS::kde2d( !! sym_x
                                                                  , !! sym_y
                                                                  )
                                                    , data.frame( !! sym_x, !! sym_y )
                                                    )
                              )
              , row_number = row_number()
              , density = unlist(density)[row_number]
              , density = ifelse(density < 1e-25, 1e-25, density) ## ggplot cannot handle extreme values
    )
  }

  safe_log = safely( add_pc1_and_density )
  log = safe_log(df)

  if( is.null(log$result) ){
    df = df %>%
      mutate( density = 1
              , PC1 = 1 )
  }else{
    df = log$result
  }

  p = ggplot(df, aes_string( col_x, col_y, color = 'PC1') ) +
    geom_point( aes( alpha = 1/density ), size = size ) +
    viridis::scale_color_viridis() +
    theme_minimal() +
    theme( legend.position = 'none') +
    labs( title = title, x = x_title, y = y_title )

  if ( ! is.null(col_facet) ){

    form = as.formula( paste0( '~',col_facet) )

    p = p +
      facet_wrap( form, ... )

  }

  p

  return(p)
}

#' @title generates comparison pairs for `ggpubr::stat_compare_means()`
#' @description generates all possible pairs and filters according to wilcox p_value
#' @param data dataframe
#' @param col_var character vector denoting variable column
#' @param col_group character vector denoting grouping column
#' @param thresh double, Default: 0.05
#' @return list
#' @examples
#'
#' f_plot_generate_comparison_pairs( mtcars, 'disp', 'cyl' )
#'
#' @seealso
#'  \code{\link[stringr]{str_split}}
#'  \code{\link[rlang]{UQ}}
#' @rdname f_plot_generate_comparison_pairs
#' @export
#' @importFrom stringr str_split
#' @importFrom rlang UQ
f_plot_generate_comparison_pairs = function(data, col_var, col_group, thresh = 0.05 ){


  compare_means = function(comb){

    sym_group = as.name(col_group)

    lvl1 = stringr::str_split(comb, ',')[[1]][1]
    lvl2 = stringr::str_split(comb, ',')[[1]][2]

    x = data %>%
      filter( rlang::UQ(sym_group) == lvl1 ) %>%
      .[[col_var]]

    y = data %>%
      filter( rlang::UQ(sym_group) == lvl2 ) %>%
      .[[col_var]]

    wilcox_safely = safely(wilcox.test)
    wilcox = wilcox_safely(x,y)

    if( is.null(wilcox$error) ){
      return( wilcox$result$p.value )
    }else{
      return( 1 )
    }

  }

  sym_group = as.name(col_group)

  data = data %>%
    as_tibble() %>%
    select( one_of(col_group, col_var) ) %>%
    mutate( !! sym_group := as.factor( !! sym_group ) )

  lvl = levels( data[[col_group]] )

  suppressWarnings({

    compare = expand.grid( a = lvl
                          , b = lvl
                          , stringsAsFactors = F) %>%
      filter( a != b) %>%
      mutate( comb = map2( a, b, function(x,y) sort( c(x,y) ) )
              , comb = map( comb, paste0, collapse = ',' )
      ) %>%
      unnest( comb ) %>%
      group_by( comb ) %>%
      summarise() %>%
      mutate( p_val  = map( comb, compare_means )
              , comb = stringr::str_split(comb, ',') )   %>%
      filter( p_val <= thresh ) %>%
      .$comb

  })

  return(compare)
}
erblast/oetteR documentation built on May 27, 2019, 12:11 p.m.