R/metric.R

Defines functions LossMetric CorpusBLEUMetric JaccardCoeff Dice foreground_acc SpearmanCorrCoef PearsonCorrCoef R2Score ExplainedVariance exp_rmspe msle mae rmse mse RocAucMulti RecallMulti PrecisionMulti MatthewsCorrCoefMulti JaccardMulti FBetaMulti F1ScoreMulti BrierScoreMulti APScoreMulti accuracy_multi preplexity MatthewsCorrCoef RocAucBinary RocAuc Recall Precision Jaccard HammingLoss FBeta F1Score CohenKappa BrierScore BalancedAccuracy APScoreBinary top_k_accuracy accuracy optim_metric skm_to_fastai AccumMetric flatten_check

Documented in AccumMetric accuracy accuracy_multi APScoreBinary APScoreMulti BalancedAccuracy BrierScore BrierScoreMulti CohenKappa CorpusBLEUMetric Dice ExplainedVariance exp_rmspe F1Score F1ScoreMulti FBeta FBetaMulti flatten_check foreground_acc HammingLoss Jaccard JaccardCoeff JaccardMulti LossMetric mae MatthewsCorrCoef MatthewsCorrCoefMulti mse msle optim_metric PearsonCorrCoef Precision PrecisionMulti preplexity R2Score Recall RecallMulti rmse RocAuc RocAucBinary RocAucMulti skm_to_fastai SpearmanCorrCoef top_k_accuracy

#' @title Flatten check
#'
#' @description Check that `out` and `targ` have the same number of elements and flatten them.
#'
#'
#' @param inp predictions
#' @param targ targets
#' @return tensor
#' @export
flatten_check <- function(inp, targ) {

  if(missing(inp) & missing(targ)) {
    metrics()$flatten_check
  } else {
    metrics()$flatten_check(
      inp = inp,
      targ = targ
    )
  }

}


#' @title AccumMetric
#'
#' @description Stores predictions and targets on CPU in accumulate to perform final calculations with `func`.
#'
#'
#' @param func function
#' @param dim_argmax dimension argmax
#' @param activation activation
#' @param thresh threshold point
#' @param to_np to matrix or not
#' @param invert_arg invert arguments
#' @param flatten flatten
#' @param ... additional arguments to pass
#' @return None
#' @export
AccumMetric <- function(func, dim_argmax = NULL, activation = "no",
                        thresh = NULL, to_np = FALSE,
                        invert_arg = FALSE, flatten = TRUE,
                        ...) {

  args <- list(
    func = func,
    dim_argmax = dim_argmax,
    activation = activation,
    thresh = thresh,
    to_np = to_np,
    invert_arg = invert_arg,
    flatten = flatten,
    ...
  )

  if(is.null(args$dim_argmax))
    args$dim_argmax <- NULL

  if(is.null(args$thresh))
    args$thresh <- NULL

  do.call(metrics()$AccumMetric, args)

}

#' @title Skm to fastai
#'
#' @description Convert `func` from sklearn$metrics to a fastai metric
#'
#'
#' @param func function
#' @param is_class is classification or not
#' @param thresh threshold point
#' @param axis axis
#' @param activation activation
#' @param ... additional arguments to pass
#' @return None
#' @export
skm_to_fastai <- function(func, is_class = TRUE, thresh = NULL,
                          axis = -1, activation = NULL,
                          ...) {

  args <- list(
    func = func,
    is_class = is_class,
    thresh = thresh,
    axis = as.integer(axis),
    activation = activation,
    ...
  )

  if(is.null(args$thresh))
    args$thresh <- NULL

  if(is.null(args$activation))
    args$activation <- NULL

  do.call(metrics()$skm_to_fastai, args)

}


#' @title Optim metric
#'
#' @description Replace metric `f` with a version that optimizes argument `argname`
#'
#'
#' @param f f
#' @param argname argname
#' @param bounds bounds
#' @param tol tol
#' @param do_neg do_neg
#' @param get_x get_x
#' @return None
#' @export
optim_metric <- function(f, argname, bounds,
                         tol = 0.01, do_neg = TRUE, get_x = FALSE) {

  args <- list(
    f = f,
    argname = argname,
    bounds = bounds,
    tol = tol,
    do_neg = do_neg,
    get_x = get_x
  )

  do.call(metrics()$optim_metric, args)

}


#' @title Accuracy
#'
#' @description Compute accuracy with `targ` when `pred` is bs * n_classes
#'
#'
#' @param inp predictions
#' @param targ targets
#' @param axis axis
#' @return None
#' @export
accuracy <- function(inp, targ, axis = -1) {

  if(missing(inp) & missing(targ)) {
    metrics()$accuracy
  } else {
    args <- list(
      inp = inp,
      targ = targ,
      axis = as.integer(axis)
    )
    do.call(metrics()$accuracy, args)
  }

}

attr(accuracy, "py_function_name") <- 'accuracy'

#' @title Top_k_accuracy
#'
#' @description Computes the Top-k accuracy (`targ` is in the top `k` predictions of `inp`)
#'
#'
#' @param inp predictions
#' @param targ targets
#' @param k k
#' @param axis axis
#' @return None
#'
#' @examples
#'
#' \dontrun{
#'
#' loaders = loaders()
#'
#' data = Data_Loaders(loaders['train'], loaders['valid'])$cuda()
#'
#' model = nn$Sequential() +
#'   nn$Flatten() +
#'   nn$Linear(28L * 28L, 10L)
#' metrics = list(accuracy,top_k_accuracy)
#' learn = Learner(data, model, loss_func = F$cross_entropy, opt_func = Adam,
#'                 metrics = metrics)
#'
#' }
#'
#' @export
top_k_accuracy <- function(inp, targ, k = 5, axis = -1) {


  if(missing(inp) & missing(targ)) {
    metrics()$top_k_accuracy
  } else {
    args <- list(
      inp = inp,
      targ = targ,
      k = as.integer(k),
      axis = as.integer(axis)
    )

    do.call(metrics()$top_k_accuracy, args)
  }

}

attr(top_k_accuracy, "py_function_name") <- 'top_k_accuracy'

#' @title APScoreBinary
#'
#' @description Average Precision for single-label binary classification problems
#'
#'
#' @param axis axis
#' @param average average
#' @param pos_label pos_label
#' @param sample_weight sample_weight
#' @return None
#' @export
APScoreBinary <- function(axis = -1, average = "macro", pos_label = 1, sample_weight = NULL) {

  args <- list(
    axis = as.integer(axis),
    average = average,
    pos_label = as.integer(pos_label),
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL


  do.call(metrics()$APScoreBinary, args)

}

#' @title BalancedAccuracy
#'
#' @description Balanced Accuracy for single-label binary classification problems
#'
#'
#' @param axis axis
#' @param sample_weight sample_weight
#' @param adjusted adjusted
#' @references None
#' @export
BalancedAccuracy <- function(axis = -1, sample_weight = NULL, adjusted = FALSE) {

  args <- list(
    axis = as.integer(axis),
    sample_weight = sample_weight,
    adjusted = adjusted
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  do.call(metrics()$BalancedAccuracy, args)

}


#' @title BrierScore
#'
#' @description Brier score for single-label classification problems
#'
#'
#' @param axis axis
#' @param sample_weight sample_weight
#' @param pos_label pos_label
#' @return None
#' @export
BrierScore <- function(axis = -1, sample_weight = NULL, pos_label = NULL) {

  args <- list(
    axis = as.integer(axis),
    sample_weight = sample_weight,
    pos_label = pos_label
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  if(is.null(args$pos_label))
    args$pos_label <- NULL

  do.call(metrics()$BrierScore, args)

}


#' @title CohenKappa
#'
#' @description Cohen kappa for single-label classification problems
#'
#'
#' @param axis axis
#' @param labels labels
#' @param weights weights
#' @param sample_weight sample_weight
#' @return None
#' @export
CohenKappa <- function(axis = -1, labels = NULL, weights = NULL, sample_weight = NULL) {

  args <- list(
    axis = as.integer(axis),
    labels = labels,
    weights = weights,
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  if(is.null(args$weights))
    args$weights <- NULL

  do.call(metrics()$CohenKappa, args)

}


#' @title F1Score
#'
#' @description F1 score for single-label classification problems
#'

#'
#' @param axis axis
#' @param labels labels
#' @param pos_label pos_label
#' @param average average
#' @param sample_weight sample_weight
#' @return None
#' @export
F1Score <- function(axis = -1, labels = NULL, pos_label = 1, average = "binary", sample_weight = NULL) {

  args <- list(
    axis = as.integer(axis),
    labels = labels,
    pos_label = as.integer(pos_label),
    average = average,
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  do.call(metrics()$F1Score, args)

}

#' @title FBeta
#'
#' @description FBeta score with `beta` for single-label classification problems
#'
#'
#' @param beta beta
#' @param axis axis
#' @param labels labels
#' @param pos_label pos_label
#' @param average average
#' @param sample_weight sample_weight
#' @return None
#' @export
FBeta <- function(beta, axis = -1, labels = NULL, pos_label = 1, average = "binary", sample_weight = NULL) {


  if(missing(beta)) {
    metrics()$FBeta
  } else {
    args <- list(
      beta = beta,
      axis = as.integer(axis),
      labels = labels,
      pos_label = as.integer(pos_label),
      average = average,
      sample_weight = sample_weight
    )

    if(is.null(args$sample_weight))
      args$sample_weight <- NULL

    do.call(metrics()$FBeta, args)
  }

}

#' @title HammingLoss
#'
#' @description Hamming loss for single-label classification problems
#'
#'
#' @param axis axis
#' @param sample_weight sample_weight
#' @return None
#' @export
HammingLoss <- function(axis = -1, sample_weight = NULL) {

  args <- list(
    axis = as.integer(axis),
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  do.call(metrics()$HammingLoss, args)

}


#' @title Jaccard
#'
#' @description Jaccard score for single-label classification problems
#'
#'
#' @param axis axis
#' @param labels labels
#' @param pos_label pos_label
#' @param average average
#' @param sample_weight sample_weight
#' @return None
#' @export
Jaccard <- function(axis = -1, labels = NULL, pos_label = 1,
                    average = "binary", sample_weight = NULL) {

  args <- list(
    axis = as.integer(axis),
    labels = labels,
    pos_label = as.integer(pos_label),
    average = average,
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  do.call(metrics()$Jaccard, args)

}


#' @title Precision
#'
#' @description Precision for single-label classification problems
#'
#'
#' @param axis axis
#' @param labels labels
#' @param pos_label pos_label
#' @param average average
#' @param sample_weight sample_weight
#' @return None
#' @export
Precision <- function(axis = -1, labels = NULL, pos_label = 1,
                      average = "binary", sample_weight = NULL) {

  args <- list(
    axis = as.integer(axis),
    labels = labels,
    pos_label = as.integer(pos_label),
    average = average,
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  do.call(metrics()$Precision, args)

}


#' @title Recall
#'
#' @description Recall for single-label classification problems
#'
#'
#' @param axis axis
#' @param labels labels
#' @param pos_label pos_label
#' @param average average
#' @param sample_weight sample_weight
#' @return None
#' @export
Recall <- function(axis = -1, labels = NULL, pos_label = 1,
                   average = "binary", sample_weight = NULL) {

  args <- list(
    axis = as.integer(axis),
    labels = labels,
    pos_label = as.integer(pos_label),
    average = average,
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  do.call(metrics()$Recall, args)

}


#' @title RocAuc
#'
#' @description Area Under the Receiver Operating Characteristic
#' Curve for single-label multiclass classification problems
#'
#'
#' @param axis axis
#' @param average average
#' @param sample_weight sample_weight
#' @param max_fpr max_fpr
#' @param multi_class multi_class
#' @return None
#' @export
RocAuc <- function(axis = -1, average = "macro",
                   sample_weight = NULL, max_fpr = NULL, multi_class = "ovr") {

  args <- list(
    axis = as.integer(axis),
    average = average,
    sample_weight = sample_weight,
    max_fpr = max_fpr,
    multi_class = multi_class
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  if(is.null(args$max_fpr))
    args$max_fpr <- NULL

  do.call(metrics()$RocAuc, args)

}


#' @title RocAucBinary
#'
#' @description Area Under the Receiver Operating Characteristic Curve for single-label binary classification problems
#'
#'
#' @param axis axis
#' @param average average
#' @param sample_weight sample_weight
#' @param max_fpr max_fpr
#' @param multi_class multi_class
#' @return None
#'
#' @examples
#'
#' \dontrun{
#'
#' model = dls %>% tabular_learner(layers=c(200,100,100,200),
#' config = tabular_config(embed_p = 0.3, use_bn = FALSE),
#' metrics = list(accuracy, RocAucBinary(),
#'                Precision(), Recall(),
#'                F1Score()))
#'
#' }
#'
#' @export
RocAucBinary <- function(axis = -1, average = "macro",
                         sample_weight = NULL, max_fpr = NULL, multi_class = "raise") {

  args <- list(
    axis = as.integer(axis),
    average = average,
    sample_weight = sample_weight,
    max_fpr = max_fpr,
    multi_class = multi_class
  )


  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  if(is.null(args$max_fpr))
    args$max_fpr <- NULL

  do.call(metrics()$RocAucBinary, args)

}


#' @title MatthewsCorrCoef
#'
#' @description Matthews correlation coefficient for single-label classification problems
#' @param ... parameters to pass
#' @return None
#' @export
MatthewsCorrCoef <- function( ...) {

  args = list(...)

  if(length(args)>0){
    metrics()$MatthewsCorrCoef(...)
  } else {
    metrics()$MatthewsCorrCoef
  }

}


#' @title Perplexity
#'
#' @description Perplexity (exponential of cross-entropy loss) for Language Models
#'
#' @param ... parameters to pass
#' @return None
#' @export
preplexity = function(...) {

  args = list(...)

  if(length(args)>0) {
    do.call( metrics()$perplexity,args)
  } else {
    metrics()$perplexity
  }

}

#' @title Accuracy_multi
#'
#' @description Compute accuracy when `inp` and `targ` are the same size.
#'
#'
#' @param inp predictions
#' @param targ targets
#' @param thresh threshold point
#' @param sigmoid sigmoid
#' @return None
#' @export
accuracy_multi <- function(inp, targ, thresh = 0.5, sigmoid = TRUE) {

  if(missing(inp) & missing(targ)) {
    metrics()$accuracy_multi
  } else {
    args <- list(
      inp = inp,
      targ = targ,
      thresh = thresh,
      sigmoid = sigmoid
    )
    do.call(metrics()$accuracy_multi, args)
  }

}


#' @title APScoreMulti
#'
#' @description Average Precision for multi-label classification problems
#'
#'
#' @param sigmoid sigmoid
#' @param average average
#' @param pos_label pos_label
#' @param sample_weight sample_weight
#' @return None
#' @export
APScoreMulti <- function(sigmoid = TRUE, average = "macro",
                         pos_label = 1, sample_weight = NULL) {

  args <- list(
    sigmoid = sigmoid,
    average = average,
    pos_label = as.integer(pos_label),
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  do.call(metrics()$APScoreMulti, args)

}


#' @title BrierScoreMulti
#'
#' @description Brier score for multi-label classification problems
#'
#'
#' @param thresh thresh
#' @param sigmoid sigmoid
#' @param sample_weight sample_weight
#' @param pos_label pos_label
#' @return None
#' @export
BrierScoreMulti <- function(thresh = 0.5, sigmoid = TRUE,
                            sample_weight = NULL, pos_label = NULL) {

  args = list(
    thresh = thresh,
    sigmoid = sigmoid,
    sample_weight = sample_weight,
    pos_label = pos_label
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  if(is.null(args$pos_label))
    args$pos_label <- NULL

  do.call(metrics()$BrierScoreMulti, args)

}



#' @title F1ScoreMulti
#'
#' @description F1 score for multi-label classification problems
#'
#' @param thresh thresh
#' @param sigmoid sigmoid
#' @param labels labels
#' @param pos_label pos_label
#' @param average average
#' @param sample_weight sample_weight
#' @return None
#' @export
F1ScoreMulti <- function(thresh = 0.5, sigmoid = TRUE, labels = NULL,
                         pos_label = 1, average = "macro",
                         sample_weight = NULL) {

  args = list(
    thresh = thresh,
    sigmoid = sigmoid,
    labels = labels,
    pos_label = as.integer(pos_label),
    average = average,
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  if(is.null(args$labels))
    args$labels <- NULL

  do.call(metrics()$F1ScoreMulti, args)

}


#' @title FBetaMulti
#'
#' @description FBeta score with `beta` for multi-label classification problems
#'
#'
#' @param beta beta
#' @param thresh thresh
#' @param sigmoid sigmoid
#' @param labels labels
#' @param pos_label pos_label
#' @param average average
#' @param sample_weight sample_weight
#' @return None
#' @export
FBetaMulti <- function(beta, thresh = 0.5, sigmoid = TRUE, labels = NULL,
                       pos_label = 1, average = "macro", sample_weight = NULL) {

  args = list(
    beta = beta,
    thresh = thresh,
    sigmoid = sigmoid,
    labels = labels,
    pos_label = as.integer(pos_label),
    average = average,
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  if(is.null(args$labels))
    args$labels <- NULL

  do.call(metrics()$FBetaMulti, args)

}





#' @title JaccardMulti
#'
#' @description Jaccard score for multi-label classification problems
#'
#' @param thresh thresh
#' @param sigmoid sigmoid
#' @param labels labels
#' @param pos_label pos_label
#' @param average average
#' @param sample_weight sample_weight
#' @return None
#' @export
JaccardMulti <- function(thresh = 0.5, sigmoid = TRUE,
                         labels = NULL, pos_label = 1,
                         average = "macro", sample_weight = NULL) {

  args = list(
    thresh = thresh,
    sigmoid = sigmoid,
    labels = labels,
    pos_label = as.integer(pos_label),
    average = average,
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  if(is.null(args$labels))
    args$labels <- NULL

  do.call(metrics()$JaccardMulti, args)
}


#' @title MatthewsCorrCoefMulti
#'
#' @description Matthews correlation coefficient for multi-label classification problems
#'
#'
#' @param thresh thresh
#' @param sigmoid sigmoid
#' @param sample_weight sample_weight
#' @return None
#' @export
MatthewsCorrCoefMulti <- function(thresh = 0.5, sigmoid = TRUE, sample_weight = NULL) {

  args = list(
    thresh = thresh,
    sigmoid = sigmoid,
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  do.call(metrics()$MatthewsCorrCoefMulti, args)

}


#' @title PrecisionMulti
#'
#' @description Precision for multi-label classification problems
#'
#' @param thresh thresh
#' @param sigmoid sigmoid
#' @param labels labels
#' @param pos_label pos_label
#' @param average average
#' @param sample_weight sample_weight
#' @return None
#' @export
PrecisionMulti <- function(thresh = 0.5, sigmoid = TRUE, labels = NULL,
                           pos_label = 1, average = "macro",
                           sample_weight = NULL) {

  args = list(
    thresh = thresh,
    sigmoid = sigmoid,
    labels = labels,
    pos_label = as.integer(pos_label),
    average = average,
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  if(is.null(args$labels))
    args$labels <- NULL

  do.call(metrics()$PrecisionMulti, args)

}

#' @title RecallMulti
#'
#' @description Recall for multi-label classification problems
#'
#'
#' @param thresh thresh
#' @param sigmoid sigmoid
#' @param labels labels
#' @param pos_label pos_label
#' @param average average
#' @param sample_weight sample_weight
#' @return None
#' @export
RecallMulti <- function(thresh = 0.5, sigmoid = TRUE, labels = NULL,
                        pos_label = 1, average = "macro",
                        sample_weight = NULL) {

  args = list(
    thresh = thresh,
    sigmoid = sigmoid,
    labels = labels,
    pos_label = as.integer(pos_label),
    average = average,
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  if(is.null(args$labels))
    args$labels <- NULL

  do.call(metrics()$RecallMulti, args)

}

#' @title RocAucMulti
#'
#' @description Area Under the Receiver Operating Characteristic Curve for multi-label binary classification problems
#'
#'
#' @param sigmoid sigmoid
#' @param average average
#' @param sample_weight sample_weight
#' @param max_fpr max_fpr
#' @return None
#' @export
RocAucMulti <- function(sigmoid = TRUE, average = "macro",
                        sample_weight = NULL, max_fpr = NULL) {

  args = list(
    sigmoid = sigmoid,
    average = average,
    sample_weight = sample_weight,
    max_fpr = max_fpr
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  if(is.null(args$max_fpr))
    args$max_fpr <- NULL

  do.call(metrics()$RocAucMulti, args)

}

#' @title MSE
#'
#' @description Mean squared error between `inp` and `targ`.
#'
#'
#' @param inp predictions
#' @param targ targets
#' @return None
#'
#'
#' @examples
#'
#' \dontrun{
#'
#' model = dls %>% tabular_learner(layers=c(200,100,100,200),
#' metrics = list(mse(),rmse()) )
#'
#' }
#'
#' @export
mse <- function(inp, targ) {

  if(missing(inp) & missing(targ)) {
    metrics()$mse
  } else {
    metrics()$mse(
      inp = inp,
      targ = targ
    )
  }

}


#' @title RMSE
#'
#' @description Root mean squared error
#'
#'
#' @param preds predictions
#' @param targs targets
#' @return None
#'
#'
#' @examples
#'
#' \dontrun{
#'
#' model = dls %>% tabular_learner(layers=c(200,100,100,200),
#' metrics = list(mse(),rmse()) )
#'
#' }
#'
#' @export
rmse <- function(preds, targs) {

  if(missing(preds) & missing(targs)) {
    metrics()$rmse
  } else {
    metrics()$rmse(
      preds = preds,
      targs = targs
    )
  }

}

#' @title MAE
#'
#' @description Mean absolute error between `inp` and `targ`.
#'
#'
#' @param inp predictions
#' @param targ targets
#' @return None
#' @export
mae <- function(inp, targ) {

  if(missing(inp) & missing(targ)) {
    metrics()$mae
  } else {
    metrics()$mae(
      inp = inp,
      targ = targ
    )
  }

}


#' @title MSLE
#'
#' @description Mean squared logarithmic error between `inp` and `targ`.
#'
#' @param inp predictions
#' @param targ targets
#' @return None
#' @export
msle <- function(inp, targ) {

  if(missing(inp) & missing(targ)) {
    metrics()$msle
  } else {
    metrics()$msle(
      inp = inp,
      targ = targ
    )
  }
}


#' @title Exp_rmspe
#'
#' @description Root mean square percentage error of the exponential of predictions and targets
#'
#'
#' @param preds predicitons
#' @param targs targets
#' @return None
#' @export
exp_rmspe <- function(preds, targs) {

  if(missing(preds) & missing(targs)) {
    metrics()$exp_rmspe
  } else {
    metrics()$exp_rmspe(
      preds = preds,
      targs = targs
    )
  }

}


#' @title Explained Variance
#'
#' @description Explained variance between predictions and targets
#'
#'
#' @param sample_weight sample_weight
#' @return None
#' @export
ExplainedVariance <- function(sample_weight = NULL) {

  args = list(
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  do.call(metrics()$ExplainedVariance, args)

}

#' @title R2Score
#'
#' @description R2 score between predictions and targets
#'
#'
#' @param sample_weight sample_weight
#' @return None
#' @export
R2Score <- function(sample_weight = NULL) {

  args = list(
    sample_weight = sample_weight
  )

  if(is.null(args$sample_weight))
    args$sample_weight <- NULL

  do.call(metrics()$R2Score, args)

}


#' @title PearsonCorrCoef
#'
#' @description Pearson correlation coefficient for regression problem
#'
#'
#' @param dim_argmax dim_argmax
#' @param activation activation
#' @param thresh thresh
#' @param to_np to_np
#' @param invert_arg invert_arg
#' @param flatten flatten
#' @return None
#' @export
PearsonCorrCoef <- function(dim_argmax = NULL, activation = "no",
                            thresh = NULL, to_np = FALSE, invert_arg = FALSE, flatten = TRUE) {

  args = list(
    dim_argmax = dim_argmax,
    activation = activation,
    thresh = thresh,
    to_np = to_np,
    invert_arg = invert_arg,
    flatten = flatten
  )

  if(is.null(args$dim_argmax))
    args$dim_argmax <- NULL

  if(is.null(args$thresh))
    args$thresh <- NULL

  do.call(metrics()$PearsonCorrCoef, args)

}


#' @title SpearmanCorrCoef
#'
#' @description Spearman correlation coefficient for regression problem
#'
#'
#' @param dim_argmax dim_argmax
#' @param axis axis
#' @param nan_policy nan_policy
#' @param activation activation
#' @param thresh thresh
#' @param to_np to_np
#' @param invert_arg invert_arg
#' @param flatten flatten
#' @return None
#' @export
SpearmanCorrCoef <- function(dim_argmax = NULL, axis = 0, nan_policy = "propagate",
                             activation = "no", thresh = NULL, to_np = FALSE,
                             invert_arg = FALSE, flatten = TRUE) {

  args <- list(
    dim_argmax = dim_argmax,
    axis = as.integer(axis),
    nan_policy = nan_policy,
    activation = activation,
    thresh = thresh,
    to_np = to_np,
    invert_arg = invert_arg,
    flatten = flatten
  )

  if(is.null(args$dim_argmax))
    args$dim_argmax <- NULL

  if(is.null(args$thresh))
    args$thresh <- NULL

  do.call(metrics()$SpearmanCorrCoef, args)

}



#' @title Foreground accuracy
#'
#' @description Computes non-background accuracy for multiclass segmentation
#'
#'
#' @param inp predictions
#' @param targ targets
#' @param bkg_idx bkg_idx
#' @param axis axis
#' @return None
#' @export
foreground_acc <- function(inp, targ, bkg_idx = 0, axis = 1) {


  if(missing(inp) & missing(targ)) {
    metrics()$foreground_acc
  } else {
    metrics()$foreground_acc(
      inp = inp,
      targ = targ,
      bkg_idx = as.integer(bkg_idx),
      axis = as.integer(axis)
    )
  }

}

#' @title Dice coefficient
#'
#' @description Dice coefficient metric for binary target in segmentation
#'
#'
#' @param axis axis
#' @return None
#' @export
Dice <- function(axis = 1) {

  metrics()$Dice(
    axis = as.integer(axis)
  )

}

#' @title JaccardCoeff
#'
#' @description Implementation of the Jaccard coefficient that is lighter in RAM
#'
#'
#' @param axis axis
#' @return None
#' @export
JaccardCoeff <- function(axis = 1) {

  metrics()$JaccardCoeff(
    axis = as.integer(axis)
  )

}

#' @title CorpusBLEUMetric
#'
#' @description Blueprint for defining a metric
#'
#'
#' @param vocab_sz vocab_sz
#' @param axis axis
#' @return None
#' @export
CorpusBLEUMetric <- function(vocab_sz = 5000, axis = -1) {

  metrics()$CorpusBLEUMetric(
    vocab_sz = as.integer(vocab_sz),
    axis = as.integer(axis)
  )

}


#' @title LossMetric
#'
#' @description Create a metric from `loss_func.attr` named `nm`
#'
#' @param attr attr
#' @param nm nm
#' @return None
#' @export
LossMetric <- function(attr, nm = NULL) {

  metrics()$LossMetric(
    attr = attr,
    nm = nm
  )

}

Try the fastai package in your browser

Any scripts or data that you put into this service are public.

fastai documentation built on March 31, 2023, 11:41 p.m.