R/ContentLoss.R

#'TVLoss
#'
#'@export
#'@import reticulate
#'@references \insertRef{Gatys2016}{neuralstyleR}
#'
ContentLoss<-function(contentWeight = 1e-3, normalizeGradients=FALSE){

    ContentLossMain = py_run_string(
        "
import torch.nn as nn
import torch
class ContentLoss(nn.Module):
        def __init__(self, strength, normalizeGradients):
            super(ContentLoss, self).__init__()
            self.strength = strength
            self.target = torch.Tensor()
            self.normalize = normalizeGradients
            self.loss = 0
            self.crit = nn.MSELoss()
            self.mode = 'none'
        def forward(self, input):
            if self.mode == 'loss':
                self.loss = self.crit.forward(input, self.target) * self.strength
            elif self.mode == 'capture':
                #self.target.resize_as(input).copy(input)
                self.target = input.clone()
            self.output = input
            return self.output
        def backward(self, input, gradOutput):
            if self.mode == 'loss':
                if input.nelement() == self.target.nelement():
                    self.gradInput = self.crit.backward(input, self.target)
                if self.normalize:
                    self.gradInput.div(torch.norm(self.gradInput, 1) + 1e-8)
                self.gradInput.mul(self.strength)
                self.gradInput.add(gradOutput)
            else:
                #self.gradInput.resize_as(gradOutput).copy(gradOutput)
                self.gradInput = gradOutput.clone()
            return self.gradInput")

    return(ContentLossMain$ContentLoss(contentWeight, normalizeGradients))

}
David-J-R/neuralstyleR documentation built on May 8, 2019, 1:54 p.m.