R/setter.R

Defines functions set_predictions set_targets set_segments decorate_generic

Documented in decorate_generic set_predictions set_segments set_targets

# ------------------------------------------------------------------------------
# -------- Prediction ----------------------------------------------------------
# ------------------------------------------------------------------------------

#' Decorate predictions extracted from a data.frame
#'
#' @inheritParams decorate_generic
#'
#' @return `data.frame()`\cr Decorated prediction frame with columns
#' "prediction" and "prediction_type" in addition to id columns.
#' @export
set_predictions <- function(col_names, id_names, test_data){
      predictions <- decorate_generic(
        test_data = test_data,
        col_names = col_names,
        id_names = id_names,
        value_name = "prediction",
        key_name = "prediction_type"
        )

      # TODO: temporary. See task 5cebcb93-9339-475b-9a16-a6d69a0dc8e5
      predictions <- predictions %>%
        dplyr::filter(!is.na(prediction))

      return(predictions)
    }

# ------------------------------------------------------------------------------
# -------- Target --------------------------------------------------------------
# ------------------------------------------------------------------------------

#' Decorate targets extracted from a data.frame
#'
#' @inheritParams decorate_generic
#'
#' @return `data.frame()` \cr Data.frame with id columns, columns
#' "target_type" and "target"
#' @export
set_targets <- function(col_names, id_names, test_data) {
  targets <- decorate_generic(
    test_data = test_data,
    col_names = col_names,
    id_names = id_names,
    key_name = "target_type",
    value_name = "target"
  )
  return(targets)
}


#------------------------------------------------------------------------------
# -------- Segment -------------------------------------------------------------
# ------------------------------------------------------------------i------------


#' Creates a dataframe with .id and segment, from the name of a factor column of the test frame.
#'
#' @inheritParams decorate_generic
#'
#' @return `data.frame()` \cr Data frame with id_names and "segment".
#' @export
set_segments <- function(col_names, id_names, test_data) {

  segments  <- decorate_generic(
    test_data = test_data,
    col_names = col_names,
    id_names = id_names,
    value_name = "segment",
    key_name = NULL
    )
  return(segments)
}

# ------------------------------------------------------------------------------
# -------- setter --------------------------------------------------------------
# ------------------------------------------------------------------------------

#' Extract data and decorate it as expected for an Assesser attribute
#'
#' From a test data_frames, extract col_names and id_names, and gathers the
#' col_names columns with given value_name and key_name.
#'
#' If `key_name == NULL`, then columns are combined into one column with name
#' `value_name`.
#'
#' @param test_data `data.frame()` \cr A dataframe with vectors as
#'   columns from which to fetch the data.
#' @param col_names `character()` \cr Column names to extract from the
#'   dataframe.
#' @param id_names `character()` \cr Column used as unique line ids.
#' @param value_name `character(1)` Name of value output column.
#' @param key_name `character(1)` Name of key column when extracted data is
#' gathered. If NULL, then value_columns are not gathered but combined.
#'
#' @return A gathered data.frame with columns ".id", `value_name`, and
#' optionnaly `key_name`. The key column is then of type factor, otherwise the
#' value column is of type factor.
#'
#' @export
decorate_generic <- function(
  test_data,
  col_names,
  id_names,
  value_name,
  key_name = NULL
){

  # input assertions
  assertthat::assert_that(
    is.data.frame(test_data),
    msg = "test_data must be a dataframe"
  )
  assertthat::assert_that(
    all(purrr::map_lgl(
        list(col_names, id_names, value_name),
        is.character
      )),
    msg = "col_names, id_names, value_name must be character vectors"
  )
  assertthat::assert_that(
    anyDuplicated(c(col_names, id_names)) == 0,
    msg = "Column cannot be in col_names and id_names at the same time"
    )
  assertthat::assert_that(
    is.null(key_name) || (length(key_name) == 1 && is.character(key_name)),
    msg = "key_name must be a length 1 character vector"
  )
  assertthat::assert_that(
    length(value_name) == 1,
    msg = "value_name must be a length 1 character vector"
  )
  assertthat::assert_that(
    all(c(col_names, id_names) %in% names(test_data)),
    msg = paste(
      "column name or id name missing from dataframe:",
      setdiff(c(col_names, id_names), names(test_data))
      )
  )
  assertthat::assert_that(
    nrow(test_data[id_names]) == dplyr::n_distinct(test_data[id_names]),
    msg = "Id columns do not define each line uniquely"
    )

  input_df  <- test_data[c(col_names, id_names)]

  if (!is.null(key_name)){
    output_df <- input_df %>%
      tidyr::gather(
        key = !!key_name,
        value = !!value_name,
        - tidyselect::one_of(id_names)
      )

    output_df[[key_name]] <- as.factor(output_df[[key_name]])
  } else {
    output_df <- input_df %>%
      tidyr::unite(
        col = !!value_name,
        tidyselect::one_of(col_names),
        sep = "_",
        remove = TRUE
        )
    output_df[[value_name]] <- as.factor(output_df[[value_name]])
  }

  return(output_df)
}
signaux-faibles/MLsegmentr documentation built on Aug. 29, 2019, 2:22 p.m.