#'TVLoss
#'
#'@export
#'@import reticulate
#'@references \insertRef{Gatys2016}{neuralstyleR}
#'
StyleLoss<-function(styleWeight = 1e2, normalizeGradients=FALSE){
StyleLossMain = py_run_string(
"
import torch.nn as nn
import torch
class StyleLoss(nn.Module):
def __init__(self, strength, normalizeGradients, gram):
super(StyleLoss, self).__init__()
self.strength = strength
self.target = torch.Tensor()
self.normalize = normalizeGradients
self.loss = 0
self.crit = nn.MSELoss()
self.mode = 'none'
self.gram = gram
self.blend_weight = None
self.G = None
def forward(self, input):
self.G = self.gram.forward(input)
self.G.div(input.nelement())
if self.mode == 'loss':
self.loss = self.strength * self.crit.forward(self.G, self.target)
elif self.mode == 'capture':
if self.blend_weight is None:
#self.target.resize_as(self.G).copy(self.G)
self.target = self.G.clone()
elif self.target.nelement() == 0:
#self.target.resize_as(self.G).copy(self.G).mul(self.blend_weight)
self.target = self.G.clone().mul(self.blend_weight)
else:
self.target.add(self.blend_weight, self.G)
self.output = input
return self.output
def backward(self, input, gradOutput):
if self.mode == 'loss':
dG = self.crit.backward(self.G, self.target)
dG.div(input.nelement())
self.gradInput = self.gram.backward(input, dG)
if self.normalize:
self.gradInput.div(torch.norm(self.gradInput, 1) + 1e-8)
self.gradInput.mul(self.strength)
self.gradInput.add(gradOutput)
else:
self.gradInput = gradOutput
return self.gradInput")
return(StyleLossMain$StyleLoss(styleWeight, normalizeGradients, GramMatrix()))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.