#'TVLoss
#'
#'@export
#'@import reticulate
#'@references \insertRef{Gatys2016}{neuralstyleR}
#'
ContentLoss<-function(contentWeight = 1e-3, normalizeGradients=FALSE){
ContentLossMain = py_run_string(
"
import torch.nn as nn
import torch
class ContentLoss(nn.Module):
def __init__(self, strength, normalizeGradients):
super(ContentLoss, self).__init__()
self.strength = strength
self.target = torch.Tensor()
self.normalize = normalizeGradients
self.loss = 0
self.crit = nn.MSELoss()
self.mode = 'none'
def forward(self, input):
if self.mode == 'loss':
self.loss = self.crit.forward(input, self.target) * self.strength
elif self.mode == 'capture':
#self.target.resize_as(input).copy(input)
self.target = input.clone()
self.output = input
return self.output
def backward(self, input, gradOutput):
if self.mode == 'loss':
if input.nelement() == self.target.nelement():
self.gradInput = self.crit.backward(input, self.target)
if self.normalize:
self.gradInput.div(torch.norm(self.gradInput, 1) + 1e-8)
self.gradInput.mul(self.strength)
self.gradInput.add(gradOutput)
else:
#self.gradInput.resize_as(gradOutput).copy(gradOutput)
self.gradInput = gradOutput.clone()
return self.gradInput")
return(ContentLossMain$ContentLoss(contentWeight, normalizeGradients))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.