predict.citocnn: Predict with a fitted CNN model

View source: R/cnn.R

predict.citocnnR Documentation

Predict with a fitted CNN model

Description

This function generates predictions from a Convolutional Neural Network (CNN) model that was created using the cnn function.

Usage

## S3 method for class 'citocnn'
predict(
  object,
  newdata = NULL,
  type = c("link", "response", "class"),
  device = NULL,
  batchsize = NULL,
  ...
)

Arguments

object

a model created by cnn.

newdata

A multidimensional array representing the new data for which predictions are to be made. The dimensions of newdata should match those of the training data, except for the first dimension which represents the number of samples. If NULL, the function uses the data the model was trained on.

type

A character string specifying the type of prediction to be made. Options are:

  • "link": Scale of the linear predictor.

  • "response": Scale of the response.

  • "class": The predicted class labels (for classification tasks).

device

Device to be used for making predictions. Options are "cpu", "cuda", and "mps". Default is "cpu".

batchsize

An integer specifying the number of samples to be processed at the same time. If NULL, the function uses the same batchsize that was used when training the model. Default is NULL.

...

Additional arguments (currently not used).

Value

A matrix of predictions. If type is "class", a factor of predicted class labels is returned.

Examples


if(torch::torch_is_installed()){
library(cito)

set.seed(222)

device <- ifelse(torch::cuda_is_available(), "cuda", "cpu")

## Data
shapes <- cito:::simulate_shapes(320, 28)
X <- shapes$data
Y <- shapes$labels

## Architecture
architecture <- create_architecture(conv(5), maxPool(), conv(5), maxPool(), linear(10))

## Build and train network
cnn.fit <- cnn(X, Y, architecture, loss = "softmax", epochs = 50, validation = 0.1, lr = 0.05, device=device)

## Get predictions of the validation set
valid <- cnn.fit$data$validation
predictions <- predict(cnn.fit, newdata = X[valid,,,,drop=FALSE], type="class")

## Classification accuracy
accuracy <- sum(predictions == Y[valid])/length(valid)

}


citoverse/cito documentation built on Jan. 16, 2025, 11:49 p.m.