
Defines functions fix_fit .onLoad

Documented in fix_fit

fastai2 <- NULL
env <- new.env()

.onLoad <- function(libname, pkgname) {

  fastai2 <<- reticulate::import("fastai", delay_load = list(
    priority = 10,
    environment = "r-fastai"

  cran_ = !file.exists("C:/Users/ligges/AppData/Local/r-miniconda/envs/r-reticulate/python.exe")

  if(cran_) {
    if(reticulate::py_module_available('matplotlib')) {
      env[["matplot"]] <- reticulate::import('matplotlib')
      env[["warnings"]] <- reticulate::import('warnings')
      env[['fix_fit']] <- fix_fit
      try(env[['fix_fit']](), TRUE)
      #env[['bs_find']] <- bs_finder


#' Fix fit
#' @param disable_graph to remove dynamic plot, by default is FALSE
#' @return None
#' @export
fix_fit = function(disable_graph = FALSE) {

  fastaip <- reticulate::import('fastprogress')

  fastaip$progress_bar$fill = ''

  if(!disable_graph) {
    fastaip$fastprogress$WRITER_FN = function(value, ..., sep=' ', end='\n', flush = FALSE) {
      args = list(
        value, ...)

      #save(args, file = "~/Downloads/args.RData")

      text = unlist(strsplit(trimws(args[[1]]),' '))
      text = text[!text=='']
      # drop
      text_ = grepl( "Epoch", args[[1]], fixed = TRUE)

      text_2 = grepl('\\|\\-+\\|',args[[1]], fixed = FALSE)

      if(!text_ & !text_2) {
        lgl = grepl('epoch', text)

        # temp file
        nm = paste(tempdir(),'to_df.csv',sep = '/')

        # save column names // write to temp dir
        if(lgl[1]) {
          # remove old train from cache
          # write
          tmm = tempdir()
          tmp_name = paste(tmm,"output.txt",sep = '/')
          writeLines(text, fileConn)

        if(lgl[1]) {
          df <- data.frame(matrix(ncol = length(text), nrow = 0))
          colnames(df) <- text
          # add row for tidy output
          df[nrow(df) + 1,] = as.character(round(stats::runif(ncol(df)),3))
          df = knitr::kable(df, format = "pandoc")
          cat(df[1:2], sep="\n")

          if(interactive()) {
            try(dev.off(), TRUE)

          set_theme = function() {
          invisible(try(set_theme(), TRUE))

        } else {

          ## restore from temp
          tmm = tempdir()
          tmp_name = paste(tmm,"output.txt",sep = '/')
          text2 = readLines(paste(tmm,"output.txt",sep = '/'))
          df <- data.frame(matrix(ncol = length(text2), nrow = 0))
          colnames(df) <- text2

          # add actual row
          silent_fun = function() {
            df[nrow(df) + 1,] = text
            df = knitr::kable(df, format = "pandoc")
            cat(df[3], sep="\n")
          prnt = try(silent_fun(), TRUE)
          if(!inherits(prnt, 'try-error')) {
            # if !fail then repeat and collect data to temp dir
            df[nrow(df) + 1,] = text
            to_df = df
            to_df$time = NULL
            # if file is there, then read and row bind
            if(file.exists(nm)) {
              to_df_orig = read.csv(nm)
              to_df = rbind(to_df_orig, to_df)
              to_df$time = NULL
            write.csv(to_df, nm, row.names = FALSE)
            # visualize but first make data frame numeric in case of character

            to_df = read.csv(nm)

            loss_names = grepl('loss', names(to_df))

            losses = cbind(to_df[1], to_df[loss_names])
            metrics_ = cbind(to_df[1], to_df[!names(to_df) %in% names(losses)])
            ## ggplot
            column_fun <- function(column_name, df, yaxis, colour) {

              lp <- ggplot2::ggplot(df, ggplot2::aes_string('epoch'))

              strings = column_name
              if(length(strings) > 1) {
                for (i in 1:length(strings)) {
                  variable = ggplot2::sym(strings[i])
                  lp <- lp + ggplot2::geom_line(ggplot2::aes(y = !!variable, colour = !!strings[i])) +
                    # add points
                    ggplot2::geom_point(ggplot2::aes(y = !!variable, colour = !!strings[i]))
                lp <- lp +
                  ggplot2::scale_x_continuous(breaks = seq(min(df)-1, max(df), 1)) +
                  ggplot2::ylab(yaxis) + ggplot2::labs(colour = yaxis) + ggplot2::theme(legend.position="bottom",
                                                                                        legend.margin=ggplot2::margin(t = 0, unit='cm'),
              } else {
                variable <- ggplot2::sym(column_name)
                strings = column_name
                lp = lp + ggplot2::geom_line(ggplot2::aes(y = !!variable, colour = column_name)) +
                  # add points
                  ggplot2::geom_point(ggplot2::aes(y = !!variable, colour = column_name)) +
                  ggplot2::scale_x_continuous(breaks = seq(min(df)-1, max(df), 1)) +
                  ggplot2::ylab(yaxis) + ggplot2::labs(colour = yaxis) + ggplot2::theme(legend.position="bottom",
                                                                                        legend.margin=ggplot2::margin(t = 0, unit='cm'),

            result_fun = function() {
              if(nrow(to_df)>1) {
                if(ncol(metrics_)>1 & ncol(losses)>1) {
                  p1 = column_fun(names(metrics_)[!names(metrics_) %in% 'epoch'], metrics_, 'Metrics', 'darkgreen')
                  p2 = column_fun(names(losses)[!names(losses) %in% 'epoch'], losses, 'Loss', 'red')

                  figure <- ggpubr::ggarrange(p2, p1,
                                              labels = c("", ""),
                                              ncol = 1, nrow = 2)
                } else if (ncol(metrics_)>1 & ncol(losses)<=1) {
                  p1 = column_fun(names(metrics_)[!names(metrics_) %in% 'epoch'], metrics_, 'Metrics', 'darkgreen')

                } else if (ncol(metrics_)<=1 & ncol(losses)>1) {
                  p2 = column_fun(names(losses)[!names(losses) %in% 'epoch'], losses, 'Loss', 'red')

                } else {
              paste('done plot')

              try(result_fun(), TRUE)


  } else {

    fastaip$fastprogress$WRITER_FN = function(value, ..., sep=' ', end='\n', flush = FALSE) {
      args = list(
        value, ...)

      text = unlist(strsplit(trimws(args[[1]]),' '))
      text = text[!text=='']

      text_ = grepl( 'Epoch', args[[1]], fixed = TRUE)
      text_2 = grepl('\\|\\-+\\|',args[[1]], fixed = FALSE)

      if (!text_ & !text_2) {
        lgl = grepl('epoch', text)

        # save column names // write to temp dir
        if(lgl[1]) {
          tmm = tempdir()
          tmp_name = paste(tmm,"output.txt",sep = '/')
          fileConn <- file(tmp_name)
          writeLines(text, fileConn)

        if(lgl[1]) {
          df <- data.frame(matrix(ncol = length(text), nrow = 0))
          colnames(df) <- text
          # add row for tidy output
          df[nrow(df) + 1,] = as.character(round(stats::runif(ncol(df)),3))
          df = knitr::kable(df, format = "pandoc")
          cat(df[1:2], sep="\n")
        } else {
          ## restore from temp
          tmm = tempdir()
          tmp_name = paste(tmm,"output.txt",sep = '/')
          text2 = readLines(paste(tmm,"output.txt",sep = '/'))
          df <- data.frame(matrix(ncol = length(text2), nrow = 0))
          colnames(df) <- text2
          # add actual row
          silent_fun = function() {
            df[nrow(df) + 1,] = text
            df = knitr::kable(df, format = "pandoc")
            cat(df[3], sep="\n")
          try(silent_fun(), TRUE)





Try the fastai package in your browser

Any scripts or data that you put into this service are public.

fastai documentation built on June 22, 2024, 11:15 a.m.