R/helperfuns.R

Defines functions fit_vector_length split_term_and_lambda make_lasso select_pen_i f2s expand_form expand_forms create_full_lists

create_full_lists <- function(data, family)
{
  
  all_names <- colnames(data)
  form <- as.formula(paste0("~ 1 + ", paste(all_names, collapse = " + ")))
  reps <- attr(make_tfd_dist(family), "nrparams_dist")
  
  res <- rep(list(form),reps)
  names(res) <- names_families(family)
  return(res)
  
}

expand_forms <- function(lof, controls, type)
{
  
  res <- lapply(1:length(lof), function(k) 
    expand_form(
      lof[[k]], # select the kth formula
      controls = select_pen_i(controls, k, type), 
      # select the penalty for the kth distr parameter
      type = type
    )
  )
  return(res)
  
}

expand_form <- function(form, controls, type)
{
  
  tls <- terms.formula(form)
  int <- paste0("~ ", attr(tls, "intercept"))
  vars <- attr(tls, "term.labels")
  if(controls$gammas < 0 | controls$gammas > 1) 
    stop("gamma values must be in [0,1]")
  if(type=="standard")
    lin <- paste(sapply(vars, make_lasso, 
                        la = controls$gammas * controls$lambdas), 
                 collapse = " + ")
  if(type=="reduced") controls$gammas <- 0
  nlin <- paste(sapply(vars, controls$sterm_default, 
                       la = pmax(0,1 - controls$gammas) * controls$lambdas), 
                collapse = " + ")
  if(type=="standard"){
    return(as.formula(paste(int, lin, nlin, sep = " + ")))
  }else if(type=="reduced"){
    return(as.formula(paste(int, nlin, sep = " + ")))
  }else{
    stop("No such type as >", type, "< supported.")
  }
  
}

f2s <- function(form) Reduce(paste, deparse(form))

select_pen_i <- function(ctrl, i, type)
{

  if(type != "reduced"){
    ctrl$gammas <- ctrl$gammas[[i]]
  }else{
    ctrl$gammas <- 0.1337 # is ignored
  }
  ctrl$lambdas <- ctrl$lambdas[[i]]
  return(ctrl)
  
}

make_lasso <- function(x, la) paste0("lasso(", x, ", la = ", la,")")

split_term_and_lambda <- function(term)
{
  
  la <- extractval(term, "la")
  term <- gsub(",\\s?la\\s?=\\s?([0-9]+\\.[0-9]*|\\.?[0-9]+)([eE][+-][0-9]+)?","",term)
  return(list(la = la, term = term))
  
}

fit_vector_length <- function(vec, len, name, warn = TRUE)
{
  
  if(warn)
    warning(paste0(name, " do not have the same length as list_of_formulas."),
            "Matching lengths by repeating or subsetting values.")
  
  if(length(vec) < len){
    return(rep(vec, ceiling(len/length(vec)))[1:len])
  }else{
    return(vec[1:len])
  }
}
neural-structured-additive-learning/sparsedistreg documentation built on May 13, 2022, 3:56 a.m.