predict.citodnn | R Documentation |
Predict from a fitted dnn model
## S3 method for class 'citodnn'
predict(
object,
newdata = NULL,
type = c("link", "response", "class"),
device = c("cpu", "cuda", "mps"),
reduce = c("mean", "median", "none"),
...
)
## S3 method for class 'citodnnBootstrap'
predict(
object,
newdata = NULL,
type = c("link", "response", "class"),
device = c("cpu", "cuda", "mps"),
reduce = c("mean", "median", "none"),
...
)
object |
a model created by |
newdata |
new data for predictions |
type |
type of predictions. The default is on the scale of the linear predictor, "response" is on the scale of the response, and "class" means that class predictions are returned (if it is a classification task) |
device |
device on which network should be trained on. |
reduce |
predictions from bootstrapped model are by default reduced (mean, optional median or none) |
... |
additional arguments |
prediction matrix
if(torch::torch_is_installed()){
library(cito)
set.seed(222)
validation_set<- sample(c(1:nrow(datasets::iris)),25)
# Build and train Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris[-validation_set,])
# Use model on validation set
predictions <- predict(nn.fit, iris[validation_set,])
# Scatterplot
plot(iris[validation_set,]$Sepal.Length,predictions)
# MAE
mean(abs(predictions-iris[validation_set,]$Sepal.Length))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.