#' A basic datasource (DS)
#'
#' The standard datasource used to get training and test splits of data.
#'
#' This 'basic' datasource is the datasource that will most commonly be used for
#' most analyses. It can generate training and tests sets for data that has been
#' recorded simultaneously or pseudo-populations for data that was not recorded
#' simultaneously.
#'
#' Like all datasources, this datasource takes binned format data and has a
#' `get_data()` method that is never explicitly called by the user of the
#' package, but rather it is called internally by a cross-validation object to
#' get training and testing splits of data that can be passed to a classifier.
#'
#'
#' @param binned_data A string that list a path to a file that has data in
#' binned format, or a data frame of binned_data that is in binned format.
#'
#' @param labels A string specifying the name of the labels that should
#' be decoded. This label must be one of the columns in the binned data that
#' starts with 'label.'. For example, if there was a column name in a binned
#' data file called labels.stimulus_ID that you wanted to decode, then you
#' would set this argument to be "stimulus_ID".
#'
#' @param num_cv_splits A number specifying how many cross-validation splits
#' should be used.
#'
#' @param use_count_data If the binned data is neural spike counts, then setting
#' use_count_data = TRUE will convert the data into spike counts. This is
#' useful for classifiers that work on spike count data, e.g., the
#' poisson_naive_bayes_CL.
#'
#' @param num_label_repeats_per_cv_split A number specifying how many times each
#' label should be repeated in each cross-validation split.
#'
#' @param label_levels A vector of strings specifying specific label
#' levels that should be used. If this is set to NULL then all label levels
#' available will be used.
#'
#' @param num_resample_sites The number of sites that should be randomly
#' selected when constructing training and test vectors. This number needs to
#' be less than or equal to the number of sites available that have
#' num_cv_splits * num_label_repeats_per_cv_split repeats.
#'
#' @param site_IDs_to_use A vector of integers specifying which sites should be
#' used. If this is NULL (default value), then all sites that have
#' num_cv_splits * num_label_repeats_per_cv_split repeats will be used, and
#' a message about how many sites are used will be displayed.
#'
#' @param site_IDs_to_exclude A vector of integers specifying which sites should
#' be excluded.
#'
#' @param randomly_shuffled_labels A Boolean specifying whether the labels
#' should be shuffled prior to running an analysis (i.e., prior to the first
#' call to the the get_data() method). This is used when one wants to create a
#' null distribution for comparing when decoding results are above chance.
#'
#' @param create_simultaneous_populations If the data from all sites
#' was recorded simultaneously, then setting this variable to 1 will cause the
#' get_data() function to return simultaneous populations rather than
#' pseudo-populations.
#'
#' @return This constructor creates an NDR datasource object with the class
#' `ds_basic`. Like all NDR datasource objects, this datasource will be used
#' by the cross-validator to generate training and test data sets.
#'
#'
#' @examples
#' # A typical example of creating a datasource to be passed cross-validation object
#' data_file <- system.file(file.path("extdata", "ZD_150bins_50sampled.Rda"), package = "NeuroDecodeR")
#' ds <- ds_basic(data_file, "stimulus_ID", 18)
#'
#' # If one has many repeats of each label, decoding can be faster if one
#' # uses fewer CV splits and repeats each label multiple times in each split.
#' ds <- ds_basic(data_file, "stimulus_ID", 6,
#' num_label_repeats_per_cv_split = 3
#' )
#'
#' # One can specify a subset of labels levels to be used in decoding. Here
#' # we just do a three-way decoding analysis between "car", "hand" and "kiwi".
#' ds <- ds_basic(data_file, "stimulus_ID", 18,
#' label_levels = c("car", "hand", "kiwi")
#' )
#'
#' # One never explicitly calls the get_data() function, but rather this is
#' # called by the cross-validator. However, to illustrate what this function
#' # does, we can call it explicitly here to get training and test data:
#' all_cv_data <- get_data(ds)
#' names(all_cv_data)
#'
#' @family datasource
#'
# the constructor
#' @export
ds_basic <- function(binned_data,
labels,
num_cv_splits,
use_count_data = FALSE,
num_label_repeats_per_cv_split = 1,
label_levels = NULL,
num_resample_sites = NULL,
site_IDs_to_use = NULL,
site_IDs_to_exclude = NULL,
randomly_shuffled_labels = FALSE,
create_simultaneous_populations = 0) {
if (is.character(binned_data)) {
binned_file_name <- binned_data
} else {
binned_file_name <- "Loaded binned data given"
}
# load data if a string is given and check that the data is in binned format
binned_data <- check_and_load_binned_data(binned_data)
if (use_count_data) {
binned_data <- convert_rates_to_counts(binned_data)
}
# store the original binned data in order to get the site_IDs_to_use if it is null
# (a bit hacky and not memory efficient but should be ok)
binned_data_org <- binned_data
# remove all labels that aren't being used, and rename the labels that are being used to "labels"
label_col_ind <- match(paste0("labels.", labels), names(binned_data))
# also keep the variable trial_number if it exists
if (("trial_number" %in% colnames(binned_data))) {
binned_data <- binned_data |>
dplyr::select("siteID", starts_with("time"), "trial_number", labels = all_of(label_col_ind))
} else {
binned_data <- binned_data |>
dplyr::select("siteID", starts_with("time"), labels = all_of(label_col_ind))
}
# only use specified label_levels
if (!is.null(label_levels)) {
binned_data <- dplyr::filter(binned_data, labels %in% label_levels)
} else {
label_levels <- as.list(levels(binned_data$labels)) # why is this a list? (b/c of ds_generalization???)
}
# if (is.null(site_IDs_to_use)) {
# site_IDs_to_use <- unique(binned_data$siteID)
#}
if (is.null(site_IDs_to_use)) {
# if site_IDs_to_use is not specified, use all sites that have enough label repetitions
site_IDs_to_use <- get_siteIDs_with_k_label_repetitions(binned_data_org,
labels,
k = num_cv_splits * num_label_repeats_per_cv_split,
label_levels = unlist(label_levels))
# print message about which sites are used
message(
paste0("Automatically selecting sites_IDs_to_use.",
" Since num_cv_splits = ", num_cv_splits,
" and num_label_repeats_per_cv_split = ", num_label_repeats_per_cv_split,
", all sites that have ", num_cv_splits * num_label_repeats_per_cv_split,
" repetitions have been selected. This yields ", length(site_IDs_to_use),
" sites that will be used for decoding (out of ", length(unique(binned_data$siteID)),
" total).")
)
# Give an error message if no sites are available for decoding
# Could give an error if there are not at least 2 sites available, but will allow one site for now
if (length(site_IDs_to_use) < 1) {
stop(
paste("\nNo sites are available that enough trial repetitions based on",
"the num_cv_splits, num_label_repeats_per_cv_split, and label_levels that were specified.",
"Please use different values for these parameters, and/or manually specify the site_IDs_to_use.")
)
}
# message(site_IDs_to_use)
# free up some memory since this is not used elsewhere
rm(binned_data_org)
}
if (!is.null(site_IDs_to_exclude)) {
site_IDs_to_use <- setdiff(site_IDs_to_use, site_IDs_to_exclude)
}
if (length(label_levels) != length(unique(label_levels))) {
warning("Some labels were listed twice. Duplication will be ignored.")
}
if (is.null(num_resample_sites)) {
num_resample_sites <- length(site_IDs_to_use)
}
if (num_resample_sites < 1) {
stop("num_resample_sites must be greater than 0.")
}
if (create_simultaneous_populations > 2 || create_simultaneous_populations < 0) {
stop("create_simultaneous_populations must be set to 0 or 1.")
}
# check if data is valid to get simultaneously recorded data
if (create_simultaneous_populations == 1 || create_simultaneous_populations == TRUE) {
# for simultaneously recorded data there should be the same number of labels for each site
num_trials_for_each_label_for_each_site <- binned_data |>
dplyr::group_by(.data$siteID, labels) |>
dplyr::summarize(n = n()) |>
tidyr::pivot_wider(names_from = "labels", values_from = "n")
# for some reason select(-.data$siteID) isn't working
num_trials_for_each_label_for_each_site$siteID <- NULL
if (sum(sapply(lapply(num_trials_for_each_label_for_each_site, unique), length) != 1)) {
stop(paste(
"There are not the same number of repeated labels/trials for each site which",
"which there should be for simultaneously recorded data."
))
}
# add a variable called 'trial_number' if it doesn't exist and the data was recorded simultaneously
if (!("trial_number" %in% colnames(binned_data))) {
warning(paste(
"No variable named trial_number in the binned_data.\n",
"Attempting to add this variable to decode simultaneously recorded data",
"by assuming all trials for each site are in the same sequential order."
))
num_trials_each_site <- binned_data |>
dplyr::group_by(.data$siteID) |>
dplyr::summarize(n = n()) |>
dplyr::select("n")
# assuming the trials are in order for each site, otherwise there is no way to align them
binned_data$trial_number <- rep(1:num_trials_each_site$n[1], dim(num_trials_each_site)[1])
binned_data <- dplyr::select(binned_data, "siteID", "trial_number", everything())
}
# shuffle the labels if specified (shuffle labels the same way for each site)
if (randomly_shuffled_labels == TRUE) {
min_site_ID <- min(binned_data$siteID)
first_site_data <- dplyr::filter(binned_data, .data$siteID == min_site_ID)
the_labels <- sample(first_site_data$labels)
binned_data$labels <- rep(the_labels, length(unique(binned_data$siteID)))
}
# add variable label_trial_combo
binned_data <- binned_data |>
mutate(label_trial_combo = paste0(binned_data$labels, "_", binned_data$trial_number))
# end pre-processing for simultaneously recorded data...
} else {
# shuffle the labels if specified
if (randomly_shuffled_labels == TRUE) {
binned_data <- binned_data |>
ungroup() |>
group_by(.data$siteID) |>
mutate(labels = labels[sample(row_number())]) |>
ungroup()
}
}
# create the main data structure
the_ds <- list(
binned_file_name = binned_file_name,
binned_data = binned_data,
labels = labels,
num_cv_splits = num_cv_splits,
num_label_repeats_per_cv_split = num_label_repeats_per_cv_split,
label_levels = label_levels,
num_resample_sites = num_resample_sites,
site_IDs_to_use = site_IDs_to_use,
site_IDs_to_exclude = site_IDs_to_exclude,
randomly_shuffled_labels = randomly_shuffled_labels,
create_simultaneous_populations = create_simultaneous_populations
)
attr(the_ds, "class") <- "ds_basic"
the_ds
}
#' @inherit get_data
#' @keywords internal
#' @export
get_data.ds_basic <- function(ds_obj) {
binned_data <- ds_obj$binned_data
num_cv_splits <- ds_obj$num_cv_splits
num_trials_used_per_label <- ds_obj$num_cv_splits * ds_obj$num_label_repeats_per_cv_split
create_simultaneous_populations <- ds_obj$create_simultaneous_populations
num_resample_sites <- ds_obj$num_resample_sites
site_IDs_to_use <- ds_obj$site_IDs_to_use
create_simultaneous_populations <- ds_obj$create_simultaneous_populations
num_label_repeats_per_cv_split <- ds_obj$num_label_repeats_per_cv_split
# the code that actually gets the data used to train and test the classifier -----------
curr_sites_to_use <- sample(site_IDs_to_use, num_resample_sites)
binned_data <- dplyr::filter(binned_data, .data$siteID %in% curr_sites_to_use)
if (create_simultaneous_populations == 1) {
# use one site to select the trials to use and then apply to all sites
curr_label_trials_to_use <- binned_data |>
dplyr::filter(.data$siteID == binned_data$siteID[1]) |>
select(labels, "label_trial_combo") |>
group_by(labels) |>
sample_n(size = num_trials_used_per_label) |>
pull(.data$label_trial_combo)
# apply specific simultaneous trials selected to all sites
all_k_fold_data <- binned_data |>
dplyr::filter(.data$label_trial_combo %in% curr_label_trials_to_use) |>
dplyr::mutate(label_trial_siteID_combo = paste0(.data$label_trial_combo, '_', .data$siteID))
# Arrange rows for each site of all_k_fold_data to be in the random order
# specified by curr_label_trials_to_use. This ensures a different random
# ordering of data each time get_data() is called.
curr_label_trials_to_use_siteID <- paste0(curr_label_trials_to_use, '_',
all_k_fold_data$siteID)
all_k_fold_data <- all_k_fold_data[match(curr_label_trials_to_use_siteID,
all_k_fold_data$label_trial_siteID_combo), ]
all_k_fold_data <- all_k_fold_data |>
dplyr::select(-"label_trial_combo", -"label_trial_siteID_combo")
} else {
# for data not recorded simultaneously
all_k_fold_data <- binned_data |>
group_by(labels, .data$siteID) |>
sample_n(size = num_trials_used_per_label)
}
# remove the variable trial_number if it exists in all_k_fold_data
if ("trial_number" %in% names(all_k_fold_data)) {
all_k_fold_data <- select(all_k_fold_data, -"trial_number")
}
unique_labels <- unique(all_k_fold_data$labels)
num_sites <- length(unique(binned_data$siteID))
num_labels <- length(unique_labels)
# arrange the data by siteID and labels before adding on the CV_slide_ID
all_k_fold_data <- dplyr::arrange(all_k_fold_data, .data$siteID, labels)
# CV_slice_ID is a groups of data that have one example for each label
# - these groups are mapped into CV blocks where blocks contain num_label_repeats_per_cv_split of each label
CV_slice_ID <- rep(1:num_trials_used_per_label, num_labels * num_sites)
# add the number of the cross-validation split to the data
all_k_fold_data$CV_slice_ID <- CV_slice_ID
# paste the site.000 in front of the siteID so that is is listed as site_0001, site_0002, etc
all_k_fold_data$siteID <- paste0("site_", stringr::str_pad(all_k_fold_data$siteID, 4, pad = "0"))
# convert so that there are one column for each site
# older version that uses gather/spread
#melted_data <- tidyr::gather(all_k_fold_data, "time_bin", "activity", -.data$siteID, -labels, -CV_slice_ID)
#all_cv_data_old <- tidyr::spread(melted_data, .data$siteID, .data$activity) |>
# select(labels, .data$time_bin, CV_slice_ID, everything()) |>
# mutate(time_bin = as.factor(.data$time_bin)) # |> arrange(labels, time_bin)
long_data <- tidyr::pivot_longer(all_k_fold_data,
-c("siteID", "labels", "CV_slice_ID"),
names_to = "time_bin",
values_to = "activity")
all_cv_data <- tidyr::pivot_wider(long_data,
names_from = "siteID",
values_from = "activity") |>
select(labels, "time_bin", CV_slice_ID, everything()) |>
mutate(time_bin = as.factor(.data$time_bin)) # |> arrange(labels, time_bin)
# create different CV_1, CV_2 which list which points are training points and which points are test points
for (iCV in 1:num_cv_splits) {
start_ind <- (((iCV - 1) * num_label_repeats_per_cv_split) + 1)
end_ind <- (iCV * num_label_repeats_per_cv_split)
curr_cv_block_inds <- start_ind:end_ind
eval(parse(text = paste0("all_cv_data$CV_", iCV, "= ifelse(all_cv_data$CV_slice_ID %in% curr_cv_block_inds, 'test', 'train')")))
}
all_cv_data <- dplyr::select(all_cv_data, -CV_slice_ID) |>
dplyr::ungroup() # fails tests if I don't ungroup. Also remove the original CV_slice_ID field
# add train_labels and test_labels columns
all_cv_data <- all_cv_data |>
mutate(train_labels = labels) |>
rename(test_labels = labels) |>
select("train_labels", "test_labels", everything())
all_cv_data
} # end get_data()
#' @inherit get_parameters
#' @keywords internal
#' @export
get_parameters.ds_basic <- function(ndr_obj) {
ndr_obj$binned_data <- NULL
variable_lengths <- sapply(ndr_obj, length)
length_one_variables <- variable_lengths[variable_lengths < 2]
length_one_variables <- ndr_obj[names(length_one_variables)]
# convert null values to NAs so that the variables are retained
length_one_variables <- sapply(length_one_variables, function(x) ifelse(is.null(x), NA, x))
parameter_df <- data.frame(val = unlist(length_one_variables)) |>
tibble::rownames_to_column("key") |>
tidyr::spread("key", "val") |>
dplyr::mutate(dplyr::across(where(is.factor), as.character))
parameter_df$label_levels <- list(sort(unlist(ndr_obj$label_levels)))
parameter_df$site_IDs_to_use <- list(ndr_obj$site_IDs_to_use)
names(parameter_df) <- paste(class(ndr_obj), names(parameter_df), sep = ".")
parameter_df
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.