R/StyleLoss.R

#'TVLoss
#'
#'@export
#'@import reticulate
#'@references \insertRef{Gatys2016}{neuralstyleR}
#'
StyleLoss<-function(styleWeight = 1e2, normalizeGradients=FALSE){

    StyleLossMain = py_run_string(
        "
import torch.nn as nn
import torch
class StyleLoss(nn.Module):
        def __init__(self, strength, normalizeGradients, gram):
            super(StyleLoss, self).__init__()
            self.strength = strength
            self.target = torch.Tensor()
            self.normalize = normalizeGradients
            self.loss = 0
            self.crit = nn.MSELoss()
            self.mode = 'none'
            self.gram = gram
            self.blend_weight = None
            self.G = None
        def forward(self, input):
            self.G = self.gram.forward(input)
            self.G.div(input.nelement())
            if self.mode == 'loss':
                self.loss = self.strength * self.crit.forward(self.G, self.target)
            elif self.mode == 'capture':
                if self.blend_weight is None:
                    #self.target.resize_as(self.G).copy(self.G)
                    self.target = self.G.clone()
                elif self.target.nelement() == 0:
                    #self.target.resize_as(self.G).copy(self.G).mul(self.blend_weight)
                    self.target = self.G.clone().mul(self.blend_weight)
                else:
                    self.target.add(self.blend_weight, self.G)
            self.output = input
            return self.output
        def backward(self, input, gradOutput):
            if self.mode == 'loss':
                dG = self.crit.backward(self.G, self.target)
                dG.div(input.nelement())
                self.gradInput = self.gram.backward(input, dG)
                if self.normalize:
                    self.gradInput.div(torch.norm(self.gradInput, 1) + 1e-8)
                self.gradInput.mul(self.strength)
                self.gradInput.add(gradOutput)
            else:
                self.gradInput = gradOutput
            return self.gradInput")

    return(StyleLossMain$StyleLoss(styleWeight, normalizeGradients, GramMatrix()))

}
David-J-R/neuralstyleR documentation built on May 8, 2019, 1:54 p.m.