library(testthat)
library(usethis)
# Load already included functions if relevant
pkgload::load_all(export_all = FALSE)

Okay well I am going to be trying out the {fusen} style of vignette driven development on this package now. I will use this vignette to sort out how to do the LASSO based phylogenetic model in {fibre}.

First, let's get some data to work with.

library(ape)
library(readr)
library(dplyr)
library(glmnet)
library(RPANDA)

bird_tree <- read.nexus("extdata/HackettStage1_0001_1000_MCCTreeTargetHeights.nex")
bt_res <- read_rds("extdata/rateShiftInformation.RDS")
bird_tree <- bt_res$Phylogeny
bird_dat <- read_csv("extdata/AVONET3_BirdTree.csv")

bird_dat <- bird_dat %>%
  mutate(species = gsub(" ", "_", Species3))

spec_rates <- fit_ClaDS0(bird_tree, "extdata/spec_rates.txt", nCPU = 1,
                         iteration = 500000, thin = 2000, update = 1000, adaptation = 5)

Next we extract the root to tip matrix using existing {fibre}. We will be trying out the 'second order' style first, which gives a matrix that will represent deviations in trait evolution at each node (a random 'trend' model).

library(makemyprior)
library(Matrix)

# bird_tree2 <- bird_tree
# bird_tree2$edge.length <- sqrt(bird_tree2$edge.length)
# rtp <- make_root2tip(bird_tree2, order = "first")

#scaling <- sqrt(typical_variance(as.matrix(rtp %*% t(rtp))))
#rtp <- rtp / scaling

Now to format the data to work with {lars} or {glmnet}

library(glmnet)
library(sAIC)
library(selectiveInference)

bird_dat <- bird_dat[bird_dat$species %in% bird_tree$tip.label, ]

test <- fibre_data(~ p(species, bird_tree), data = bird_dat)

rtp <- test$re[[1]]$rtp_mat

x <- rtp[match(bird_dat$species, rownames(rtp)), ]

crits <- function(fit, x, y, alpha = 1, penalty.factor = NULL){

  lambda <- fit$lambda

  coef <- coef(fit)
  p <- ncol(x)

  if(is.list(coef)) {
    yhat <- lapply(coef, function(b) cbind(1, x) %*% b)
    yhat <- do.call(rbind, yhat)
    y <- as.vector(y)
    #p <- p * length(coef)
    #n_y <- length(coef)
    #print(dim(yhat))
    #print(dim(y))
  } else {
    yhat <- cbind(1, x) %*% coef
    n_y <- 1
  }

  df <- fit$df#*n_y
  n <- fit$nobs
  nlambda <- length(lambda)

  loocv <- numeric(nlambda)
  loocv2 <- numeric(nlambda)

  if(!is.null(penalty.factor)) {
    scal <- penalty.factor * (n / sum(penalty.factor)) * p
  } else {
    scal <- 1
  }

  if(alpha == 0){
    xs <- x

    I <- diag(ncol(x))
    xx <- t(xs)%*%xs
    for(i in 1:nlambda){

      aux <- solve(xx + I * lambda[i] * scal * n)
      hatm <- xs%*%aux%*%t(xs)
      df[i] <- sum(diag(hatm))

      ymat <- matrix(y, ncol = 1)
      onemhat <- diag(n) - hatm

      loocv[i] <- t(ymat) %*% onemhat %*% (diag(n) * (diag(onemhat)^-2)) %*% onemhat %*% ymat



      # B <- diag(1 / (1 - diag(hatm)))
      # ff <- (B%*%(diag(ncol(hatm)) - hatm))%*%matrix(y, ncol = 1)
      # loocv[i] <- sum(ff^2) / n



    # df2 <- df
    # t2 <- system.time({
    # svd <- sparsesvd(x)
    # D <- svd$d^2
    # for(i in 1:length(lambda)){
    #   aux <- sum(D * (1 / (D + lambda[i] * scal)))
    #   df2[i] <- sum(diag(xs%*%aux%*%t(xs)))
    # }
    # })
    }


  }


  residuals <- (y - yhat)
  mse <- colMeans(residuals^2)
  sse <- colSums(residuals^2)

  # nvar <- df + 1
  # bic <- 2 * n * log(mse) + nvar * log(n)
  # aic <- 2 * n * log(mse) + 2*nvar
  # aicc <- aic + (2 * nvar * (nvar+1)) / (n - nvar - 1)
  #hqc <- n * log(mse) + 2 * nvar * log(log(n))

  sigma <- sqrt(sse / (n-df))

  tLL <- fit$nulldev - fit$nulldev * (1 - fit$dev.ratio)
  k <- df
  aicc <- -tLL + 2 * k + 2 * k * (k + 1) / (n - k - 1)
  aic <- -tLL + 2 * k
  bic <- log(n) * k - tLL

  bic_l <- bic + 1.001 * lchoose(p, k)

  # k <- df2 + 1
  # aicc2 <- -tLL + 2 * k + 2 * k * (k + 1) / (n - k - 1)
  # aic2 <- -tLL + 2 * k
  # bic2 <- log(n) * k - tLL

  rate_est <- (sigma^2) / (lambda * n)

  list(bic = bic, aic = aic, aicc = aicc, bic_l = bic_l, sigma = sigma, df = df,
       mse = mse, rate_est = rate_est, loocv = loocv)
}

vars <- c("Beak.Length_Culmen", "Beak.Length_Nares", "Beak.Width", "Beak.Depth", "Tarsus.Length", "Wing.Length", "Kipps.Distance", "Secondary1", "Hand-Wing.Index", "Tail.Length", "Mass", "Habitat", "Habitat.Density", "Migration", "Trophic.Level", "Trophic.Niche", "Primary.Lifestyle")

library(GGally)
ggpairs(bird_dat[ , vars])

cont_vars <- c("Beak.Length_Culmen", "Beak.Length_Nares", "Beak.Width", "Beak.Depth", "Tarsus.Length", "Wing.Length", "Kipps.Distance", "Secondary1", "Hand-Wing.Index", "Tail.Length", "Mass")

y_log <- log(bird_dat$Mass + 1) 
y <- bird_dat$Mass 

edges_nums <- c((Ntip(bird_tree) + 2):(Ntip(bird_tree) + Nnode(bird_tree)), 1:Ntip(bird_tree))
edge_ord <- match(edges_nums, bird_tree$edge[ , 2])
lens <- bird_tree$edge.length[edge_ord]

system.time({
  mass_log <- glmnet(x, y_log, standardize = FALSE, nlambda = 1000, lambda.min.ratio = 0.00001 / nrow(x),
                  penalty.factor = sqrt(lens))
  aics <- crits(mass_log, x, y_log)
})

best_mod <- which.min(aics$bic_l)
best_lam <- mass_log$lambda[best_mod]
aics$rate_est[best_mod] * 2

rtp_all <- make_root2tip(bird_tree, "both")
preds <- predict(mass_log, s = best_lam, newx = rtp_all)
tip_preds <- predict(mass_log, s = best_lam, newx = x)

rates <- coef(mass_log, s = best_lam)[-1]
nonzero <- which(rates != 0)
nonzero_edge <- bird_tree$edge[edge_ord, ][nonzero, ]
nonzero_edge <- edge_ord[nonzero]
nonzero_node <- bird_tree$edge[nonzero_edge, 1]

nonzero_rates <- log(abs(rates[nonzero]) + 1)

plot_traits(bird_tree, preds, direction = "upwards", ftype = "off")
edgelabels(edge = nonzero_edge, pch = 21, col = "black", bg = "white",
           cex = nonzero_rates)

plot(aics$bic_l ~ log(aics$rate_est * 2), type = "l")

aic_df <- tibble(BIC = aics$bic_l, `Global Rate Parameter` = aics$rate_est * 2) %>%
  mutate(good = ifelse(BIC - min(BIC) <= 2, "deltaBIC <= 2", "deltaBIC > 2"))
ggplot(aic_df, aes(`Global Rate Parameter`, BIC)) +
  geom_path(size = 1.5) +
  geom_point(data = aic_df %>% filter(good == "deltaBIC <= 2"),
             colour = "red") +
  scale_x_log10() +
  theme_minimal() +
  theme(axis.title = element_text(size = 32),
        axis.text = element_text(size = 22))

plot_traits(bird_tree, preds, type = "fan")
plot(y_log ~ tip_preds)
abline(0, 1)

pred_df <- tibble(`Body Mass (observed)` = (exp(y_log) - 1) / 1000,
                  `Body Mass (predicted)` = (exp(tip_preds) - 1) / 1000)

ggplot(pred_df, aes(`log Body Mass (predicted)`, `log Body Mass (observed)`)) +
  geom_point() +
  geom_abline(slope = 1, size = 1.5, colour = "grey30") +
  theme_minimal()

ggplot(pred_df, aes(`Body Mass (predicted)`, `Body Mass (observed)`)) +
  geom_point(alpha = 0.4) +
  geom_abline(slope = 1, size = 1.5, colour = "grey30") +
  theme_minimal() +
  coord_equal() +
  scale_x_continuous(trans = "log", breaks = c(0.01, 0.1, 1.0, 7), limits = c(NA, 12)) +
  scale_y_continuous(trans = "log", breaks = c(0.01, 0.1, 1.0, 7), limits = c(NA, 12)) +
  theme(axis.title = element_text(size = 32),
        axis.text = element_text(size = 22))

ggplot(pred_df, aes(`log Body Mass (predicted)`, `log Body Mass (observed)`)) +
  geom_hex() +
  geom_abline(slope = 1, size = 1.5, colour = "grey30") +
  theme_minimal()



plot_traits <- function(phy, trait_vals, ...) {
  x <- trait_vals[1:ape::Ntip(phy)]
  names(x) <- rownames(trait_vals)[1:ape::Ntip(phy)]
  phytools::contMap(phy, x,
                    anc.states = trait_vals[setdiff(1:length(trait_vals), 1:ape::Ntip(phy))],
                    method = "user", ...)
}


system.time({
  test3 <- cv.glmnet(x, y, standardize = FALSE, nlambda = 1000, lambda.min.ratio = 0.0000001, nfolds = 25)
})

best_lambda <- test2$lambda[which.min(aics$aicc)]
coefs <- coef(test2, s = best_lambda)
plot(coefs)
rownames(coefs)[which(coefs > 0)]


beak_traits <- bird_dat[ , c("Beak.Length_Culmen", "Beak.Length_Nares", "Beak.Width", "Beak.Depth")]
beak_traits <- scale(log(beak_traits + 1))
ggpairs(as.data.frame(beak_traits)) +
  theme_minimal()

system.time({
  beak_log <- glmnet(x, beak_traits, standardize = FALSE, nlambda = 1000, lambda.min.ratio = 0.00001 / nrow(x),
                  penalty.factor = sqrt(lens),
                  family = "mgaussian")
  aics <- crits(beak_log, x, beak_traits)
})

best_mod <- which.min(aics$bic_l)
best_lam <- beak_log$lambda[best_mod]
aics$rate_est[best_mod] * 2

rtp_all <- make_root2tip(bird_tree, "both")
preds <- predict(beak_log, s = best_lam, newx = rtp_all)
tip_preds <- predict(beak_log, s = best_lam, newx = x)

rates <- lapply(coef(beak_log, s = best_lam), function(x) x[-1])
rates <- do.call(cbind, rates)
nonzero <- which(rates[ , 1] != 0)

rates_st <- rates[nonzero, ]
rates_st <- rates_st / sqrt(rowSums(rates_st^2))

ggpairs(as.data.frame(rates_st),
        upper = list(continuous = "density", combo = "box_no_facet", discrete = "count", na = "na"),) +
  theme_minimal()

traj_size <- sqrt(rowSums(rates[nonzero, ]^2))

library(movMF)
test_MF <- lapply(seq_len(15), function(x) movMF(rates_st, x, nruns = 10))
plot(sapply(test_MF, BIC), type = "l")
best_MF <- test_MF[[which.min(sapply(test_MF, BIC))]]
preds_MF_clust <- predict(best_MF)
preds_MF_memb <- predict(best_MF, type = "memberships")

preds_MF_prob <- dmovMF(rates_st, best_MF$theta, best_MF$alpha)

commonness <- preds_MF_prob
hist(commonness)
quantile(commonness, c(0.05, 0.5, 0.95))

least <- which(commonness <= quantile(commonness, 0.15))
rates[nonzero, ][least, ]

ggpairs(as.data.frame(rates_st[least, ]))

plot(bird_tree, show.tip.label = FALSE)

#nonzero_edge <- bird_tree$edge[edge_ord, ][nonzero, ][least, ]
least_nonzero_edge <- edge_ord[nonzero][least]
least_nonzero_node <- bird_tree$edge[least_nonzero_edge, 2]
least_nonzero_node2 <- bird_tree$edge[least_nonzero_edge, 1]

nodelabels(node = least_nonzero_node, pch = 21, col = "black", bg = "red", cex = 2)
nodelabels(node = least_nonzero_node2, pch = 21, col = "black", bg = "blue", cex = 2)
#edgelabels(edge = least_nonzero_edge, pch = 21, col = "black", bg = "red")


least_nonzero_edge2 <- edge_ord[nonzero]
least_nonzero_node2 <- bird_tree$edge[least_nonzero_edge2, 2]
plot(bird_tree, show.tip.label = FALSE)
nodelabels(node = least_nonzero_node2, pch = 21, col = "black", bg = "red", cex = 1.25)


closest_dists <- find_nearest_node_dist(bird_tree, least_nonzero_node2,
                                        bt_res$SpeciationShiftLocations,
                                        "both")

plot(abs(closest_dists$steps) ~ log(preds_MF_prob))
plot(abs(closest_dists$length) ~ log(preds_MF_prob))

plot(closest_dists$length ~ I(-log(preds_MF_prob)))

summary(lm(abs(closest_dists$length) ~ log(preds_MF_prob)))

summary(lm(log(abs(closest_dists$length) + 1) ~ log(preds_MF_prob)))
plot(lm(sqrt(abs(closest_dists$length)) ~ log(preds_MF_prob)))

summary(lm(abs(closest_dists$length) ~ I(log(preds_MF_prob)^2)))
summary(lm(abs(closest_dists$steps) ~ log(preds_MF_prob)))

nulls <- replicate(1000, lm(sqrt_abs_dist ~ log_dens,
                            data = data.frame(sqrt_abs_dist = sample(sqrt(abs(closest_dists$length))),
                                              log_dens = log(preds_MF_prob))), 
                   simplify = FALSE)

xs <- range(log(preds_MF_prob))
xs <- seq(xs[1], xs[2], length.out = 100)

null_preds <- sapply(nulls, function(x) predict(x, newdata = data.frame(log_dens = xs)))

ribbon <- apply(null_preds, 1, function(x) quantile(x, c(0.025, 0.975)))
ribbon <- data.frame(neg_log_dens = -xs, lower = ribbon[1, ]^2, upper = ribbon[2, ]^2)

ribbon_68 <- apply(null_preds, 1, function(x) quantile(x, c(0.16, 0.84)))
ribbon_68 <- data.frame(neg_log_dens = -xs, lower = ribbon_68[1, ]^2, upper = ribbon_68[2, ]^2)

plot_dat <- data.frame(abs_dist = abs(closest_dists$length),
                       neg_log_dens = -log(preds_MF_prob),
                       traj_size = traj_size)

mod <- lm(sqrt_abs_dist ~ log_dens,
          data = data.frame(sqrt_abs_dist = sqrt(abs(closest_dists$length)),
                            log_dens = log(preds_MF_prob)))
fitted <- predict(mod, newdata = data.frame(log_dens = xs))
fitted_df <- data.frame(neg_log_dens = -xs, abs_dist = fitted^2)

ggplot(plot_dat, aes(neg_log_dens, abs_dist)) +
  geom_ribbon(aes(x = neg_log_dens,
                  ymin = lower,
                  ymax = upper), 
              data = ribbon,
              inherit.aes = FALSE,
              fill = "grey80") +
  geom_ribbon(aes(x = neg_log_dens,
                  ymin = lower,
                  ymax = upper), 
              data = ribbon_68,
              inherit.aes = FALSE,
              fill = "grey60") +
  #geom_smooth(method = "lm", se = FALSE, colour = "grey20") +
  geom_path(data = fitted_df, colour = "grey20", size = 1.3) +
  geom_point(size = 2) +
  #geom_point(aes(size = traj_size)) +
  #scale_size_continuous(trans = "log1p") +
  ylab("Time to Closest Speciation Shift (millions of years)") +
  xlab("Trajectory Novelty (-log(density))") +
  #scale_y_sqrt() +
  theme_minimal()


quantile(nulls[2, ], c(0.025, 0.975))

library(rgl)
vectors <- best_MF$theta
beak_WD <- apply(vectors[ , 3:4], 1, mean)
vectors <- cbind(vectors[ , 1:2], Beak.WD = beak_WD)

vectors <- best_MF$theta[ , 1:3]
norms <- sqrt(rowSums(vectors^2))
vectors <- vectors / norms

rgl::close3d()
rgl::open3d()
rgl::arrow3d(c(0, 0, 0), vectors[1, ])
rgl::arrow3d(c(0, 0, 0), vectors[2, ])
rgl::arrow3d(c(0, 0, 0), vectors[3, ])
rgl::arrow3d(c(0, 0, 0), vectors[4, ])
rgl::arrow3d(c(0, 0, 0), vectors[5, ])
rgl::arrow3d(c(0, 0, 0), vectors[6, ])
rgl::arrow3d(c(0, 0, 0), vectors[7, ])
rgl::arrow3d(c(0, 0, 0), vectors[8, ])

rates_st2 <- rates_st[ , 1:3]
rates_st2 <- rates_st2 / sqrt(rowSums(rates_st2^2))

rgl::points3d(rates_st2)
rgl::spheres3d(0, 0, 0, front = "lines", back = "lines", color = "grey")
rgl::arrow3d(c(0, 0, 0), vectors[1, ])
rgl::arrow3d(c(0, 0, 0), vectors[2, ])
rgl::arrow3d(c(0, 0, 0), vectors[3, ])
rgl::arrow3d(c(0, 0, 0), vectors[4, ])
rgl::arrow3d(c(0, 0, 0), vectors[5, ])
rgl::arrow3d(c(0, 0, 0), vectors[6, ])
rgl::arrow3d(c(0, 0, 0), vectors[7, ])
rgl::arrow3d(c(0, 0, 0), vectors[8, ])
axes3d()

library(Rvcg)
ell <- shape::getellipse(1, 0.75)
cyl <- rgl::cylinder3d(matrix(c(0,0,0,1,0,0), ncol = 3), section = ell)
cone <- Rvcg::vcgCone(1, 0, 2)
cone <- scale3d(cone, 0.75, 1, 1)
cone <- translate3d(cone, 0, 2, 0)

shade3d(cyl, color = "yellow")
shade3d(cone, color = "yellow")

pred_mat <- predict(beak_log, s = best_lam, newx = x)[ , , 1]
plot(beak_traits[ , 1] ~ pred_mat[ , 1])
abline(0, 1)
plot(beak_traits[ , 2] ~ pred_mat[ , 2])
abline(0, 1)
plot(beak_traits[ , 3] ~ pred_mat[ , 3])
abline(0, 1)
plot(beak_traits[ , 4] ~ pred_mat[ , 4])
abline(0, 1)




beak_traits_bt <- exp(beak_traits) - 1
pred_mat_bt <- exp(pred_mat) - 1
plot(beak_traits_bt[ , 1] ~ pred_mat_bt[ , 1])

system.time({
  beak_bt <- glmnet(x, beak_traits_bt, standardize = FALSE, nlambda = 1000, lambda.min.ratio = 0.00001 / nrow(x),
                    penalty.factor = sqrt(lens),
                    family = "mgaussian")
  aics_bt <- crits(beak_bt, x, beak_traits_bt)
})

best_mod_bt <- which.min(aics_bt$bic_l)
best_lam_bt <- beak_bt$lambda[best_mod_bt]
aics_bt$rate_est[best_mod_bt] * 2

#rtp_all <- make_root2tip(bird_tree, "both")
preds_bt <- predict(beak_bt, s = best_lam_bt, newx = rtp_all)
tip_preds <- predict(beak_bt, s = best_lam_bt, newx = x)

rates_bt <- lapply(coef(beak_bt, s = best_lam_bt), function(x) x[-1])
rates_bt <- do.call(cbind, rates_bt)
nonzero_bt <- which(rates_bt[ , 1] != 0)

rates_st_bt <- rates_bt[nonzero_bt, ]
rates_st_bt <- rates_st_bt / sqrt(rowSums(rates_st_bt^2))

ggpairs(as.data.frame(rates_st_bt),
        upper = list(continuous = "density", combo = "box_no_facet", discrete = "count", na = "na"),) +
  theme_minimal()

traj_size_bt <- sqrt(rowSums(rates_bt[nonzero_bt, ]^2))

test_MF_bt <- lapply(seq_len(18), function(x) movMF(rates_st_bt, x))
plot(sapply(test_MF_bt, BIC), type = "l")
best_MF_bt <- test_MF_bt[[which.min(sapply(test_MF_bt, BIC))]]
preds_MF_clust_bt <- predict(best_MF_bt)
preds_MF_memb_bt <- predict(best_MF_bt, type = "memberships")

preds_MF_prob_bt <- dmovMF(rates_st_bt, best_MF_bt$theta, best_MF_bt$alpha)

commonness_bt <- preds_MF_prob_bt
hist(commonness_bt)
quantile(commonness_bt, c(0.05, 0.5, 0.95))

least_bt <- which(commonness_bt <= quantile(commonness_bt, 0.10))
rates_bt[nonzero_bt, ][least_bt, ]

ggpairs(as.data.frame(rates_st_bt[least_bt, ]))

plot(bird_tree, show.tip.label = FALSE)

#nonzero_edge <- bird_tree$edge[edge_ord, ][nonzero, ][least, ]
least_nonzero_edge_bt <- edge_ord[nonzero_bt][least_bt]
least_nonzero_node_bt <- bird_tree$edge[least_nonzero_edge_bt, 2]
least_nonzero_node2_bt <- bird_tree$edge[least_nonzero_edge_bt, 1]

nodelabels(node = least_nonzero_node_bt, pch = 21, col = "black", bg = "red", cex = 2)
nodelabels(node = least_nonzero_node2_bt, pch = 21, col = "black", bg = "blue", cex = 2)
#edgelabels(edge = least_nonzero_edge, pch = 21, col = "black", bg = "red")

least_nonzero_edge2_bt <- edge_ord[nonzero_bt]
least_nonzero_node2_bt <- bird_tree$edge[least_nonzero_edge2_bt, 2]
plot(bird_tree, show.tip.label = FALSE)
nodelabels(node = least_nonzero_node2_bt, pch = 21, col = "black", bg = "red", cex = 1.25)


closest_dists_bt <- find_nearest_node_dist(bird_tree, least_nonzero_node2_bt,
                                           bt_res$SpeciationShiftLocations,
                                           "both")

plot(abs(closest_dists_bt$steps) ~ log(preds_MF_prob_bt))
plot(abs(closest_dists_bt$length) ~ log(preds_MF_prob_bt))

summary(lm(sqrt(abs(closest_dists_bt$length)) ~ log(preds_MF_prob_bt)))
summary(lm(sqrt(abs(closest_dists_bt$steps)) ~ log(preds_MF_prob_bt)))
plot(lm(sqrt(abs(closest_dists_bt$steps)) ~ log(preds_MF_prob_bt)))

plot(lm(abs(closest_dists$length) ~ log(preds_MF_prob)))

nulls_l_bt <- replicate(1000, lm(sqrt_abs_dist ~ log_dens,
                            data = data.frame(sqrt_abs_dist = sample(sqrt(abs(closest_dists_bt$length))),
                                              log_dens = log(preds_MF_prob_bt))), 
                   simplify = FALSE)

nulls_s_bt <- replicate(1000, lm(sqrt_abs_dist ~ log_dens,
                            data = data.frame(sqrt_abs_dist = sample(sqrt(abs(closest_dists_bt$steps))),
                                              log_dens = log(preds_MF_prob_bt))), 
                   simplify = FALSE)


xs_bt <- range(log(preds_MF_prob_bt))
xs_bt <- seq(xs_bt[1], xs_bt[2], length.out = 100)

null_preds_l_bt <- sapply(nulls_l_bt, function(x) predict(x, newdata = data.frame(log_dens = xs_bt)))
null_preds_s_bt <- sapply(nulls_s_bt, function(x) predict(x, newdata = data.frame(log_dens = xs_bt)))

ribbon_bt <- apply(null_preds_s_bt, 1, function(x) quantile(x, c(0.025, 0.975)))
ribbon_bt <- data.frame(neg_log_dens = -xs_bt, lower = ribbon_bt[1, ]^2, upper = ribbon_bt[2, ]^2)

ribbon_68_bt <- apply(null_preds_s_bt, 1, function(x) quantile(x, c(0.16, 0.84)))
ribbon_68_bt <- data.frame(neg_log_dens = -xs_bt, lower = ribbon_68_bt[1, ]^2, upper = ribbon_68_bt[2, ]^2)

plot_dat_bt <- data.frame(abs_dist = abs(closest_dists_bt$steps),
                       neg_log_dens = -log(preds_MF_prob_bt),
                       traj_size = traj_size_bt)

mod_bt <- lm(sqrt_abs_dist ~ log_dens,
          data = data.frame(sqrt_abs_dist = sqrt(abs(closest_dists_bt$steps)),
                            log_dens = log(preds_MF_prob_bt)))
fitted_bt <- predict(mod_bt, newdata = data.frame(log_dens = xs_bt))
fitted_df_bt <- data.frame(neg_log_dens = -xs_bt, abs_dist = fitted_bt^2)

ggplot(plot_dat_bt, aes(neg_log_dens, abs_dist)) +
  geom_ribbon(aes(x = neg_log_dens,
                  ymin = lower,
                  ymax = upper), 
              data = ribbon_bt,
              inherit.aes = FALSE,
              fill = "grey80") +
  geom_ribbon(aes(x = neg_log_dens,
                  ymin = lower,
                  ymax = upper), 
              data = ribbon_68_bt,
              inherit.aes = FALSE,
              fill = "grey60") +
  #geom_smooth(method = "lm", se = FALSE, colour = "grey20") +
  geom_path(data = fitted_df_bt, colour = "grey20", size = 1.3) +
  geom_point(size = 2, alpha = 0.5) +
  #geom_point(aes(size = traj_size)) +
  #scale_size_continuous(trans = "log1p") +
  ylab("Nodes between Closest Speciation Shift") +
  xlab("Trajectory Novelty (-log(density))") +
  #scale_y_sqrt() +
  theme_minimal()


pred_mat_bt <- predict(beak_bt, s = best_lam_bt, newx = x)[ , , 1]
plot(log(beak_traits_bt[ , 1] + 1) ~ log(pred_mat_bt[ , 1] + 1))
abline(0, 1)
plot(beak_traits_bt[ , 2] ~ pred_mat_bt[ , 2])
abline(0, 1)
plot(beak_traits_bt[ , 3] ~ pred_mat_bt[ , 3])
abline(0, 1)
plot(beak_traits_bt[ , 4] ~ pred_mat_bt[ , 4])
abline(0, 1)

res_bt <- pred_mat_bt - beak_traits_bt
hist(as.vector(res_bt), breaks = 100)
plot(log(as.vector(pred_mat_bt) + 1), as.vector(res_bt))

Here I am going to develop some functions to simulate from various phylogenetic branch regression models based on ridge, student-t, lasso, and horseshoe regression.

My function

#' Simulate from a phylogenetic branch regression
#'
#' @return 1
#' @export
#'
#' @examples
fibre_sim <- function(phy, rate_dist = c("gaussian", "student-t", "laplacian", "horseshoe"), 
                     order = c("first", "second"), global_rate = 1, rtp = NULL) {

  rate_dist <- match.arg(rate_dist)
  order <- match.arg(order)

  n <- ape::Nedge(phy)

  if(is.null(rtp)) {
    rtp <- make_root2tip(phy, "both", order = order)
  }

  lambda <- switch(rate_dist,
                   gaussian = 1,
                   `student-t` = 1 / rgamma(n, 1, 1),
                   laplacian = rexp(n, 2),
                   horseshoe = abs(rcauchy(n))^2)

  beta <- rnorm(n, 0, lambda * global_rate)
  beta <- beta / sqrt(phy$edge.length)

  trait_vals <- rtp %*% matrix(beta, ncol = 1)

  trait_vals

}

plot_traits <- function(phy, trait_vals, ...) {
  trait_vals <- sign(trait_vals) * sqrt(abs(trait_vals))
  x <- trait_vals[1:ape::Ntip(phy)]
  names(x) <- rownames(trait_vals)[1:ape::Ntip(phy)]
  phytools::contMap(phy, x,
                    anc.states = trait_vals[setdiff(1:length(trait_vals), 1:ape::Ntip(phy))],
                    method = "user", ...)
}

set.seed(3021122)

tree <- pbtree(n = 100, scale = 5)
plot(tree)

dev.off()
old <- par(mfrow = c(2, 2))

gauss <- fibre_sim(tree, rate_dist = "gaussian", order = "first")
plot_traits(tree, gauss, fsize = c(0.001, 1), mar = c(3, 2, 2, 1))
title("Gaussian (ridge) e.g. 'Brownian motion'")

student <- fibre_sim(tree, rate_dist = "student-t", order = "first")
plot_traits(tree, student, fsize = c(0.001, 1), mar = c(3, 2, 2, 1))
title("Student-t")

laplace <- fibre_sim(tree, rate_dist = "laplacian", order = "first")
plot_traits(tree, laplace, fsize = c(0.001, 1), mar = c(3, 2, 2, 1))
title("Laplacian (lasso)")

hs <- fibre_sim(tree, rate_dist = "horseshoe", order = "first")
plot_traits(tree, hs, fsize = c(0.001, 1), mar = c(3, 2, 2, 1))
title("Horseshoe")

par(old)

dev.off()
old <- par(mfrow = c(2, 2))

gauss <- fibre_sim(tree, rate_dist = "gaussian", order = "second")
plot_traits(tree, gauss, fsize = c(0.001, 1), mar = c(3, 2, 2, 1))
title("Gaussian (ridge)")

student <- fibre_sim(tree, rate_dist = "student-t", order = "second")
plot_traits(tree, student, fsize = c(0.001, 1), mar = c(3, 2, 2, 1))
title("Student-t")

laplace <- fibre_sim(tree, rate_dist = "laplacian", order = "second")
plot_traits(tree, laplace, fsize = c(0.001, 1), mar = c(3, 2, 2, 1))
title("Laplacian (lasso)")

hs <- fibre_sim(tree, rate_dist = "horseshoe", order = "second")
plot_traits(tree, hs, fsize = c(0.001, 1), mar = c(3, 2, 2, 1))
title("Horseshoe")

par(old)
skeleton()
test_that("fibre_sim works", {

})
# Run but keep eval=FALSE to avoid infinite loop
# Execute in the console directly
fusen::inflate(flat_file = "dev/flat_skeleton.Rmd", vignette_name = "Go further")


rdinnager/fibre documentation built on Dec. 14, 2024, 10:33 a.m.