R/stimulation_pauses.R

Defines functions stimpulse_interpolate stimpulse_align stimpulse_extract stimpulse_find

Documented in stimpulse_align stimpulse_extract stimpulse_find stimpulse_interpolate

#' @name stimpulse_interpolate
#' @title Find and interpolate stimulation pulses
#' @param signal a channel signal trace
#' @param sample_rate sample rate
#' @param pulse_duration stimulation pulse duration in seconds
#' @param n_pulses suggested number of pulses
#' @param threshold suggested suggested threshold of responses to find
#' stimulation pulses
#' @param pulse_info a list containing number of pulses \code{n_pulse},
#' onset index (\code{onset_index}, first time-point is 1), offset index
#' (\code{offset_index}); this can be generated by \code{stimpulse_find};
#' see 'Examples' below
#' @param expand_timepoints point offsets allowed to align the pulses
#' @param center whether to center the pulses by median; default is true
#' @param max_offset maximum (edge) offsets in seconds to interpolate the
#' pulses; default is \code{-0.0002} seconds before stimulation onset
#' and \code{0.0005} seconds after the stimulation offset
#' @returns \code{stimpulse_find} and \code{stimpulse_align} returns the pulse
#' information (\code{pulse_info}) with the time-points of detected or corrected
#' stimulation on-set and off-set. The time-points are 1-indexed.
#' \code{stimpulse_extract} extract the signals around pulses;
#' \code{stimpulse_interpolate} returns interpolated signals.
#' @examples
#'
#'
#'
#' data("stimulation_signal")
#'
#' signal <- stimulation_signal$signal
#' sample_rate <- stimulation_signal$sample_rate
#'
#' # each pulse is roughly <0.001 seconds
#' pulse_durations <- 0.001
#'
#' # Initial pulses
#' pulse_info <- stimpulse_find(signal, sample_rate, pulse_durations)
#'
#' # number of pulses detected
#' pulse_info$n_pulses
#'
#' # extract responses -10 points before onset ~ 20 points after offset
#' expand_timepoints <- c(-20, 80)
#' pulses_snippets <- stimpulse_extract(
#'   signal = signal,
#'   pulse_info = pulse_info,
#'   expand_timepoints = expand_timepoints
#' )
#'
#' # Visualize the pulses
#' snippet_time <- seq(
#'   expand_timepoints[[1]], by = 1,
#'   length.out = nrow(pulses_snippets)) / sample_rate * 1000
#' matplot(snippet_time, pulses_snippets, type = 'l', lty = 1, col = 'gray80',
#'         xlab = "Time (ms)", ylab = "uV", main = "Initial find")
#'
#' # Align the pulses
#' pulse_info <- stimpulse_align(signal, pulse_info)
#'
#' # Estimated pulse duration
#' estimated_duration <-
#'   (pulse_info$offset_index - pulse_info$onset_index + 1) / sample_rate
#'
#' # reload aligned pulses
#' pulses_snippets <- stimpulse_extract(
#'   signal = signal,
#'   pulse_info = pulse_info,
#'   expand_timepoints = expand_timepoints
#' )
#' matplot(snippet_time, pulses_snippets, type = 'l', lty = 1, col = 'gray80',
#'         xlab = "Time (ms)", ylab = "uV", main = "Aligned pulses")
#' lines(snippet_time, rowMeans(pulses_snippets), col = 'red')
#'
#'
#' # Interpolate the pulses
#' interpolated <- stimpulse_interpolate(
#'   signal = signal,
#'   sample_rate = sample_rate,
#'   pulse_info = pulse_info,
#'   max_offset = c(-0.0003, 0.0005)
#' )
#'
#'
#' interp_snippets <- stimpulse_extract(
#'   signal = interpolated,
#'   pulse_info = pulse_info,
#'   expand_timepoints = expand_timepoints
#' )
#'
#' oldpar <- par(mfrow = c(1, 2))
#' on.exit(par(oldpar))
#' matplot(snippet_time, pulses_snippets, type = 'l', lty = 1,
#'         col = 'gray80', xlab = "Time (ms)", ylab = "uV",
#'         main = "Stim pulses", ylim = c(-600, 400))
#' lines(snippet_time, rowMeans(pulses_snippets), col = 'red')
#' abline(v = max(estimated_duration) * 1000, lty = 2)
#'
#' matplot(snippet_time, interp_snippets, type = 'l', lty = 1,
#'         col = 'gray80', xlab = "Time (ms)", ylab = "uV",
#'         main = "Interpolated 0.5 ms bandwidth")
#' lines(snippet_time, rowMeans(interp_snippets), col = 'red')
#' abline(v = max(estimated_duration) * 1000, lty = 2, col = "gray40")
#' abline(v = max(estimated_duration) * 1000 + 0.5, lty = 2)
#'
#'
#'
#'
#'
#'
NULL

#' @rdname stimpulse_interpolate
#' @export
stimpulse_find <- function(signal, sample_rate, pulse_duration, n_pulses = NA, threshold = NA) {

  # pulse_duration <- 0.003
  # n_pulses <- NA
  # threshold <- NA

  # diff by one makes brain signal more flat
  signal_diff <- diff(signal)
  signal_diff_ceiling <- max(abs(signal_diff)) + 1

  min_idx_gap <- ceiling(pulse_duration * sample_rate)

  # throttle
  find_onsets_throttle <- function(threshold, sdff) {
    pulse_onset_idx <- which(abs(sdff) > threshold) - 1L
    pulse_onset_idx[pulse_onset_idx <= 0] <- 1L

    # find all the onsets
    last_idx <- -Inf
    pulse_onset_idx <- pulse_onset_idx[vapply(pulse_onset_idx, function(idx) {

      last_idx2 <- last_idx
      if (idx - last_idx2 >= min_idx_gap) {
        last_idx <<- idx
        return(TRUE)
      }
      return(FALSE)

    }, FALSE)]

    pulse_onset_idx
  }

  # debounce
  find_onsets <- function(threshold, sdff) {
    pulse_onset_idx <- which(abs(sdff) > threshold) - 1L
    pulse_onset_idx[pulse_onset_idx <= 0] <- 1L

    # find all the onsets
    last_idx <- -Inf
    pulse_onset_idx <- pulse_onset_idx[vapply(pulse_onset_idx, function(idx) {

      last_idx2 <- last_idx
      last_idx <<- idx
      if (idx - last_idx2 >= min_idx_gap) {
        return(TRUE)
      }
      return(FALSE)

    }, FALSE)]

    pulse_onset_idx
  }

  match_n_pulses <- function(min_threshold, max_threshold, n_pulses, sdff) {
    # print(c(min_threshold, max_threshold))
    pulse_onset_idx1 <- find_onsets(min_threshold, sdff = sdff)
    if (length(pulse_onset_idx1) <= n_pulses || (max_threshold - min_threshold) <= 0) {
      return(list(
        pulse_onset_idx = pulse_onset_idx1,
        threshold = min_threshold
      ))
    }

    pulse_onset_idx2 <- find_onsets(max_threshold, sdff = sdff)
    if (length(pulse_onset_idx2) >= n_pulses ||
       length(pulse_onset_idx2) == length(pulse_onset_idx1)) {
      return(list(
        pulse_onset_idx = pulse_onset_idx2,
        threshold = max_threshold
      ))
    }

    re <- match_n_pulses(
      min_threshold = min_threshold,
      max_threshold = (min_threshold + max_threshold) / 2,
      n_pulses = n_pulses,
      sdff = sdff
    )
    if (length(re$pulse_onset_idx) == n_pulses) { return(re) }

    re <- match_n_pulses(
      min_threshold = (min_threshold + max_threshold) / 2,
      max_threshold = max_threshold,
      n_pulses = n_pulses,
      sdff = sdff
    )
    if (length(re$pulse_onset_idx) == n_pulses) { return(re) }
  }

  if (is.na(threshold)) {
    if (is.na(n_pulses)) {
      threshold <- abs(median(signal_diff)) + 4 * stats::sd(signal_diff)
      pulse_onset_idx <- find_onsets(threshold, signal_diff)
      n_pulses <- length(pulse_onset_idx)

      # again, find a better threshold
      matched <- match_n_pulses(
        abs(median(signal_diff)) + 3 * stats::sd(signal_diff),
        signal_diff_ceiling,
        n_pulses = n_pulses,
        sdff = signal_diff
      )
      if (
        length(matched$pulse_onset_idx) == n_pulses &&
        all(abs(matched$pulse_onset_idx - pulse_onset_idx) < min_idx_gap)
      ) {
        pulse_onset_idx <- matched$pulse_onset_idx
        threshold <- matched$threshold
      }
    } else {

      matched <- match_n_pulses(
        abs(median(signal_diff)) + 3 * stats::sd(signal_diff),
        signal_diff_ceiling,
        n_pulses = n_pulses,
        sdff = signal_diff
      )
      pulse_onset_idx <- matched$pulse_onset_idx
      threshold <- matched$threshold
    }
  }

  # pulse onset are found pulse_onset_idx
  pulse_onset_idx <- as.integer(pulse_onset_idx)
  n_pulses <- length(pulse_onset_idx)

  # TODO: what if no pulse detected?
  if (!n_pulses) {
    stop("No stimulation pulse detected. Please set the signal threshold manually")
  }

  # revese the signal and find the offset
  match <- match_n_pulses(
    abs(median(signal_diff)) + 3 * stats::sd(signal_diff),
    signal_diff_ceiling,
    n_pulses = n_pulses,
    sdff = rev(signal_diff)
  )
  pulse_offset_idx <- as.integer(sort(length(signal) - match$pulse_onset_idx))

  # pulse_offset_idx might mismatch with pulse_onset_idx
  duration_idx <- vapply(pulse_onset_idx, function(idx) {
    tmp <- pulse_offset_idx[pulse_offset_idx > idx]
    if (length(tmp)) { return( tmp[[1]] - idx ) }
    return(NA_integer_)
  }, FUN.VALUE = 0L)

  if (all(is.na(duration_idx))) {
    duration_idx[] <- min_idx_gap
  } else {
    duration_idx[is.na(duration_idx)] <- max(duration_idx, na.rm = TRUE)
  }

  list(
    n_pulses = n_pulses,
    threshold = threshold,
    onset_index = pulse_onset_idx,
    offset_index = pulse_onset_idx + duration_idx
  )
}

#' @rdname stimpulse_interpolate
#' @export
stimpulse_extract <- function(signal, pulse_info, expand_timepoints = c(-10, 20), center = TRUE) {
  expand_timepoints <- round(expand_timepoints)
  expand_timepoints[[1]] <- min(expand_timepoints[[1]], 0)
  expand_timepoints[[2]] <- max(expand_timepoints[[2]], 0)
  pre_points <- -expand_timepoints[[1]]
  post_points <- expand_timepoints[[2]]
  onset_index <- pulse_info$onset_index + expand_timepoints[[1]]

  max_pulse_npoints <- max(pulse_info$offset_index - pulse_info$onset_index) + 1

  offset_index <- pulse_info$onset_index + max_pulse_npoints + expand_timepoints[[2]] - 1

  time_points_delta <- seq_len(max_pulse_npoints + expand_timepoints[[2]] - expand_timepoints[[1]]) - 1L

  pulses <- sapply(onset_index, function(idx) {
    signal[idx + time_points_delta]
  }, simplify = TRUE)

  if ( center ) {
    m <- apply(pulses, 2, median, na.rm = TRUE)
    pulses <- t(t(pulses) - m)
  }
}

#' @rdname stimpulse_interpolate
#' @export
stimpulse_align <- function(signal, pulse_info, expand_timepoints = c(-10, 20)) {

  # DIPSAUS DEBUG START
  # data("stimulation_signal")
  #
  # signal <- stimulation_signal$signal
  # sample_rate <- stimulation_signal$sample_rate
  # pulse_info <- stimpulse_find(signal, sample_rate, 0.001)
  #
  # pulses_snippets <- stimpulse_extract(signal, pulse_info, center = TRUE)
  # matplot(pulses_snippets, type = 'l', lty = 1, col = 'gray80')
  # lines(rowMeans(pulses_snippets), col = 'red')
  # expand_timepoints <- c(-10, 20)

  if (length(pulse_info$onset_index) == 0) {
    stop("No stimulation pulse is specified")
  }

  expand_timepoints <- round(expand_timepoints)
  expand_timepoints[[1]] <- min(expand_timepoints[[1]], 0)
  expand_timepoints[[2]] <- max(expand_timepoints[[2]], 0)
  pre_points <- -expand_timepoints[[1]]
  post_points <- expand_timepoints[[2]]

  max_pulse_npoints <- max(pulse_info$offset_index - pulse_info$onset_index) + 1

  onset_index <- pulse_info$onset_index + expand_timepoints[[1]]
  offset_index <- pulse_info$onset_index + max_pulse_npoints + expand_timepoints[[2]] - 1

  time_points_delta <- seq_len(max_pulse_npoints + expand_timepoints[[2]] - expand_timepoints[[1]]) - 1L

  pulses_centered <- stimpulse_extract(signal,
                                    pulse_info,
                                    expand_timepoints = expand_timepoints,
                                    center = TRUE)
  # stim_averaged <- apply(pulses_centered, 1L, median)
  # stim_averaged <- rowMeans(pulses_centered)
  stim_averaged <- apply(pulses_centered, 1L, function(x) {
    if (mean(x) > 0) {
      quantile(x, 0.75)
    } else {
      quantile(x, 0.25)
    }
  })
  # matplot(pulses_centered, type = 'l', lty = 1, col = 'gray80')
  # lines(stim_averaged, col = 'red')

  # for each pulses_centered, convolute with `stim_averaged` to find time offset
  # within +-10 points
  offsets_template <- seq(-pre_points, post_points)

  peaks <- find_peaks(abs(stim_averaged), min_distance = 2)$index

  if (length(peaks) > 0) {
    neg_mask <- rep(TRUE, length(offsets_template))
    neg_mask[unlist(lapply(peaks, function(idx) {
      idx + offsets_template
    }))] <- FALSE
  } else {
    neg_mask <- rep(FALSE, length(offsets_template))
  }
  offsets <- apply(pulses_centered, 2L, function(slice) {

    # slice <- pulses_centered[,10]


    conv <- sapply(offsets_template, function(offset) {

      if (offset > 0) {
        slice <- c(rep(NA_real_, offset), slice)[seq_along(stim_averaged)]
      } else if (offset < 0) {
        slice <- rev(c(rep(NA_real_, -offset), rev(slice))[seq_along(stim_averaged)])
      }
      # slice[neg_mask] <- NA_real_
      mean(stim_averaged * slice, na.rm = TRUE)

    })
    conv[is.na(conv)] <- -Inf
    offset <- offsets_template[which.max(conv)][[1]]
    offset
  })

  # local({
  #   onset_index2 <- onset_index - offsets
  #
  #   pulses <- sapply(onset_index2, function(idx) {
  #     signal[idx + time_points_delta]
  #   }, simplify = TRUE)
  #
  #   m <- apply(pulses, 2, median, na.rm = TRUE)
  #   pulses_centered <- t(t(pulses) - m)
  #   stim_averaged <- apply(pulses_centered, 1L, median)
  #   matplot(pulses_centered, type = 'l', lty = 1, col = 'gray80')
  #   lines(stim_averaged, col = 'red')
  # })

  pulse_info$onset_index <- pulse_info$onset_index - offsets
  pulse_info$offset_index <- pulse_info$offset_index - offsets

  pulse_info$aligned <- TRUE
  pulse_info
}

#' @rdname stimpulse_interpolate
#' @export
stimpulse_interpolate <- function(signal, sample_rate, pulse_info, max_offset = c(-0.0002, 0.0005)) {

  if (length(pulse_info$onset_index) == 0) {
    stop("No stimulation pulse is specified")
  }

  onset_index <- pulse_info$onset_index
  offset_index <- pulse_info$offset_index

  start <- c(0, offset_index)
  finish <- c(onset_index, length(signal))

  # load -pre_points to post_points to interpolate
  max_duration_npts <- max(offset_index - onset_index) + 1
  pre_points <- min(finish - start) - 1
  post_points <- max_duration_npts + min(finish - start, 100) - 1


  time_points_template <- seq(-pre_points, post_points)
  max_offset_pts <- ceiling(abs(max_offset) * sample_rate) * c(-1, 1)

  # # obtain signal slices and masks
  # mask <- time_points_template >= 0 & time_points_template <= max_duration_npts
  # signal_slices <- sapply(onset_index, function(idx) {
  #   slice <- signal[idx + time_points_template]
  #   slice
  # })
  #
  # # intercept <- apply(signal_slices, 2L, median, na.rm = TRUE)
  #
  # # centered_signals <- t(t(signal_slices) - intercept)
  #
  # mask_idx <- which(mask)
  # mask_idx_range <- range(mask_idx)
  # bandwidth_pts <- ceiling(min(mask_idx_range[[1]], length(time_points_template) - mask_idx_range[[2]], max_offset_pts)) - 1
  # bandwidth <- length(mask_idx) / length(time_points_template) * 2
  #
  # mask_expanded <- mask
  # mask_expanded_start <- mask_idx_range[[1]] - bandwidth_pts
  # mask_expanded_finish <- mask_idx_range[[2]] + bandwidth_pts
  # mask_expanded[seq(mask_expanded_start, mask_expanded_finish)] <- TRUE
  #
  #
  # fitted <- apply(signal_slices, 2, function(slice) {
  #   # slice <- signal_slices[,2]; slice0 <- slice
  #
  #   smoothed <- stats::lowess(time_points_template, slice, f = bandwidth, iter = 10)
  #   slice[mask_expanded] <- smoothed$y[mask_expanded]
  #
  #   # plot(slice0, type = 'l')
  #   # lines(smoothed$y, col = 'blue')
  #   # lines(slice, col = 'red')
  #
  #   slice
  # })
  #
  # env <- new.env(parent = emptyenv())
  # env$signal <- signal
  #
  # sapply(seq_along(onset_index), function(ii) {
  #   idx <- onset_index[[ii]]
  #   env$signal[idx + time_points_template] <- fitted[, ii]
  #   return()
  # })
  #
  # env$signal

  # prepare interpolation splines
  ntp <- length(time_points_template)

  plan <- cbind(
    c(-pre_points, -1 + max_offset_pts[[1]], (pre_points - max_offset_pts[[2]]) / ntp),
    c(max_duration_npts + max_offset_pts[[2]], post_points, (post_points - max_duration_npts - max_offset_pts[[2]] + 1) / ntp)
  )

  nknots <- 50
  ord <- 4
  regularization <- 0.01
  nknots2 <- ceiling(plan[3, ] / sum(plan[3, ]) * (nknots - ord))
  while (sum(nknots2) > nknots - ord) {
    tmp <- which.max(nknots2)
    if (nknots2[tmp] >= 2) {
      nknots2[tmp] <- nknots2[tmp] - 1
    } else {
      break
    }
  }
  plan[3, ] <- nknots2
  plan <- plan[, nknots2 > 0, drop = FALSE]


  knots <- apply(plan, 2, function(item) {
    time_start <- item[[1]]
    time_end <- item[[2]]
    nknots_ <- item[[3]]
    # Chebychev points
    # cos((2 * seq_len(nknots_) - 1) / (nknots_ * 2) * pi) * (time_end - time_start) / 2 + (time_start + time_end) / 2

    seq(time_start, time_end, length.out = nknots_)
  })

  knots <- sort(as.double(unlist(knots)))
  if (length(knots) > nknots - ord) {
    knots <- knots[seq_len(nknots - ord)]
  }

  knots <- c(
    rep(time_points_template[1], ord - 1),
    knots,
    rep(time_points_template[length(time_points_template)], nknots - length(knots) - 1)
  )

  B <- t(
    splines::splineDesign(
      knots,
      time_points_template,
      ord = ord,
      outer.ok = TRUE,
      sparse = FALSE
    )
  )
  if (!inherits(B, "matrix")) {
    class(B) <- "matrix"
  }
  bbt <- tcrossprod(B)
  diag(bbt) <- diag(bbt) + regularization
  bbtsolve <- qr.solve(bbt)

  # obtain masked signals
  mask <- time_points_template >= -1 + max_offset_pts[[1]] & time_points_template <= (max_duration_npts + max_offset_pts[[2]])
  signal_slices <- sapply(onset_index, function(idx) {
    slice <- signal[idx + time_points_template]
    slice
  })

  intercept <- apply(signal_slices, 2L, median, na.rm = TRUE)

  centered_signals <- t(t(signal_slices) - intercept)

  # slice <- centered_signals[,1]
  # plot(time_points_template, slice, type = "l")
  # lines(lowess(time_points_template, slice, f = (sum(mask) + 5) / length(mask)), col = 'red')

  # Initial fit
  gamma <- tcrossprod(t(centered_signals[!mask, , drop = FALSE]), B[, !mask]) %*% bbtsolve
  fitted <- t(gamma %*% B + intercept)

  masked_timepoints <- time_points_template[mask]
  # env <- new.env(parent = emptyenv())
  # env$signal <- signal

  replacements <- lapply(seq_along(onset_index), function(ii) {
    idx <- onset_index[[ii]]

    cbind(idx + masked_timepoints, fitted[mask, ii])
  })
  replacements <- do.call("rbind", replacements)

  signal[replacements[, 1]] <- replacements[, 2]

  structure(
    signal,
    onset_timepoints = onset_index,
    masked_offsets = time_points_template[mask]
  )
}


# data("stimulation_signal")
#
# signal <- stimulation_signal$signal
# sample_rate <- stimulation_signal$sample_rate
# pulse_info <- stimpulse_find(signal, sample_rate, 0.001)
#
# pulses_snippets <- stimpulse_extract(signal, pulse_info, center = TRUE)
# matplot(pulses_snippets, type = 'l', lty = 1, col = 'gray80')
# lines(rowMeans(pulses_snippets), col = 'red')
#
# pulse_info <- stimpulse_align(signal, pulse_info)
# pulses_snippets <- stimpulse_extract(signal, pulse_info, center = TRUE)
# matplot(pulses_snippets, type = 'l', lty = 1, col = 'gray80')
# lines(rowMeans(pulses_snippets), col = 'red')
#
# # pulse_info$offset_index <- pulse_info$offset_index + 30
# interpolated <- stimpulse_interpolate(signal, sample_rate, pulse_info, max_offset = c(5, 40) / sample_rate)
#
# # interpolated[is.na(interpolated)] <- 0
# ravetools::diagnose_channel((interpolated[seq_len(20000)]), signal, srate = sample_rate, try_compress = FALSE, name = c("interpotad", "orig"))

# # visualize
# time = seq_along(signal) / 30000
# plotly::plot_ly(type = "scattergl", mode = "lines") |>
#   plotly::add_trace(
#     # data = data.frame(time = tpoints / 30000,
#     #                   snippet = snippet,
#     #                   interpolated = tmp),
#     x = time,
#     y = signal,
#     line = list(
#       color = "gray"
#     )
#   ) |>
#   plotly::add_trace(
#     x = time,
#     y = interpolated,
#     line = list(
#       color = "red"
#     )
#   )
# plotly::add_trace(
#   x = ravetools::decimate(time, 3, ftype = 'fir'),
#   y = ravetools::decimate(signal, 3, ftype = 'iir'),
#   line = list(
#     color = "black"
#   )
# ) |>
# plotly::add_trace(
#   x = ravetools::decimate(time, 3, ftype = 'iir'),
#   y = ravetools::decimate(interpolated, 3, ftype = 'iir'),
#   line = list(
#     color = "blue"
#   )
# )

Try the ravetools package in your browser

Any scripts or data that you put into this service are public.

ravetools documentation built on May 31, 2026, 9:06 a.m.