# Packages ---------------------------------------------------------------- library(torch) library(torchvision) library(torchdatasets) library(luz) set.seed(1) torch_manual_seed(1) # Datasets and loaders ---------------------------------------------------- dir <- "./dogs-vs-cats" # caching directory ds <- torchdatasets::dogs_vs_cats_dataset( dir, download = TRUE, transform = . %>% torchvision::transform_to_tensor() %>% torchvision::transform_resize(size = c(224, 224)) %>% torchvision::transform_normalize(rep(0.5, 3), rep(0.5, 3)), target_transform = function(x) as.double(x) - 1 ) train_id <- sample.int(length(ds), size = 0.7*length(ds)) train_ds <- dataset_subset(ds, indices = train_id) valid_ds <- dataset_subset(ds, indices = which(!seq_along(ds) %in% train_id)) train_dl <- dataloader(train_ds, batch_size = 64, shuffle = TRUE, num_workers = 4) valid_dl <- dataloader(valid_ds, batch_size = 64, num_workers = 4) # Building the network --------------------------------------------------- net <- torch::nn_module( initialize = function(num_classes) { self$model <- model_alexnet(pretrained = TRUE) for (par in self$parameters) { par$requires_grad_(FALSE) } self$model$classifier <- nn_sequential( nn_dropout(0.5), nn_linear(9216, 512), nn_relu(), nn_linear(512, 256), nn_relu(), nn_linear(256, num_classes) ) }, forward = function(x) { self$model(x)[,1] } ) # Train ------------------------------------------------------------------- fitted <- net %>% setup( loss = nn_bce_with_logits_loss(), optimizer = optim_adam, metrics = list( luz_metric_binary_accuracy_with_logits() ) ) %>% set_hparams(num_classes = 1) %>% set_opt_hparams(lr = 0.01) %>% fit(train_dl, epochs = 5, valid_data = valid_dl, verbose = TRUE) # Make predictions -------------------------------------------------------- preds <- torch_sigmoid(predict(fitted, valid_dl)) # Serialization ----------------------------------------------------------- luz_save(fitted, "model-dogs-and-cats.pt")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.