predict.citocnn | R Documentation |
This function generates predictions from a Convolutional Neural Network (CNN) model that was created using the cnn
function.
## S3 method for class 'citocnn'
predict(
object,
newdata = NULL,
type = c("link", "response", "class"),
device = NULL,
batchsize = NULL,
...
)
object |
a model created by |
newdata |
A multidimensional array representing the new data for which predictions are to be made. The dimensions of |
type |
A character string specifying the type of prediction to be made. Options are:
|
device |
Device to be used for making predictions. Options are "cpu", "cuda", and "mps". Default is "cpu". |
batchsize |
An integer specifying the number of samples to be processed at the same time. If |
... |
Additional arguments (currently not used). |
A matrix of predictions. If type
is "class"
, a factor of predicted class labels is returned.
if(torch::torch_is_installed()){
library(cito)
set.seed(222)
device <- ifelse(torch::cuda_is_available(), "cuda", "cpu")
## Data
shapes <- cito:::simulate_shapes(320, 28)
X <- shapes$data
Y <- shapes$labels
## Architecture
architecture <- create_architecture(conv(5), maxPool(), conv(5), maxPool(), linear(10))
## Build and train network
cnn.fit <- cnn(X, Y, architecture, loss = "softmax", epochs = 50, validation = 0.1, lr = 0.05, device=device)
## Get predictions of the validation set
valid <- cnn.fit$data$validation
predictions <- predict(cnn.fit, newdata = X[valid,,,,drop=FALSE], type="class")
## Classification accuracy
accuracy <- sum(predictions == Y[valid])/length(valid)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.