luz_callback_mixup: Mixup callback

luz_callback_mixupR Documentation

Mixup callback

Description

Implementation of 'mixup: Beyond Empirical Risk Minimization'. As of today, tested only for categorical data, where targets are expected to be integers, not one-hot encoded vectors. This callback is supposed to be used together with nn_mixup_loss().

Usage

luz_callback_mixup(alpha = 0.4, ..., run_valid = FALSE, auto_loss = FALSE)

Arguments

alpha

parameter for the beta distribution used to sample mixing coefficients

...

currently unused. Just to force named arguments.

run_valid

Should it run during validation

auto_loss

Should it automatically modify the loss function? This will wrap the loss function to create the mixup loss. If TRUE make sure that your loss function does not apply reductions. If run_valid=FALSE, then loss will be mean reduced during validation.

Details

Overall, we follow the fastai implementation described here. Namely,

  • We work with a single dataloader only, randomly mixing two observations from the same batch.

  • We linearly combine losses computed for both targets: loss(output, new_target) = weight * loss(output, target1) + (1-weight) * loss(output, target2)

  • We draw different mixing coefficients for every pair.

  • We replace weight with weight = max(weight, 1-weight) to avoid duplicates.

Value

A luz_callback

See Also

nn_mixup_loss(), nnf_mixup()

Other luz_callbacks: luz_callback_auto_resume(), luz_callback_csv_logger(), luz_callback_early_stopping(), luz_callback_interrupt(), luz_callback_keep_best_model(), luz_callback_lr_scheduler(), luz_callback_metrics(), luz_callback_mixed_precision(), luz_callback_model_checkpoint(), luz_callback_profile(), luz_callback_progress(), luz_callback_resume_from_checkpoint(), luz_callback_train_valid(), luz_callback()

Examples

if (torch::torch_is_installed()) {
mixup_callback <- luz_callback_mixup()
}


mlverse/torchlight documentation built on Sept. 19, 2024, 11:22 p.m.