#' Gets matrix of TCE from observed network.
#' @param R_obs D x D matrix of observed effects.
#' @param normalize A length D vector which is used to convert R_tce from the
#'   per-allele to the per-variance scale. Each entry should be the
#'   std dev of the corresponding phenotype. Set to NULL for no normalization.
#' @return D x D matrix of total causal effects.
get_tce <- function(R_obs, normalize = NULL) {
  diag_R_obs <- diag(R_obs)
  R_tce <- (1 / (1 + diag_R_obs)) * R_obs
  diag(R_tce) <- 1
  if (!is.null(normalize)) {
    R_tce <- R_tce * outer(normalize, 1 / normalize)

#' Gets observed network from direct effects.
#' @param G D x D matrix of direct effects.
#' @return D x D matrix of observed effects.
get_observed <- function(G) {
  D <- dim(G)[1]
  return(solve(diag(D) - G, G))

#' Fits exact model to data.
#' @param R D x D matrix of "total causal effects".
#' @return D x D matrix with zero diagonal of deconvoluted direct effects.
get_direct <- function(R) {
  D <- dim(R)[1]
  R[is.na(R)] <- 0

  R_inv <- solve(R)
  G <- diag(D) - t(t(R_inv) / diag(R_inv))
  return(list("G" = G, "R_inv" = R_inv))

scale_free <- function(D, p = 0.5, DAG = FALSE, a_in = 1, a_out = 1){
  A = matrix(c(0, 0, 1, 0), nrow = 2)
  while(nrow(A) < D){
    d_in <- colSums(A)
    d_out <- rowSums(A)
    choice <- stats::runif(1)
    D_curr <- nrow(A)
    if(choice < p){ # Edge from existing to new
      v <- sample(D_curr, 1, prob = d_out + a_out)
      a_next <- rep(0, D_curr)
      a_next[v] <- 1
      A <- rbind(cbind(A, a_next, deparse.level = 0), rep(0, D_curr + 1))
    } else { # Edge between existing
        v <- sample(D_curr, 1, prob = d_out + a_out)
        w <- sample(D_curr, 1, prob = log(d_in + a_in))
          if(v < w){
        } else if(v != w && A[w, v] == 0){
      A[v, w] = 1

random <- function(D, p = 3/D, DAG = F){
  A <- matrix(as.integer(runif(D*D) < (DAG+1)*p), nrow=D)
  diag(A) <- 0

    perm <- sample.int(D)
    A <- A[perm, perm]
    A <- A * upper.tri(A)

#' Simulates random graph, optionally a DAG
#' Edges are sampled from a PERT distribution with random sign. By
#' default these are between 0.05 and 0.8 with a mode of 0.2 and
#' intended to represent "root variance-explained".
#' @param D Integer. Number of nodes to simulate.
#' @param graph String. One of 'scalefree' or 'random'. Type of network
#'   to simulate.
#' @param p Float between 0 and 1. Parameter for network generation. For
#'   scalefree, p is the probability of adding a new node at each iteration.
#'   For random, it is the probability of including each edge.
#' @param v Float. Mode of the pert distribution for simulated edge weights.
#' @param DAG Bool. TRUE to ensure the returned graph is a DAG.
#' @export
generate_network <- function(D, graph = 'scalefree', p = 0.4, v = 0.2, DAG = FALSE, max_eig = 0.9){
  if(graph == 'scalefree'){
    A <- scale_free(D, p, DAG)
  } else if(graph == 'random'){
    A <- random(D, p, DAG)
  } else{

  G = A
  G[A != 0] <- sample(c(1, -1), sum(A), replace = T) * mc2d::rpert(sum(A), min = v/2, mode = v, max = 2*v)
  rg <- Mod(eigen(G, only.values = T)$values)[1]
  if(rg > max_eig){
    G <- G * (max_eig/rg)
  rownames(G) <- paste0("V", 1:D)
  colnames(G) <- paste0("V", 1:D)

#' Turns a fully observed network into a partially observed network.
#' This turns a fully observed network matrix into a partially observed one
#' where TCEs of missing nodes are integrated into pseudo-DCEs for the observed
#' nodes.
#' @param dataset List generated by `generate_dataset`.
#' @param p Float 0 to 1, proportion of nodes to hide.
#' @export
censor_dataset <- function(dataset, p){
  D <- nrow(dataset$G)
  G <- dataset$G
  R <- get_tce(get_observed(G))
  D <- nrow(R)
  keep_cols <- sort(sample.int(D, size = round((1-p)*D)))
  keep_rows <- dataset$targets %in% c("control", paste0("V", keep_cols))
  R_cens <- R[keep_cols, keep_cols]
  G_cens <- get_direct(R_cens)$G
  Y_cens <- dataset$Y[keep_rows, keep_cols]
  targets_cens <- dataset$targets[keep_rows]
  eps_cens <- dataset$eps[keep_cols]
  return(list(Y=Y_cens, targets=targets_cens, G=G_cens, R=R_cens, eps=eps_cens))

generate_data_inspre <- function(G, N_cont, N_int, int_beta=-2, noise='gaussian'){
  D <- nrow(G)
  int_sizes <- rep(N_int, D) # rnbinom(D, mu=N_int - N_int/10, size=N_int/10) + N_int/10

  Ncs <- cumsum(c(N_cont, int_sizes))
  N <- Ncs[length(Ncs)]
  XB <- matrix(0, nrow=sum(N), ncol=D)

  for(d in 1:D){
    start = Ncs[d]
    end = Ncs[d+1]
    XB[(start+1):end, d] = 1
  XB <- t(t(XB) * int_beta)

  net_vars <- colSums(G**2)
  eps_vars <- max(0.9, max(net_vars)) - net_vars + 0.1
  if(noise == 'gaussian'){
    eps <- t(matrix(rnorm(D*N, sd=sqrt(eps_vars)), nrow=D, ncol=N))
  } else{
  Y <- (XB + eps) %*% solve(diag(D) - G)
  # This mimics perturb-seq normalization
  mu_cont <- colMeans(Y[1:N_cont, ])
  sd_cont <- apply(Y[1:N_cont, ], 2, sd)
  Y <- t((t(Y) - mu_cont)/sd_cont)

  # Also need to normalize the graph
  R <- get_tce(get_observed(G), normalize=sd_cont)
  G <- get_direct(R)$G
  int_beta <- int_beta/sd_cont

  colnames(Y) <- paste0("V", 1:D)
  targets <- c(rep("control", N_cont), paste0("V", rep(1:D, times=int_sizes)))
  return(list(Y = Y, targets = targets, G = G, R = R, int_beta=int_beta))

#' Simulates an intervention dataset.
#' Simulates a graph-intervention dataset with corresponding graph.
#' @param D Integer. Number of nodes to simulate.
#' @param N_cont Integer. Number of control samples to simulate.
#' @param N_int Integer. Mean of number of intervention samples to simulate per node.
#' @param size Float. Size parameter for NB distribution for int samples.
#' @param int_r2 Float. Mean variance in node explained by intervention.
#' @param int_dir string. 'positive', 'negative', or 'both'. Direction of
#'  effect of intervention on node.
#' @param graph String. One of 'scalefree' or 'random'. Type of network
#'   to simulate.
#' @param v_mode Float. Mode of the pert distribution.
#' @param v_min Float. Min of the pert distribution.
#' @param v_max Float. Max of the pert distribuion.
#' @param DAG Bool. TRUE to ensure the returned graph is a DAG.
#' @param C Integer. Number of fully connected confounding nodes to simulate.
#' @param noise String. Noise model to simulate, currently just "gaussian".
#' @param model Data generating model. One of "inspre" or "dotears".
#' @export
generate_dataset <- function(D, N_cont, N_int, int_beta=-2,
                             graph = 'scalefree', v = 0.2, p = 0.4,
                             DAG = FALSE, C = floor(0.1*D), noise = 'gaussian',
                             model = 'inspre'){
  G <- generate_network(D, graph, p, v, DAG)
  if(C > 0){
    if(graph == "scalefree"){
      new_vars <- matrix(mc2d::rpert(D*C, 0.01, v^(log(D)/log(log(D))), 2*v), nrow=C)
    } else if(graph == "random"){
      new_vars <- matrix(mc2d::rpert(D*C, 0.01, v^(log(D)), 2*v), nrow=C)

    G <- cbind(rbind(G, new_vars), matrix(0, nrow=D+C, ncol=C))
  if(model == 'inspre'){
    data = generate_data_inspre(G, N_cont, N_int, int_beta, noise)
  } else {

  Y <- data$Y
  G <- data$G
  var_all <- apply(Y %*% G, 2, var)[1:D]
  var_obs <- apply(Y[,1:D] %*% G[1:D, 1:D], 2, var)
  if(C > 0){
    var_conf <- apply(Y[,(D+1):(D+C)] %*% G[(D+1):(D+C), 1:D], 2, var)
  } else {
    var_conf <- rep(0, length=D)
  var_eps <- 1-var_all

  return(list(Y=Y[1:(N_cont+N_int*D), 1:D], targets=data$targets[1:(N_cont+N_int*D)], R=data$R[1:D, 1:D],
              G=G[1:D, 1:D], var_all=var_all, var_obs=var_obs,
              var_conf=var_conf, var_eps=var_eps, int_beta=data$int_beta))

#' Calculates metrics for model evaluation.
#' @param X DxD matrix of predicted parameters.
#' @param X_true DxD matrix of true parameters
#' @param eps float. Absolute values below eps will be considered 0.
#' @export
calc_metrics <- function(X, X_true, eps = 1e-8) {
  D <- ncol(X)
  X <- X[!diag(D)]
  X_true <- X_true[!diag(D)]
  X[abs(X) < eps] <- 0
  X_true[abs(X_true) < eps] <- 0

  rmse <- sqrt(mean((X - X_true)^2))
  mae <- mean(abs(X - X_true))

  sign_X <- sign(X)
  sign_Xt <- sign(X_true)
  TN <- sum(!(abs(sign_X) | abs(sign_Xt)))
  TS <- sum(sign_X == sign_Xt) - TN
  FN <- sum((1 - abs(sign_X)) & abs(sign_Xt))
  FS <- sum(sign_X != sign_Xt) - FN

  shd <-sum(sign_X != sign_Xt)
  N_pos <- sum(X_true > 0)
  N_neg <- sum(X_true < 0)
  N_zero <- sum(X_true == 0)
  weights <- sign_Xt
  weights[sign_Xt > 0] <- 1 / N_pos
  weights[sign_Xt < 0] <- 1 / N_neg
  weights[sign_Xt == 0] <- 1 / N_zero
  weight_acc <- sum((sign_X == sign_Xt) * weights) / sum(weights)
  precision = TS / (TS + FS)
  recall = TS / (TS + FN)

  return(list("precision" = precision, "recall" = recall,
              "F1" = 2*precision*recall/(precision + recall),
              "rmse" = rmse, "mae" = mae, "shd" = shd,
              "weight_acc" = weight_acc))

#' Calculates STARS cross-validation.
#' @param X N x D data matrix.
#' @param method Function that returns a list with two elements. "theta", a
#'   D x D x nlambda matrix and "lambda" a list of lambda values used.
#' @param train_prop Float, proportion of rows of X to use at each iteration.
#' @param cv_folds Integer, number of cross-validation folds.
stars_cv <- function(X, method, train_prop = 0.8, cv_folds = 10){
  N = dim(X)[1]
  D = dim(X)[2]
  S_full <- cor_w_se(X)$S_hat
  method_res <- method(S_full)
  theta_hat <- method_res$theta
  lambda <- method_res$lambda
  xi_mat <- array(0, dim = c(D, D, length(lambda)))
  for (fold in 1:cv_folds){
    train <- stats::runif(N) < train_prop
    X_train = X[train, ]
    cor_res <- cor_w_se(X_train)
    S_train <- cor_res$S_hat
    theta_cv <- method(S_train)$theta
    theta_nz <- abs(theta_cv) > 1e-8
    xi_mat <- xi_mat + theta_nz
  xi_mat <- xi_mat/cv_folds
  xi_mat <- 2 * xi_mat * (1-xi_mat)
  D_hat <- apply(xi_mat, 3, mean)
    "theta" = theta_hat,
    "lambda" = lambda,
    "D_hat" = D_hat))

#' Estimates correlation matrix of X in the presence of missing data.
#' Uses simple approximation sqrt((1-r^2)/(n-2)) which is technically
#' incorrect but works well in practice.
#' @param X NxD matrix with optionally missing entries.
cor_w_se <- function(X) {
  S_hat <- stats::cor(X, use = "pairwise.complete.obs")
  N <- apply(X, 2, function(x) {
    apply(X, 2, function(y) {
      sum(!is.na(x) & !is.na(y))
  SE_S <- sqrt((1 - S_hat**2) / (N - 2))
  return(list("S_hat" = S_hat, "SE_S" = SE_S, "N" = N))

#' Converts simulated dataset to score object for use with GIES function.
#' @param dataset Dataset generated by `generate_dataset`
#' @export
dataset_to_score <- function(dataset){
  targets <- c(list(integer(0)), as.list(1:ncol(dataset$Y)))
  target.index <- dataset$targets == "control"
  for(i in seq_along(colnames(dataset$Y))){
    name_i <- colnames(dataset$Y)[[i]]
    target.index = target.index + (i+1)*(dataset$targets == name_i)
  return(new("GaussL0penIntScore", data = dataset$Y, targets = targets,
             target.index = target.index))
