R/predictions.R

Defines functions as_array predict.fastai.learner.Learner

Documented in as_array predict.fastai.learner.Learner

#' @title Predict
#'
#' @description Prediction on `item`, fully decoded, loss function decoded and probabilities
#'
#' @param object the model
#' @param row row
#' @return data frame
#' @param ... additional arguments to pass
#' @export
predict.fastai.learner.Learner <- function(object, row, ...) {

  #object$predict(reticulate::r_to_py(row)$iloc[0])[[3]]$numpy()
  # remove metric to obtain prediction

  error_check = try(object$metrics[0],silent = TRUE)

  if(!inherits(error_check,'try-error')) {
    object$metrics <- object$metrics[0]
  }

  test_dl = object$dls$test_dl(row)
  predictions = object$get_preds(dl = test_dl, with_decoded = TRUE)

  return(predictions)

}

#' @title As_array
#'
#' @param tensor tensor object
#' @return array
#'
#' @export
as_array = function(tensor) {
  as.array(tensor$cpu()$numpy())
}

Try the fastai package in your browser

Any scripts or data that you put into this service are public.

fastai documentation built on March 31, 2023, 11:41 p.m.