Nothing
#############################################################################
#
# This file is part of the R package "RSNNS".
#
# Author: Christoph Bergmeir
# Supervisor: José M. Benítez
# Copyright (c) DiCITS Lab, Sci2s group, DECSAI, University of Granada.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Library General Public
# License as published by the Free Software Foundation; either
# version 2 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Library General Public License for more details.
#
# You should have received a copy of the GNU Library General Public License
# along with this library; see the file COPYING.LIB. If not, write to
# the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
# Boston, MA 02110-1301, USA.
#
#############################################################################
#' SnnsR low-level function to train a network and test it in every training iteration.
#'
#' @title Train a network and test it in every training iteration
#' @param inputsTrain a matrix with inputs for the network
#' @param targetsTrain the corresponding targets
#' @param initFunc the initialization function to use
#' @param initFuncParams the parameters for the initialization function
#' @param learnFunc the learning function to use
#' @param learnFuncParams the parameters for the learning function
#' @param updateFunc the update function to use
#' @param updateFuncParams the parameters for the update function
#' @param outputMethod the output method of the net
#' @param maxit maximum of iterations to learn
#' @param shufflePatterns should the patterns be shuffled?
#' @param computeError should the error be computed in every iteration?
#' @param inputsTest a matrix with inputs to test the network
#' @param targetsTest the corresponding targets for the test input
#' @param pruneFunc the pruning function to use
#' @param pruneFuncParams the parameters for the pruning function. Unlike the other functions,
#' these have to be given in a named list. See the pruning demos for further explanation.
#' @return a list containing:
#' \item{fitValues}{the fitted values, i.e. outputs of the training inputs}
#' \item{IterativeFitError}{The SSE in every iteration/epoch on the training set}
#' \item{testValues}{the predicted values, i.e. outputs of the test inputs}
#' \item{IterativeTestError}{The SSE in every iteration/epoch on the test set}
#' @rdname SnnsRObject-train
#' @name SnnsRObject$train
#' @usage \S4method{train}{SnnsR}(inputsTrain, targetsTrain=NULL,
#' initFunc="Randomize_Weights", initFuncParams=c(1.0, -1.0),
#' learnFunc="Std_Backpropagation", learnFuncParams=c(0.2, 0),
#' updateFunc="Topological_Order", updateFuncParams=c(0.0),
#' outputMethod="reg_class", maxit=100, shufflePatterns=TRUE,
#' computeError=TRUE, inputsTest=NULL, targetsTest=NULL,
#' pruneFunc=NULL, pruneFuncParams=NULL)
#' @aliases train,SnnsR-method SnnsR__train
SnnsR__train <- function(snnsObject, inputsTrain, targetsTrain=NULL,
initFunc="Randomize_Weights", initFuncParams=c(1.0, -1.0),
learnFunc="Std_Backpropagation", learnFuncParams=c(0.2, 0),
updateFunc="Topological_Order", updateFuncParams=c(0.0), outputMethod="reg_class",
maxit=100, shufflePatterns=TRUE, computeError=TRUE, inputsTest=NULL, targetsTest=NULL,
serializeTrainedObject=TRUE, pruneFunc=NULL, pruneFuncParams=NULL) {
testing <- TRUE
if(is.null(inputsTest)) testing <- FALSE
result <- list()
#use the maximal possible amount of parameters, otherwise, unused parameters
#might yield unforeseeable effects in the C code..
expandedLearnFuncParams <- c(0,0,0,0,0)
for(i in 1:length(learnFuncParams)) expandedLearnFuncParams[i] <- learnFuncParams[i]
expandedInitFuncParams <- c(0,0,0,0,0)
for(i in 1:length(initFuncParams)) expandedInitFuncParams[i] <- initFuncParams[i]
expandedUpdateFuncParams <- c(0,0,0,0,0)
for(i in 1:length(updateFuncParams)) expandedUpdateFuncParams[i] <- updateFuncParams[i]
if(is.null(pruneFunc)) {
snnsObject$setLearnFunc(learnFunc)
} else {
snnsObject$setFFLearnFunc(learnFunc)
snnsObject$setPrunFunc(pruneFunc)
computeError <- TRUE
}
snnsObject$setUpdateFunc(updateFunc)
patSetTrain <- NULL
if(is.null(targetsTrain)) {
patSetTrain <- snnsObject$createPatSet(inputsTrain)
} else {
patSetTrain <- snnsObject$createPatSet(inputsTrain, targetsTrain)
}
if(computeError)
errorTrain <- vector()
if(testing) {
patSetTest <- snnsObject$createPatSet(inputsTest, targetsTest)
errorTest <- vector()
}
#print(patSetTrain)
#print(patSetTest)
snnsObject$shufflePatterns(shufflePatterns)
snnsObject$setCurrPatSet(patSetTrain$set_no)
snnsObject$DefTrainSubPat()
if(learnFunc == "RadialBasisLearning") {
#cat("RBF: First Initialization step with Kohonen\n")
snnsObject$initializeNet(c(0,0,0,0,0), "RBF_Weights_Kohonen")
}
snnsObject$initializeNet(expandedInitFuncParams, initFunc)
for(i in 1:maxit) {
res <- NULL
if(is.null(pruneFunc)) {
res <- snnsObject$learnAllPatterns(expandedLearnFuncParams)
} else {
res <- snnsObject$learnAllPatternsFF(expandedLearnFuncParams)
}
#if(res[[1]] != 0) print(paste("An error occured at iteration ", i, " : ", res, sep=""))
if(computeError)
errorTrain[i] <- res[[2]]
if(testing) {
snnsObject$setCurrPatSet(patSetTest$set_no)
snnsObject$DefTrainSubPat()
#TODO: Why doesn't testAllPatterns work with RadialBasisLearning?
#And learning even with all parameters = 0 alters the results..
#if(learnFunc == "RadialBasisLearning")
# res <- snnsObject$learnAllPatterns(c(0,0,0,0.1,0.8))
#else
res <- snnsObject$testAllPatterns(expandedLearnFuncParams)
#if(res[[1]] != 0) print(paste("An error occured at iteration ", i, " : ", res, sep=""))
errorTest[i] <- res[[2]]
#print(res)
snnsObject$setCurrPatSet(patSetTrain$set_no)
snnsObject$DefTrainSubPat()
}
}
#-----------------------------------------------------
# pruning
if(!is.null(pruneFunc)) {
#TODO error handling if parameters are not there.
max_pr_error_increase <- pruneFuncParams$max_pr_error_increase
pr_accepted_error <- pruneFuncParams$pr_accepted_error
#pr_recreate <- TRUE
no_of_pr_retrain_cycles <- pruneFuncParams$no_of_pr_retrain_cycles
min_error_to_stop <- pruneFuncParams$min_error_to_stop
init_matrix_value <- pruneFuncParams$init_matrix_value
input_pruning <- pruneFuncParams$input_pruning
hidden_pruning <- pruneFuncParams$hidden_pruning
first_error <- errorTrain[maxit]
# maximum error
max_error = first_error * (1 + max_pr_error_increase / 100);
if(max_error < pr_accepted_error)
max_error = pr_accepted_error
PR_ALL_PATTERNS <- -1
net_error <- vector()
retrain_loop_count <- 0
# pruning until net error <= maximum error
repeat {
# delete some weights
pr_res <- snnsObject$callPrunFunc(PR_ALL_PATTERNS)
# calculate net error
pr_res <- snnsObject$calcMeanDeviation(PR_ALL_PATTERNS)
net_error[1] <- pr_res[[2]]
# retrain network
if (net_error[1] > min_error_to_stop){
for(j in 1:no_of_pr_retrain_cycles) {
re_res <- snnsObject$learnAllPatternsFF(expandedLearnFuncParams)
net_error[j] <- re_res[[2]]
#if (j %% 100 == 0) print(c(retrain_loop_count, (format(j,digits=2)), net_error[j]))
}
#plot(net_error, type="l")
}
if (net_error[j] <= max_error) {
retrain_loop_count <- retrain_loop_count + 1
} else {
break
}
}
}
#-----------------------------------------------------
if(computeError)
result$IterativeFitError <- errorTrain
snnsObject$setCurrPatSet(patSetTrain$set_no)
result$fitValues <- snnsObject$predictCurrPatSet(outputMethod, expandedUpdateFuncParams)
if(testing) {
result$IterativeTestError <- errorTest
snnsObject$setCurrPatSet(patSetTest$set_no)
result$testValues <- snnsObject$predictCurrPatSet(outputMethod, expandedUpdateFuncParams)
snnsObject$deletePatSet(patSetTest$set_no)
} else {
result$IterativeTestError <- NULL
result$testValues <- NULL
}
#snns auto-reorganizes the pattern set numbers, so the first generated pattern set
#has to be deleted at last
snnsObject$deletePatSet(patSetTrain$set_no)
if(serializeTrainedObject) {
s <- snnsObject$serializeNet("RSNNS_untitled")
snnsObject@variables$serialization <- s$serialization
}
result
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.