predict.nn: Neural network prediction

Description Usage Arguments Value Author(s) Examples

Description

Prediction of artificial neural network of class nn, produced by neuralnet().

Usage

1
2
## S3 method for class 'nn'
predict(object, newdata, rep = 1, all.units = FALSE, ...)

Arguments

object

Neural network of class nn.

newdata

New data of class data.frame or matrix.

rep

Integer indicating the neural network's repetition which should be used.

all.units

Return output for all units instead of final output only.

...

further arguments passed to or from other methods.

Value

Matrix of predictions. Each column represents one output unit. If all.units=TRUE, a list of matrices with output for each unit.

Author(s)

Marvin N. Wright

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
library(neuralnet)

# Split data
train_idx <- sample(nrow(iris), 2/3 * nrow(iris))
iris_train <- iris[train_idx, ]
iris_test <- iris[-train_idx, ]

# Binary classification
nn <- neuralnet(Species == "setosa" ~ Petal.Length + Petal.Width, iris_train, linear.output = FALSE)
pred <- predict(nn, iris_test)
table(iris_test$Species == "setosa", pred[, 1] > 0.5)

# Multiclass classification
nn <- neuralnet((Species == "setosa") + (Species == "versicolor") + (Species == "virginica")
                 ~ Petal.Length + Petal.Width, iris_train, linear.output = FALSE)
pred <- predict(nn, iris_test)
table(iris_test$Species, apply(pred, 1, which.max))

neuralnet documentation built on May 2, 2019, 9:17 a.m.