train_PAN: Trains PAN model

Description Usage Arguments Value

View source: R/train_PAN.R

Description

Trains Predictive Adversarial Network model, which means that it proceeds with the mutual training of adversarial model on whole dataloader and classifier on a single mini batch. The result is a fairer classifier.

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
train_PAN(
  dsl,
  clf_model,
  adv_model,
  clf_optimizer,
  adv_optimizer,
  dev,
  sensitive_train,
  sensitive_test,
  n_ep_pan = 50,
  batch_size,
  learning_rate_adv,
  learning_rate_clf,
  lambda,
  verbose = TRUE,
  monitor = TRUE
)

Arguments

dsl

dataset_loader object for classificator network

clf_model

net, nn_module, classifier model (preferably after pretrain)

adv_model

net, nn_module, adversarial model (preferably after pretrain)

clf_optimizer

optimizer for classificator model from pretrain

adv_optimizer

optimizer for adversarial model from pretrain

dev

device used to computation ("cuda" or "cpu")

sensitive_train

integer vector of sensitive attribute used for training

sensitive_test

integer vector of sensitive attribute used for testing

n_ep_pan

number of epochs for PAN training

batch_size

batch size used in adversarial models dataset_loader

learning_rate_adv

learning rate of adversarial

learning_rate_clf

learning rate of classifier

lambda

parameter regulating learning process (intuition: the bigger it is, the fairer predictions and the worse accuracy of classifier).

verbose

logical indicating if we want to print monitored outputs or not

monitor

logical indicating if we want to monitor the learning process or not (monitoring tends to slow down the training proccess, but provides some useful info to adjust parameters and training process)

Value

NULL if monitor is FALSE, list of metrics if it is TRUE


ModelOriented/FairPAN documentation built on Dec. 17, 2021, 4:19 a.m.