plot.citocnn: Plot a fitted CNN model

View source: R/cnn.R

plot.citocnnR Documentation

Plot a fitted CNN model

Description

This function plots the architecture of a Convolutional Neural Network (CNN) model created using the cnn function.

Usage

## S3 method for class 'citocnn'
plot(x, ...)

Arguments

x

A model created by cnn.

...

Additional arguments (currently not used).

Value

The original model object x, returned invisibly.

Examples


if(torch::torch_is_installed()){
library(cito)

set.seed(222)

device <- ifelse(torch::cuda_is_available(), "cuda", "cpu")

## Data
shapes <- cito:::simulate_shapes(320, 28)
X <- shapes$data
Y <- shapes$labels

## Architecture
architecture <- create_architecture(conv(5), maxPool(), conv(5), maxPool(), linear(10))

## Build and train network
cnn.fit <- cnn(X, Y, architecture, loss = "softmax", epochs = 50, validation = 0.1, lr = 0.05, device=device)

## Structure of Neural Network
plot(cnn.fit)
}


citoverse/cito documentation built on Jan. 16, 2025, 11:49 p.m.