R/getUpdate.R

Defines functions getUpdate

Documented in getUpdate

#' Retrieve the weight updates and their change for each learning event.
#' 
#' @description For a given set of training data, 
#' the weight updating values are returned for each or specific outcomes. 
#' The values are returned as data frame.
#' @export
#' @import data.table
#' @param data Data with columns \code{Cues} and \code{Outcomes},
#' as generated with \code{\link{createTrainingData}}.
#' @param wmlist A list with weightmatrices, generated by 
#' \code{\link{RWlearning}} or \code{\link{updateWeights}}.
#' @param split String, separator between cues or outcomes.
#' @param select.outcomes Optional selection of outcomes to limit the number of 
#' activations that are returned. The value of NULL (default) will 
#' return all activations. Note that specified values that are not in 
#' the weightmatrices will return the initial value without error or warning. 
#' Please use  \code{\link{getValues}} for returning all outcomes in the data.
#' @param present.outcome Logical: whether or not to output the update 
#' for the present output only. Defaults to FALSE. Note that if set to true,
#' this parameter cancels the effect of \code{select.outcomes}.
#' @return Data frame.
#' @author Jacolien van Rij
#' @examples
#' # load example data:
#' data(dat)
#' 
#' # add obligatory columns Cues, Outcomes, and Frequency:
#' dat <- droplevels(dat[1:3,])
#' dat$Cues <- paste("BG", dat$Shape, dat$Color, sep="_")
#' dat$Outcomes <- dat$Category
#' dat$Frequency <- dat$Frequency1
#' head(dat)
#' 
#' 
#' # now use createTrainingData to sample from the specified frequencies: 
#' train <- createTrainingData(dat)
#' head(train)
#' 
#' # this training data can actually be used train network:
#' wm <- RWlearning(train)
#' 
#' # retrieve update values for all outcomes:
#' updates1 <- getUpdate(data=train, wmlist=wm)
#' head(updates1)
#' 
#' # retrieve update values for observed outcomes:
#' updates2 <- getUpdate(data=train, wmlist=wm, present.outcome=TRUE)
#' head(updates2)
#' 
#' # plot:
#' n <- which("animal" == train$Outcomes)
#' plot(n, updates2[n], type='l', 
#'     ylim=c(0,.1), 
#'     ylab="Weight updates", xlab="Learning event")
#' 

getUpdate <- function(wmlist, data, 
	select.outcomes = NULL, split="_", present.outcome=FALSE){

  	# check columns Cues, Outcomes
	if(!all(c("Cues", "Outcomes") %in% colnames(data))){
		stop("Data frame should contain columns 'Cues' and 'Outcomes'.")
	}
	if(length(wmlist) != nrow(data)){
		stop("Difference in size between wmlist and data.")
	}
	# determine outcomes to track:
	out <- colnames(getWM(wmlist))
	if(!is.null(select.outcomes)){
		out <- select.outcomes
	}


	if(present.outcome==TRUE){
		update <- sapply(1:nrow(data), function(x){
			cur.cues <- getValues(data$Cues[x], split=split)
			cur.out  <- getValues(data$Outcomes[x], split=split)
			wm.cur <- getWM(wmlist,x)
			wm.cur <- wm.cur[cur.cues,]
			if(length(cur.cues) > 1){
				wm.cur <- colSums(wm.cur)
				wm.cur <- wm.cur[cur.out]
			}else{
				wm.cur <- wm.cur[cur.out]
			}

			if(x==1){
				return(wm.cur)
			}else{
				wm.prev <- getWM(wmlist, x-1)
				wm.prev <- wm.prev[cur.cues,]
				if(length(cur.cues) > 1){
					wm.prev <- colSums(wm.prev)
					wm.prev <- wm.prev[cur.out]
				}
				return(wm.cur - wm.prev)
			}
		}, simplify = FALSE)
		update <- unlist(update)
		return(update)
	}else{
		update <- sapply(1:nrow(data), function(x){
			cur.cues <- getValues(data$Cues[x], split=split)
			wm.cur <- getWM(wmlist,x)
			wm.cur <- wm.cur[cur.cues,]
			if(length(cur.cues) > 1){
				wm.cur <- colSums(wm.cur)
				wm.cur <- wm.cur[out]
			}
			if(x==1){
				return(wm.cur)
			}else{
				wm.prev <- getWM(wmlist, x-1)
				wm.prev <- wm.prev[cur.cues,]
				if(length(cur.cues) > 1){
					wm.prev <- colSums(wm.prev)
					wm.prev <- wm.prev[out]
				}
				return(wm.cur - wm.prev)
			}
		}, simplify = FALSE)
		update <- do.call("rbind", update)
		return(update)
	}
}

Try the edl package in your browser

Any scripts or data that you put into this service are public.

edl documentation built on Sept. 20, 2021, 9:09 a.m.