R/trainGweonsNearestNeighbor.R

Defines functions trainGweonsNearestNeighbor

Documented in trainGweonsNearestNeighbor

#' Trains Gweons Nearest Neighbor model
#'
#' Function does some preprocessing and creates a document term matrix to be used for the Nearest Neighbor model.
#'
#' @param data a data.table created with \code{\link{removeFaultyAndUncodableAnswers_And_PrepareForAnalysis}}
#' @param preprocessing a list with elements
#' \describe{
#'   \item{stopwords}{a character vector, use \code{tm::stopwords("de")} for German stopwords.}
#'   \item{stemming}{\code{NULL} for no stemming and \code{"de"} for stemming using the German porter stemmer.}
#'   \item{strPreprocessing}{\code{TRUE} if \code{\link{stringPreprocessing}} shall be used.}
#'   \item{removePunct}{\code{TRUE} if \code{\link[tm]{removePunctuation}} shall be used.}
#' }
#'
#' @seealso
#' \code{\link{predictGweonsNearestNeighbor}}
#'
#' Gweon, H.; Schonlau, M., Kaczmirek, L., Blohm, M., Steiner, S. (2017). Three Methods for Occupation Coding Based on Statistical Learning. Journal of Official Statistics 33(1), pp. 101--122
#'
#' @return a document term matrix with some additional attributes
#' @import data.table
#' @import text2vec
#' @export
#'
#' @examples
#' # set up data
#' data(occupations)
#' allowed.codes <- c("71402", "71403", "63302", "83112", "83124", "83131", "83132", "83193", "83194", "-0004", "-0030")
#' allowed.codes.titles <- c("Office clerks and secretaries (without specialisation)-skilled tasks", "Office clerks and secretaries (without specialisation)-complex tasks", "Gastronomy occupations (without specialisation)-skilled tasks",
#'  "Occupations in child care and child-rearing-skilled tasks", "Occupations in social work and social pedagogics-highly complex tasks", "Pedagogic specialists in social care work and special needs education-unskilled/semiskilled tasks", "Pedagogic specialists in social care work and special needs education-skilled tasks", "Supervisors in education and social work, and of pedagogic specialists in social care work", "Managers in education and social work, and of pedagogic specialists in social care work",
#'  "Not precise enough for coding", "Student assistants")
#' proc.occupations <- removeFaultyAndUncodableAnswers_And_PrepareForAnalysis(occupations, colNames = c("orig_answer", "orig_code"), allowed.codes, allowed.codes.titles)
#'
#' # Recommended configuration
#' dtmModel <- trainGweonsNearestNeighbor(proc.occupations,
#'                  preprocessing = list(stopwords = tm::stopwords("de"), stemming = "de", strPreprocessing = TRUE, removePunct = FALSE))
#' # Configuration used by Gweon et al. (2017)
#' dtmModel <- trainGweonsNearestNeighbor(proc.occupations,
#'                  preprocessing = list(stopwords = tm::stopwords("de"), stemming = "de", strPreprocessing = FALSE, removePunct = TRUE))
#' # Configuration used for most other approaches in this package
#' dtmModel <- trainGweonsNearestNeighbor(proc.occupations,
#'                  preprocessing = list(stopwords = character(0), stemming = NULL, strPreprocessing = TRUE, removePunct = FALSE))
#'
#' #######################################################
#' ## RUN A GRID SEARCH (takes some time)
#' \donttest{
#' # create a grid of all combinations to be tried
#' model.grid <- data.table(expand.grid(stopwords = c(TRUE, FALSE), stemming = c(FALSE, "de"), strPreprocessing = c(TRUE, FALSE), nearest.neighbors.multiplier = c(0.05, 0.1, 0.2)))
#'
#' # Do grid search
#' for (i in 1:nrow(model.grid)) {
#'   res.model <- trainGweonsNearestNeighbor(splitted.data$training, preprocessing = list(stopwords = if (model.grid[i, stopwords]) tm::stopwords("de") else character(0),
#'                                                                                        stemming = if (model.grid[i, stemming == "de"]) "de" else NULL,
#'                                                                                        strPreprocessing = model.grid[i, strPreprocessing],
#'                                                                                        removePunct = !model.grid[i, strPreprocessing]))
#'
#'   res.proc <- predictGweonsNearestNeighbor(res.model, splitted.data$test,
#'                                         tuning = list(nearest.neighbors.multiplier = model.grid[i, nearest.neighbors.multiplier]))
#'   res.proc <- expandPredictionResults(res.proc, allowed.codes = allowed.codes, method.name = "NearestNeighbor_Gweon")
#'
#'   ac <- accuracy(calcAccurateAmongTopK(res.proc, k = 1), n = nrow(splitted.data$test))
#'   ll <- logLoss(res.proc)
#'   sh <- sharpness(res.proc)
#'
#'   model.grid[i, acc := ac[, acc]]
#'   model.grid[i, acc.se := ac[, se]]
#'   model.grid[i, acc.N := ac[, N]]
#'   model.grid[i, acc.prob0 := ac[, count.pred.prob0]]
#'   model.grid[i, loss.full := ll[1, logscore]]
#'   model.grid[i, loss.full.se := ll[1, se]]
#'   model.grid[i, loss.full.N := ll[1, N]]
#'   model.grid[i, loss.sub := ll[2, logscore]]
#'   model.grid[i, loss.sub.se := ll[2, se]]
#'   model.grid[i, loss.sub.N := ll[2, N]]
#'   model.grid[i, sharp := sh[, sharpness]]
#'   model.grid[i, sharp.se := sh[, se]]
#'   model.grid[i, sharp.N := sh[, N]]
#' }
#'
#' model.grid[order(stopwords, stemming, strPreprocessing, nearest.neighbors.multiplier)]
#'
#'
#' }
trainGweonsNearestNeighbor <- function(data,
                                       preprocessing = list(stopwords = tm::stopwords("de"), stemming = "de", strPreprocessing = FALSE, removePunct = TRUE)) {

  # preprocessing
  if (preprocessing$removePunct) {
    ans <- data[, tm::removePunctuation(ans)]
  } else {
    ans <- data[, ans]
  }

  if (preprocessing$strPreprocessing) {
    ans <- stringPreprocessing(ans)
  }

  # prepare text for efficient computation -> transform to sparse matrix
  matrix <- asDocumentTermMatrix(ans, vect.vocab = NULL,
                                 stopwords = preprocessing$stopwords,
                                 stemming = preprocessing$stemming,
                                 type = "dgCMatrix")

  return(list(matrix = matrix$dtm, vect.vocab = matrix$vect.vocab, preprocessing = preprocessing, code = data[,code],
              num.allowed.codes = length(attr(data, "classification")$code)))
}
malsch/occupationCoding documentation built on March 14, 2024, 8:09 a.m.