R/prepareNetwork.R

#' prepareNetwork
#'@export
#'@import reticulate
#'@references \insertRef{Simonyan2014}{neuralstyleR}
#'@references \insertRef{Gatys2016}{neuralstyleR}
#'
prepareNetwork<-function(tvWeight = 1e-3, poolingMethod = "avg",
                    normalizeGradients = FALSE, contentWeight = 5e0,
                    styleWeight = 1e2, dtype="torch.FloatTensor"){
    torch <- import("torch")
    nn <- import("torch.nn")
    torchvision <- import("torchvision")


    cnn<-torchvision$models$vgg19(pretrained = TRUE)$type(dtype)

    net<-nn$Sequential()

    if(tvWeight > 0){
        net$add_module("TVLoss", TVLoss(tvWeight = tvWeight)$type(dtype))
    }


    contentLayers <- c()
    styleLayers <- c()
    for(i in 0:13){

        if(poolingMethod == "avg" & i %in% c(4, 9)){
            message(paste("Replacing max pooling at layer ", i, "with average pooling"))
            avgMod <- nn$AvgPool2d(kernel_size = cnn$features[i]$kernel_size,
                                   stride = cnn$features[i]$stride,
                                   padding = cnn$features[i]$padding,
                                   ceil_mode = cnn$features[i]$ceil_mode)
            net$add_module(as.character(i), avgMod$type(dtype))
        }else{
            net$add_module(as.character(i), cnn$features[i])
        }

        if(i == 11){
            message("Setting up content layer")
            contentLayer <- ContentLoss(contentWeight, normalizeGradients)$type(dtype)
            contentLayers <- c(contentLayers, contentLayer)
            net$add_module(paste("contentlayer", i), contentLayer)
            contentLayers

        }else if(i %in% c(1, 3, 6, 8 ,13)){
            message("Setting up style layer")
            styleLayer <- StyleLoss(styleWeight, normalizeGradients)$type(dtype)
            styleLayers <- c(styleLayers, styleLayer)
            net$add_module(paste("styletlayer", i), styleLayer)
        }

    }

    net$type(dtype)
    return(list(net=net, contentLayers = contentLayers, styleLayers = styleLayers))

}
David-J-R/neuralstyleR documentation built on May 8, 2019, 1:54 p.m.