make_preds_prob: Makes probabilistic predictions

Description Usage Arguments Value Examples

View source: R/predict.R

Description

Makes probabilistic predictions for neural network with classification task.

Usage

1
make_preds_prob(model, test_ds, dev)

Arguments

model

neural network classification model

test_ds

dataset object from torch used for making test predictions

dev

device used for calculations (cpu or gpu)

Value

float (probabilistic) vector of predictions

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
## Not run: 
dev       <-  "cpu"

# presaved torch model
model1    <- torch_load(system.file("extdata","preclf",package="fairpan"))

# presaved output of preprocess function
processed <- torch_load(system.file("extdata","processed",package="fairpan"))

dsl       <- dataset_loader(processed$train_x, processed$train_y,
             processed$test_x,processed$test_y, batch_size=5, dev=dev)

preds1    <- make_preds_prob(model1,dsl$test_ds,dev)

## End(Not run)

ModelOriented/FairPAN documentation built on Dec. 17, 2021, 4:19 a.m.