R/ds_basic.R

Defines functions get_parameters.ds_basic get_data.ds_basic ds_basic

Documented in ds_basic

#' A basic datasource (DS)
#'
#' The standard datasource used to get training and test splits of data.
#'
#' This 'basic' datasource is the datasource that will most commonly be used for
#' most analyses. It can generate training and tests sets for data that has been
#' recorded simultaneously or pseudo-populations for data that was not recorded
#' simultaneously.
#'
#' Like all datasources, this datasource takes binned format data
#' and has a `get_data()` method that is called by a cross-validation object to
#' get training and testing splits of data that can be passed to a classifier.
#'
#'
#' @param binned_data A string that list a path to a file that has data in
#'   binned format, or a data frame of binned_data that is in binned format.
#'
#' @param var_to_decode A string specifying the name of the labels that
#'  should be decoded. This label must be one of the columns in the binned
#'  data that starts with 'label.'
#'
#' @param num_cv_splits A number specifying how many cross-validation splits
#'  should be used.
#'
#' @param use_count_data If the binned data is neural spike counts, then setting
#'   use_count_data = TRUE will convert the data into spike counts. This is
#'   useful for classifiers that work on spike count data, e.g., the
#'   poisson_naive_bayes_CL.
#'
#' @param num_label_repeats_per_cv_split A number specifying how many times each
#'   label should be repeated in each cross-validation split.
#'
#' @param label_levels_to_use A vector of strings specifying specific label
#'   levels that should be used. If this is set to NULL then all label levels
#'   available will be used.
#'
#' @param num_resample_sites The number of sites that should be randomly
#'   selected when constructing training and test vectors. This number needs to
#'   be less than or equal to the number of sites available that have
#'   num_cv_splits * num_label_repeats_per_cv_split repeats.
#'
#' @param site_IDs_to_use A vector of integers specifying which sites should be
#'   used.
#'
#' @param site_IDs_to_exclude A vector of integers specifying which sites should
#'   be excluded.
#'
#' @param randomly_shuffled_labels_before_running A boolean specifying whether
#'   the labels should be shuffled prior to the get_data() function being
#'   called. This is used when one wants to create a null distribution for
#'   comparing when decoding results are above chance.
#'
#' @param create_simultaneously_recorded_populations If the data from all sites
#'   was recorded simultaneously, then setting this variable to 1 will cause the
#'   get_data() function to return simultaneous populations rather than
#'   pseudo-populations.
#'
#'
#' @examples
#' # A typical example of creating a datasource to be passed cross-validation object
#' data_file <- system.file("extdata/ZD_150bins_50sampled.Rda", package = "NDTr")
#' ds <- ds_basic(data_file, "stimulus_ID", 18)
#'
#' # If one has many repeats of each label, decoding can be faster if one
#' # uses fewer CV splits and repeats each label multiple times in each split.
#' ds <- ds_basic(data_file, "stimulus_ID", 6,
#'   num_label_repeats_per_cv_split = 3
#' )
#'
#' # One can specify a subset of labels levels to be used in decoding. Here
#' #  we just do a three-way decoding analysis between "car", "hand" and "kiwi".
#' ds <- ds_basic(data_file, "stimulus_ID", 18,
#'   label_levels_to_use = c("car", "hand", "kiwi")
#' )
#'
#' # One never explicitely calls the get_data() function, but rather this is
#' # done by the cross-validator. However, to illustrate what this function
#' # does, we can call it explicitly here to get training and test data:
#' all_cv_data <- NDTr:::get_data(ds)
#' names(all_cv_data)
#' @family datasource
#'


# the constructor
#' @export
ds_basic <- function(binned_data,
                     var_to_decode,
                     num_cv_splits,
                     use_count_data = FALSE,
                     num_label_repeats_per_cv_split = 1,
                     label_levels_to_use = NULL,
                     num_resample_sites = NULL,
                     site_IDs_to_use = NULL,
                     site_IDs_to_exclude = NULL,
                     randomly_shuffled_labels_before_running = FALSE,
                     create_simultaneously_recorded_populations = 0) {

  if (is.character(binned_data)) {
    binned_file_name <- binned_data
  } else {
    binned_file_name <- "Loaded binned data given"
  }

  # load data if a string is given and check that the data is in binned format
  binned_data <- check_and_load_binned_data(binned_data)

  if (use_count_data) {
    binned_data <- convert_rates_to_counts(binned_data)
  }


  # remove all labels that aren't being used, and rename the labels that are being used to "labels"
  label_col_ind <- match(paste0("labels.", var_to_decode), names(binned_data))

  # also keep the variable trial_number if it exists
  if (("trial_number" %in% colnames(binned_data))) {
    binned_data <- binned_data %>%
      dplyr::select(.data$siteID, starts_with("time"), .data$trial_number, labels = label_col_ind)
  } else {
    binned_data <- binned_data %>%
      dplyr::select(.data$siteID, starts_with("time"), labels = label_col_ind)
  }


  # only use specified label_levels
  if (!is.null(label_levels_to_use)) {
    binned_data <- dplyr::filter(binned_data, labels %in% label_levels_to_use)
  } else {
    label_levels_to_use <- as.list(levels(binned_data$labels))
  }


  if (is.null(site_IDs_to_use)) {
    site_IDs_to_use <- unique(binned_data$siteID)
  }

  if (!is.null(site_IDs_to_exclude)) {
    site_IDs_to_use <- setdiff(site_IDs_to_use, site_IDs_to_exclude)
  }

  if (length(label_levels_to_use) != length(unique(label_levels_to_use))) {
    warning("Some labels were listed twice. Duplication will be ignored.")
  }

  if (is.null(num_resample_sites)) {
    num_resample_sites <- length(site_IDs_to_use)
  }

  if (create_simultaneously_recorded_populations > 2 || create_simultaneously_recorded_populations < 0) {
    stop("create_simultaneously_recorded_populations must be set to 0 or 1.")
  }


  # check if data is valid to get simultaneously recorded data
  if (create_simultaneously_recorded_populations == 1 || create_simultaneously_recorded_populations == TRUE) {

    # for simultaneously recorded data there should be the same number of labels for each site
    num_trials_for_each_label_for_each_site <- binned_data %>%
      dplyr::group_by(.data$siteID, labels) %>%
      dplyr::summarize(n = n())

    if (length(unique(num_trials_for_each_label_for_each_site$n)) != 1) {
      stop(paste(
        "There are not the same number of repeated labels/trials for each site which",
        "which there should be for simultaneously recorded data."
      ))
    }


    # add a variable called 'trial_number' if it doesn't exist and the data was recorded simultaneously
    if (!("trial_number" %in% colnames(binned_data))) {

      warning(paste(
        "No variable named trial_number in the binned_data.\n",
        "Attempting to add this variable to decode simultaneously recorded data",
        "by assuming all trials for each site are in the same sequential order."
      ))

      num_trials_each_site <- binned_data %>%
        dplyr::group_by(.data$siteID) %>%
        dplyr::summarize(n = n()) %>%
        dplyr::select(.data$n)

      # assuming the trials are in order for each site, otherwise there is no way to align them
      binned_data$trial_number <- rep(1:num_trials_each_site$n[1], dim(num_trials_each_site)[1])

      binned_data <- dplyr::select(binned_data, .data$siteID, .data$trial_number, everything())

    }



    # shuffle the labels if specified (shuffle labels the same way for each site)
    if (randomly_shuffled_labels_before_running == TRUE) {

      min_site_ID <- min(binned_data$siteID)
      first_site_data <- dplyr::filter(binned_data, .data$siteID == min_site_ID)
      the_labels <- sample(first_site_data$labels)
      binned_data$labels <- rep(the_labels, length(unique(binned_data$siteID)))

    }


    # add variable label_trial_combo
    binned_data <- binned_data %>%
      mutate(label_trial_combo = paste0(binned_data$labels, "_", binned_data$trial_number))


    # end pre-processing for simultaneously recorded data...
  } else {

    # shuffle the labels if specified
    if (randomly_shuffled_labels_before_running == TRUE) {
      binned_data <- binned_data %>%
        ungroup() %>%
        group_by(.data$siteID) %>%
        mutate(labels = labels[sample(row_number())]) %>%
        ungroup()
    }
  }


  # create the main data structure
  the_ds <- list(
    binned_file_name = binned_file_name,
    binned_data = binned_data,
    var_to_decode = var_to_decode,
    num_cv_splits = num_cv_splits,
    num_label_repeats_per_cv_split = num_label_repeats_per_cv_split,
    label_levels_to_use = label_levels_to_use,
    num_resample_sites = num_resample_sites,
    site_IDs_to_use = site_IDs_to_use,
    site_IDs_to_exclude = site_IDs_to_exclude,
    randomly_shuffled_labels_before_running = randomly_shuffled_labels_before_running,
    create_simultaneously_recorded_populations = create_simultaneously_recorded_populations
  )


  attr(the_ds, "class") <- "ds_basic"
  the_ds

}




#' @export
get_data.ds_basic <- function(ds_obj) {

  binned_data <- ds_obj$binned_data
  num_cv_splits <- ds_obj$num_cv_splits
  num_trials_used_per_label <- ds_obj$num_cv_splits * ds_obj$num_label_repeats_per_cv_split

  create_simultaneously_recorded_populations <- ds_obj$create_simultaneously_recorded_populations
  num_resample_sites <- ds_obj$num_resample_sites
  site_IDs_to_use <- ds_obj$site_IDs_to_use

  create_simultaneously_recorded_populations <- ds_obj$create_simultaneously_recorded_populations

  num_label_repeats_per_cv_split <- ds_obj$num_label_repeats_per_cv_split



  # the code that actually gets the data used to train and test the classifier  -----------

  curr_sites_to_use <- sample(site_IDs_to_use, num_resample_sites)
  binned_data <- dplyr::filter(binned_data, .data$siteID %in% curr_sites_to_use)


  if (create_simultaneously_recorded_populations == 1) {

    # use one site to select the trials to use and then apply to all sites
    curr_label_trials_to_use <- binned_data %>%
      dplyr::filter(.data$siteID == binned_data$siteID[1]) %>%
      select(labels, .data$label_trial_combo) %>%
      group_by(labels) %>%
      sample_n(size = num_trials_used_per_label) %>%
      pull(.data$label_trial_combo)

    # apply specific simultaneous trials selected to all sites
    all_k_fold_data <- binned_data %>%
      dplyr::filter(.data$label_trial_combo %in% curr_label_trials_to_use) %>%
      select(-.data$label_trial_combo)

  } else {

    # for data not recorded simultaneously
    all_k_fold_data <- binned_data %>%
      group_by(labels, .data$siteID) %>%
      sample_n(size = num_trials_used_per_label)
  }


  # remove the variable trial_number if it exists in all_k_fold_data
  if ("trial_number" %in% names(all_k_fold_data)) {
    all_k_fold_data <- select(all_k_fold_data, -.data$trial_number)
  }


  unique_labels <- unique(all_k_fold_data$labels)
  num_sites <- length(unique(binned_data$siteID))
  num_labels <- length(unique_labels)


  # arrange the data by siteID and labels before adding on the CV_slide_ID
  all_k_fold_data <- dplyr::arrange(all_k_fold_data, .data$siteID, labels)


  # CV_slice_ID is a groups of data that have one example for each label
  #  - these groups are mapped into CV blocks where blocks contain num_label_repeats_per_cv_split of each label
  CV_slice_ID <- rep(1:num_trials_used_per_label, num_labels * num_sites)

  # add the number of the cross-validitation split to the data
  all_k_fold_data$CV_slice_ID <- CV_slice_ID

  # paste the site.000 in front of the siteID so that is is listed as site_0001, site_0002, etc
  all_k_fold_data$siteID <- paste0("site_", stringr::str_pad(all_k_fold_data$siteID, 4, pad = "0"))

  # convert so that there are one column for each site
  melted_data <- tidyr::gather(all_k_fold_data, "time_bin", "activity", -.data$siteID, -labels, -CV_slice_ID)
  all_cv_data <- tidyr::spread(melted_data, .data$siteID, .data$activity) %>%
    select(labels, .data$time_bin, CV_slice_ID, everything()) %>%
    mutate(time_bin = as.factor(.data$time_bin)) #  %>%  arrange(labels, time_bin)

  # create different CV_1, CV_2 which list which points are training points and which points are test points
  for (iCV in 1:num_cv_splits) {
    start_ind <- (((iCV - 1) * num_label_repeats_per_cv_split) + 1)
    end_ind <- (iCV * num_label_repeats_per_cv_split)
    curr_cv_block_inds <- start_ind:end_ind
    eval(parse(text = paste0("all_cv_data$CV_", iCV, "= ifelse(all_cv_data$CV_slice_ID %in% curr_cv_block_inds, 'test', 'train')")))
  }


  all_cv_data <- dplyr::select(all_cv_data, -CV_slice_ID) %>%
    dplyr::ungroup() # fails tests if I don't ungroup. Also remove the original CV_slice_ID field


  # add train_labels and test_labels columns
  all_cv_data <- all_cv_data %>%
    mutate(train_labels = labels) %>%
    rename(test_labels = labels) %>%
    select(.data$train_labels, .data$test_labels, everything())


  all_cv_data

} # end get_data()






#' @export
get_parameters.ds_basic <- function(ndtr_obj) {

  ndtr_obj$binned_data <- NULL

  variable_lengths <- sapply(ndtr_obj, length)
  length_one_variables <- variable_lengths[variable_lengths < 2]
  length_one_variables <- ndtr_obj[names(length_one_variables)]

  # convert null values to NAs so that the variables are retained
  length_one_variables <- sapply(length_one_variables, function(x) ifelse(is.null(x), NA, x))

  parameter_df <- data.frame(val = unlist(length_one_variables)) %>%
    tibble::rownames_to_column("key") %>%
    tidyr::spread("key", "val") %>%
    dplyr::mutate(dplyr::across(where(is.factor), as.character))

  parameter_df$label_levels_to_use <- list(sort(unlist(ndtr_obj$label_levels_to_use)))
  parameter_df$site_IDs_to_use <- list(ndtr_obj$site_IDs_to_use)

  names(parameter_df) <- paste(class(ndtr_obj), names(parameter_df), sep = ".")

  parameter_df

}
emeyers/NDTr documentation built on Aug. 8, 2020, 3:41 p.m.