R/segment.r

Defines functions segment

Documented in segment

#' Segment an FemFit dataset
#'
#' @description
#' Conducts principal component analysis then fits a regression tree to the first principal component. Event terminal nodes within a JSONLabel are identified by "large" mean deviations.
#'
#' @param x An "FemFit" object.
#' @param cp. The complexity parameter provided to the regression tree.
#' @param numOfNodesToLabel The number of nodes to label as events within a JSONLabel.
#' @param JSONLabel_Baseline The JSONLabels corresponding to the definite baselines.
#' @param JSONLabel_Sequence The JSONLabel sequence corresponding to \code{numOfNodesToLabel}. Defaults to \code{unique(x$df$JSONLabel)}.
#' @param xval. The number of cross-validations parameter provided to the regression tree.
#' @param ... Arguments passed to \code{\link{rpart}}.
#'
#' @details
#' The elements of the list, \code{numOfNodesToLabel}, maps to each unique \code{sessionID}. The numeric vector corresponds to the number of terminal nodes to label as events within each level of \code{JSONLabel}, where the numeric vector maps to \code{JSONLabel_Sequence}.
#'
#' \code{cp.}, \code{numOfNodesToLabel}, \code{JSONLabel_Baseline}, \code{JSONLabel_Sequence}, and \code{xval.} are recylced to the length of \code{unique(x$df$sessionID)}. \code{...} arguments are not vectorised.
#'
#' The length of the numeric vector depends on the number of non-\code{JSONLabel_Baseline} device labels.
#'
#' @return
#' Updates the \code{errorSummary} element found in the "FemFit" object. This contains a list of terminal nodes with their assigned label and timestamps.
#'
#' @seealso
#' \code{\link{prcomp}} and \code{\link{rpart}} for how the main statistical learning components work.
#'
#' \code{\link{segmentRefine_pcurve}} and \code{\link{segmentRefine_protocol}} for functions to identifiy events in noisy pressure traces post-segmentation.
#'
#' @examples
#' # Read in the FemFit data
#' AS005 = read_FemFit(c(
#'         "Datasets_AukRepeat/61aa0782289af385_283_csv.zip",
#'         "Datasets_AukRepeat/61aa0782289af385_284_csv.zip"
#'     ),
#'     remove.NAs = TRUE
#'   )
#'
#' # Segment the FemFit data
#' AS005 = segment(
#'     x = AS005
#'     # Use the same control parameter for both sessions
#'     cp. = 0.001,
#'     # Define a different set in the number of terminal nodes to edit for each session
#'     numOfNodesToLabel = list(c(3, 1, 3, 4), c(4, 1, 5, 3)),
#'     # Set the baseline JSONLabels. segment() will look for both "relax30s_A" and "relax30s_B" in each session
#'     JSONLabel_Baseline = list(c("relax30s_A", "relax30s_B")),
#'     # Works in conjunction with the numOfNodesToLabel argument.
#'     # segment() will assume that the order of protocol is unique(AS005$df$JSONLabel), excluding the baseline JSONLabels
#'     JSONLabel_Sequence = list("")
#'   )
#'
#' @export
segment = function(x, cp. = 0.005, numOfNodesToLabel = list(c(3, 1, 6, 3)), JSONLabel_Baseline = list(c("relax30s_A", "relax30s_B")), JSONLabel_Sequence = list(""), xval. = 100, ...) {
  # Throw an error if the x argument is not an FemFit object or missing
  if (!inherits(x, "FemFit") || is.na(x)) {
    stop("The x argument is not an FemFit object.", call. = FALSE)
  }

  # Throw an error if cp. is not a numeric or it has any NAs
  if (!is.numeric(cp.) || anyNA(cp.)) {
    stop("The provided cp. value is not a numeric.", call. = FALSE)
  }

  # Throw an error if cp. is less than 0
  if (any(cp. < 0)) {
    stop("The provided cp. value is less than zero. See ?rpart for details.", call. = FALSE)
  }

  # Throw a warning if cp. is less than 0.001
  if (any(cp. < 0.001)) {
    warning("The provided cp. value is less than 0.001. The naive labelling may not be easily reproducible without a preceding set.seed() call.", call. = FALSE)
  }

  # Throw an error if numOfNodesToLabel is not a numeric or it has any NAs
  if (lapply(numOfNodesToLabel, is.numeric) %>% unlist %>% all %>% !. || lapply(numOfNodesToLabel, anyNA) %>% unlist %>% all) {
    stop("A number of provided numOfNodesToLabel values are not numerics.", call. = FALSE)
  }

  # Throw an error if JSONLabel_Baseline is not a character or it has any NAs
  if (lapply(JSONLabel_Baseline, is.character) %>% unlist %>% all %>% !. || lapply(JSONLabel_Baseline, anyNA) %>% unlist %>% all) {
    stop("A number of provided JSONLabel_Baseline value(s) are not characters.", call. = FALSE)
  }

  # Throw an error if JSONLabel_Seq is not a character or it has any NAs
  if (lapply(JSONLabel_Sequence, is.character) %>% unlist %>% all %>% !. || lapply(JSONLabel_Sequence, anyNA) %>% unlist %>% all) {
    stop("A number of provided JSONLabel_Sequence value(s) are not characters.", call. = FALSE)
  }

  # Throw an error if xval. is not a numeric or it has any NAs
  if (!is.numeric(cp.) || anyNA(cp.)) {
    stop("The provided xval. value is not a numeric.", call. = FALSE)
  }

  # Throw a warning if xval. is not a whole number
  if (any(xval. %% 1 != 0)) {
    warning("The provided xval. value is not a whole number. segment() will coerce it into a whole number.", call. = FALSE)
    xval. = round(xval., 0)
  }

  # Throw an error if xval. is less than 0.001
  if (any(xval. < 1)) {
    stop("The provided xval. value is less than 1. Must be a value greater than 1.", call. = FALSE)
  }

  # Recycle elements of cp, numOfNodesToLabel, JSONLabel_Baseline, JSONLabel_Sequence, and xval. if necessary
    # Also, name the elements within each vector with sessionID
  sessionIDs = unique(x$df$sessionID)
  sessionID.len = length(unique(x$df$sessionID))
  cp. = rep(cp., length.out = sessionID.len)
    names(cp.) = sessionIDs
  numOfNodesToLabel = rep(numOfNodesToLabel, length.out = sessionID.len)
    names(numOfNodesToLabel) = sessionIDs
  JSONLabel_Baseline = rep(JSONLabel_Baseline, length.out = sessionID.len)
    names(JSONLabel_Baseline) = sessionIDs
  JSONLabel_Sequence = rep(JSONLabel_Sequence, length.out = sessionID.len)
    names(JSONLabel_Sequence) = sessionIDs
  xval. = rep(xval., length.out = sessionID.len)
    names(xval.) = sessionIDs

  # Construct the regular expression with the JSONLabel_Baseline components
  regexp.JSONLabels = sapply(JSONLabel_Baseline, function (child) paste(child, collapse = "|"))

  # Check whether the length of numOfNodesToLabel is equal to the number of non-baseline JSONLabels
  numOfNodesToLabel = lapply(sessionIDs, function (child) {
    if (all(JSONLabel_Sequence[[child]] == "")) {
      # obs.JSONLabels defaults to unique(x$df$JSONLabel) if JSONLabel_Sequence[[child]] is missing
      obs.JSONLabels = unique(x$df %>% dplyr::filter(sessionID == child) %>% dplyr::pull(JSONLabel))
    } else {
      obs.JSONLabels = JSONLabel_Sequence[[child]] %>% unique()
    }

    if (length(numOfNodesToLabel[[child]]) != length(obs.JSONLabels[!grepl(regexp.JSONLabels[child], obs.JSONLabels)])) {
      stop(paste0("The number of non-exercise JSONLabels found in Session ", child, " does not much the length of the corresponding numOfNodesToLabel vector argument"), call. = FALSE)
    }

    names(numOfNodesToLabel[[child]]) = obs.JSONLabels[!grepl(regexp.JSONLabels[child], obs.JSONLabels)]
    numOfNodesToLabel[[child]]
  })
  names(numOfNodesToLabel) = sessionIDs

  # Throw a warning if trLabel exists in x$df and remove it from x$df
  if (x$df %>% colnames %>% {any(grepl("^trLabel$", .))}) {
    warning("trLabel already exists in the FemFit object. segment() will overwrite the previous trLabel.", call. = FALSE)
    x$df = x$df %>%
      dplyr::select(-trLabel)
  }

  x_Work = by(x$df, x$df$sessionID, function (x_Child) {
    sessionID_Child = x_Child$sessionID[1]

    # Generate the first principal component of the eight pressure traces
    x_Child$pc.1 =  x_Child %>%
      dplyr::select(dplyr::starts_with("prssr_sensor")) %>%
      dplyr::select_if(function(x) (var(x) != 0)) %>%
      prcomp(., center = TRUE, scale. = TRUE) %>%
      .$x %>%
      .[, 1]

    # Fit a regression tree to split the first principal component with the explanatory variables time and JSONLabel
    rgrssnTree = rpart::rpart(formula = pc.1 ~ time + JSONLabel, data = x_Child, method = "anova", cp = cp.[sessionID_Child], xval = xval.[sessionID_Child], ...) %>%
      # Prune the regression tree with the one std. error rule
      prune(., cp = with(dplyr::as_tibble(.$cptable), max(CP[which(xerror <= min(xerror) + xstd[which(xerror == min(xerror))[1]])[1]])))

    # Define the temporal order
    curOrder = unique(rgrssnTree$where)
    newOrder = 1:length(curOrder)

    # Load in the regression splits into the data.frame
    x_Child$trmnlNode =
      # Recode the nodes to labelled in temporal order
      sapply(rgrssnTree$where, function (x) {newOrder[which(x == curOrder)]}) %>%
      {
        # Create a data.frame object which only has the recoded nodes and their JSONLabels
        df = data.frame(trmnlNode = ., JSONLabel = dplyr::pull(x_Child, JSONLabel))

        # Identify where a node spans over two or more JSONLabels
        vlookup = unique(df) %>%
          dplyr::mutate(trmnlNode_New = row_number())

        # Create the updated nodes
        df = df %>%
          dplyr::left_join(y = vlookup, by = c("trmnlNode", "JSONLabel")) %>%
          dplyr::pull(trmnlNode_New)
      }

    # Extract out the (principal component one) means for the device recorded baselines
    fitted.Mus = x_Child %>%
      dplyr::filter(grepl(regexp.JSONLabels[sessionID_Child], JSONLabel)) %>%
      dplyr::group_by(trmnlNode) %>%
      dplyr::summarise(fitted = mean(pc.1)) %>% {
        setNames(dplyr::pull(., fitted), paste0("bslineNode_", unique(.$trmnlNode)))
      }

    # Create a new set of labels to store the node labels into
    trLabels = gsub("bslineNode_([0-9]*)", "trLabel_\\1", names(fitted.Mus))
    names(trLabels) = names(fitted.Mus)
    trLabels.len = length(trLabels)

    # Load in the right numOfNodesToLabel argument...
    numOfNodesToLabel_Child = numOfNodesToLabel[[sessionID_Child]]

    # Create a data.frame object which contains each node's prediction error for each element in fitted.Mus
    predErrs = vapply(fitted.Mus, FUN.VALUE = numeric(nrow(x_Child)), FUN = function (mu) {
        (x_Child$pc.1 - mu)^2
      }) %>%
      dplyr::as_tibble() %>%
      dplyr::bind_cols(x_Child) %>%
      dplyr::group_by(sessionID, JSONLabel, trmnlNode) %>%
      dplyr::summarise_at(names(fitted.Mus), funs(sqrt(mean(., na.rm = TRUE)))) %>%
      # Create naive labels of baselines and events
      by(., .$JSONLabel, function (predErrs_Child) {
        # Create the naive baseline and event labels for the exercises
        if (!grepl(regexp.JSONLabels[sessionID_Child], predErrs_Child$JSONLabel[1])) {
          predErrs_Child = lapply(trLabels, function (current_Label) {
            predErrs_GrandChild = predErrs_Child %>%
              dplyr::arrange_(paste0("desc(", names(trLabels[which(trLabels == current_Label)]), ")"))

            predErrs_GrandChild[current_Label] = "baseline"
            predErrs_GrandChild[1:numOfNodesToLabel_Child[predErrs_Child$JSONLabel[1]], current_Label] = "event"

            predErrs_GrandChild[c("trmnlNode", current_Label)]
          }) %>%
          Reduce(function(tdf1, tdf2) dplyr::left_join(tdf1, tdf2, by = "trmnlNode"), .) %>%
          {dplyr::left_join(predErrs_Child, ., by = "trmnlNode")}
        # Else, initialise all node labels as baselines
        } else {
          predErrs_Child[trLabels] = "baseline"
        }

        return (predErrs_Child)
      }) %>%
      {Reduce(function(dtf1, dtf2) dplyr::bind_rows(dtf1, dtf2), .)} %>%
      dplyr::mutate(trLabel = NA_real_)

    # Majority vote to get the overall naive label for a node
    predErrs$trLabel = apply(predErrs, 1, function (row_Child) {
      row_Child[trLabels] %>%
        .[grepl("event", .)] %>%
        {(length(.) / trLabels.len >= 0.5)} %>%
        dplyr::if_else(., "event", "baseline")
    })

    # Final safety check to ensure the number of event nodes returned is the specified number of nodes
    unwrapped_Nodes = unlist(numOfNodesToLabel[sessionID_Child])
    names(unwrapped_Nodes) = gsub(paste0(sessionID_Child, "\\.(.*)"), "\\1", names(unwrapped_Nodes))

    finalSafety = predErrs %>%
      dplyr::filter(trLabel == "event") %>%
      dplyr::group_by(JSONLabel) %>%
      dplyr::summarise(n = n()) %>%
      dplyr::mutate(expectedN = unwrapped_Nodes[JSONLabel],
                    isExpected = n == expectedN)

    if (any(!finalSafety$isExpected)) {
      # Initialise an empty tibble object
      index_Base = dplyr::tibble()

      # Utilising a for-loop here as the index_Base object has to be updated
      for (index in which(finalSafety$isExpected == FALSE)) {
        # Isolate the event terminal nodes
        finalSafety_Child = predErrs %>%
          dplyr::filter(trLabel == "event") %>%
          dplyr::group_by(JSONLabel) %>%
          dplyr::filter(JSONLabel == finalSafety$JSONLabel[index])

        # Update the labels based on which nodes have the highest mean princpal component one score
        index_Child = x_Child %>%
          dplyr::filter(JSONLabel == finalSafety$JSONLabel[index],
                        trmnlNode %in% finalSafety_Child$trmnlNode) %>%
          dplyr::group_by(trmnlNode) %>%
          dplyr::summarise(mu = mean(pc.1)) %>%
          dplyr::arrange(desc(mu)) %>%
          dplyr::mutate(trLabel_New = "baseline")

        index_Child$trLabel_New[1:finalSafety$expectedN[index]] = "event"

        # Update the index_Base object
        index_Base = dplyr::bind_rows(index_Base, dplyr::select(index_Child, trLabel_New, trmnlNode))
      }

      # Update the predErrs object with the right number of event classifications
      predErrs = dplyr::left_join(predErrs, index_Base, by = "trmnlNode") %>%
        dplyr::mutate(trLabel = dplyr::if_else(!is.na(trLabel_New), trLabel_New, trLabel)) %>%
        dplyr::select(-trLabel_New)
    }

    # Append the trLabel to x_Child
    x_Child = x_Child %>%
      dplyr::left_join(y = dplyr::select(predErrs, sessionID, JSONLabel, trmnlNode, trLabel), by = c("sessionID", "JSONLabel", "trmnlNode")) %>%
      dplyr::mutate(trLabel = dplyr::if_else(is.na(trLabel), "baseline", trLabel))

    # Construct errorSummary
    errorSummary = predErrs %>%
      dplyr::select(sessionID, JSONLabel, trmnlNode, trLabel) %>%
      dplyr::ungroup() %>%
      dplyr::left_join(x_Child %>% group_by(trmnlNode) %>% summarise(start = min(time), stop = max(time)), by = "trmnlNode")

    return (list(
      df = x_Child %>% dplyr::select(-pc.1),
      errorSummary = errorSummary
    ))
  })

  # Setup the object to return to the end-user
  if (length(x_Work) == 1) {
    x$df = x_Work[[1]]$df
    x$errorSummary = x_Work[[1]]$errorSummary
  } else {
    x$df = x_Work %>% Reduce(function(lst1, lst2) dplyr::bind_rows(lst1$df, lst2$df), .)
    x$errorSummary = x_Work %>% Reduce(function(lst1, lst2) dplyr::bind_rows(lst1$errorSummary, lst2$errorSummary), .)
  }

  return (x)
}
TheGreatGospel/IVPSA documentation built on May 19, 2019, 1:47 a.m.