knitr::opts_chunk$set(echo = TRUE)
library(HealthyOralMicrobiome)
library(caret)
library(doParallel)
library(PRROC)
library(ggplot2)
library(cowplot)
#Load in metadata
Metadata <- HealthyOralMicrobiome::Healthy_Metadata
dim(Metadata)


ASV_tab <- HealthyOralMicrobiome::Healthy_ASV_table
dim(ASV_tab)

colSums(ASV_tab)

Comp_ASV_data <- Filt_samples(depth=5000, setType="complete", Taxa_table = ASV_tab, Metadata = Metadata, parallel = 20, cutoff_prop = 0.1)
dim(Comp_ASV_data[[2]])


Valid_ASV_data <- Filt_samples(depth=5000, setType="validation", Taxa_table = ASV_tab, Metadata = Metadata, parallel = 20, cutoff_prop = 0.1)
dim(Valid_ASV_data[[2]])


### Convert them to RA
Comp_ASV_RA <- data.frame(t(Comp_ASV_data[[2]]))
Comp_ASV_RA <- sweep(Comp_ASV_RA, 1, rowSums(Comp_ASV_RA), '/')

Valid_ASV_RA <- data.frame(t(Valid_ASV_data[[2]]))
Valid_ASV_RA <- sweep(Valid_ASV_RA, 1, rowSums(Valid_ASV_RA), '/')
## set up machine learning parameters

## we will try 7 different mtrys during training
mtry <- round(seq(1, 306/3, length=7))
mtry
grid <- expand.grid(mtry = mtry)

cv <- trainControl(method="repeatedcv",
                       number=3,
                       returnResamp = "final",
                       summaryFunction=caret::twoClassSummary,
                       classProbs=TRUE,
                       savePredictions = TRUE)



Comp_sex_classes <- ifelse(Comp_ASV_data[[1]]$A_SDC_GENDER==1, "M", "F")

cl <- makeCluster(20)
registerDoParallel(cl)

set.seed(1995)
Sex_Train <- train(Comp_ASV_RA, Comp_sex_classes,
                   method="rf",
                   trControl = cv,
                   metric="ROC",
                   tuneGrid = grid,
                   ntree=1501,
                   importance=T)

stopCluster(cl)

## now predict
Sex_train_predict <- predict.train(Sex_Train, newdata=Valid_ASV_RA, type="prob")
Sex_train_predict

Valid_Sex_classes <- ifelse(as.numeric(as.character(Valid_ASV_data[[1]]$A_SDC_GENDER))==1, 1, 0)

AUC_SEX <- PRROC::roc.curve(scores.class0 = Sex_train_predict$M, weights.class0 = Valid_Sex_classes, curve=T, rand.compute = T)
AUC_SEX

p1 <- function() {
  plot(AUC_SEX, color=F, xlab="False Postive Rate", rand.plot = T)
}

AUROC_SEX <- ggdraw(p1)

Models for regression

Train_numeric_RF <- function(Genus_data, feature){

  cv_num <- trainControl(method="repeatedcv", number=3, returnResamp = "final", savePredictions = TRUE)

  cl <- makeCluster(20)
  registerDoParallel(cl)

  Model_Train <- train(Genus_data, feature,
                   method="rf",
                   trControl = cv_num,
                   tuneGrid = grid,
                   ntree=1501,
                   importance=T, seed=1995)

  stopCluster(cl)

  return(Model_Train)



}

mtry <- round(seq(1, 306/3, length=7))
mtry
grid <- expand.grid(mtry = mtry)

features_to_test <- c("A_SDC_AGE_CALC",
                      "PM_STANDING_HEIGHT_AVG",
                      "PM_WAIST_AVG",
                      "PM_BIOIMPED_WEIGHT",
                      "PM_WAIST_HIP_RATIO",
                      "PM_BIOIMPED_FFM",
                      "SALT_SEASONING",
                      "NUTS_SEEDS_SERVINGS_PER_DAY",
                      "NUT_VEG_DAY_QTY",
                      "REFINED_GRAIN_SERVINGS_DAY_QTY",
                      "A_SLE_LIGHT_EXP",
                      "PM_BIOIMPED_BMI",
                      "NUT_JUICE_DAY_QTY",
                      "A_HS_DENTAL_VISIT_LAST")

Models <- list()
for(i in 1:length(features_to_test)){
  Models[[features_to_test[[i]]]] <- Train_numeric_RF(Comp_ASV_RA, Comp_ASV_data[[1]][,features_to_test[[i]]])
}
saveRDS(Models, file="~/Private/Sequences/Redo_Combined_Data/deblur/regression_Models_Complete.rds")

Model Performance Plot

Models <- readRDS("~/Private/Sequences/Redo_Combined_Data/deblur/regression_Models_Complete.rds")


Valid_Pred <- list()
Get_Valid_Results <- function(Model, new_data){
  prediction_res <- caret::predict.train(Model, new_data)
  return(prediction_res)

}
for(i in 1:length(features_to_test)){

  Valid_Pred[[features_to_test[[i]]]] <- Get_Valid_Results(Models[[features_to_test[[i]]]], Valid_ASV_RA)

}
R2_values <- list()

Get_R2_results <- function(Pred, Real){
  R2_data <- caret::postResample(Pred, Real)
  return(R2_data)

}

for(i in 1:length(features_to_test)){
  R2_values[[features_to_test[[i]]]] <- Get_R2_results(Valid_Pred[[features_to_test[[i]]]], Valid_ASV_data[[1]][,features_to_test[[i]]])
}

R2_values_vector <- vector()
for(i in 1:length(features_to_test)){
  R2_values_vector[[features_to_test[[i]]]]  <- R2_values[[features_to_test[[i]]]][2]
}

RF_data_df <- data.frame("R2"=R2_values_vector)
rownames(RF_data_df)
RF_data_df$Feature <- c("Age", "Height", "Waist Size", "Weight", "Waist Hip Ratio",
                        "Fat Free Mass", "Salt Usage", "Nut/Seed Servings", "Vegetable Servings", "Refined Grains Servings",
                        "Sleeping Light Exposure", "Body Mass Index", "Juice Servings", "Last Dental Visit")
RF_data_df$Type <- c("Age", "Anthropometric", "Anthropometric", "Anthropometric", "Anthropometric",
                     "Anthropometric", "Diet", "Diet", "Diet", "Diet", "Lifestyle", "Anthropometric", "Diet", "Lifestyle")

RF_data_df$color <- "black"


for(i in 1:length(RF_data_df$Type)){

  if(RF_data_df$Type[i]=="Age"){
    RF_data_df$color[i] <- "#696969"
  }else if(RF_data_df$Type[i]=="Anthropometric"){
    RF_data_df$color[i] <- "#fa8072"
  }else if(RF_data_df$Type[i]=="Diet"){
    RF_data_df$color[i] <- "#dda0dd"
  }else if(RF_data_df$Type[i]=="Lifestyle"){
    RF_data_df$color[i] <- "#ff1493"
  }else if(RF_data_df$Type[i]=="Sex"){
    RF_data_df$color[i] <- "#f5deb3"
  }else if(RF_data_df$Type[i]=="Life Style"){
    RF_data_df$color[i] <- "#87cefa"
  }
}

theme_set(theme_classic())
RF_data_df$Feature <- factor(RF_data_df$Feature, levels = RF_data_df$Feature[order(RF_data_df$R2, decreasing = T)])

RF_regression_plot <- ggplot(RF_data_df, aes(x=Feature, y=R2, fill=color)) + geom_bar(stat="identity") + coord_flip() +
  facet_grid(row=as.factor(RF_data_df$Type), scales="free", space="free") + ggtitle("Random Forest Performance") + 
  theme(plot.title = element_text(hjust=0.5)) + theme(strip.background = element_blank(), strip.text.y = element_blank()) +
  guides(fill=guide_legend(title="Type")) + scale_fill_identity(guide="legend", labels=c("Age", "Diet", "Anthropometric", "Lifestyle")) +
  ylab(expression(paste(R^2)))

RF_regression_plot

final plot

Final_fig <- plot_grid(AUROC_SEX, RF_regression_plot, labels = c("A", "B"))
Final_fig


nearinj/Nearing_et_al_2020_Oral_Microbiome documentation built on Sept. 6, 2020, 8:25 p.m.