Description Usage Arguments Examples
View source: R/SRCL_functions.R
To reproduce the synthetic data from the paper Synergistic Cause Learning.
1 |
n |
number of observations for the synthetic data |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | library(SRCL)
library(graphics)
colours <- c("grey","dodgerblue","red","orange","green")
# Data simulation
set.seed(1234567)
data <- SRCL_0_synthetic_data(100) # use 40 000 to replicate the paper
# Code data monotonisticly
lm(Y~.,data)
recode <- lm(Y~.,data)$coefficients<0
for (i in 2:ncol(data)) {
if(recode[i]==TRUE) colnames(data)[i] <- paste0("Not_",colnames(data)[i])
if(recode[i]==TRUE) data[,i] = 1 - data[,i]
}
summary(lm(Y~.,data))
exposure_data <- data[,-1]
outcome_data <- data[,1]
# Model fit
model <- SRCL_1_initiate_neural_network(inputs=ncol(exposure_data),hidden=5)
for (lr_set in c(0.001,0.0001,0.00001)) {
model <- SRCL_2_train_neural_network(exposure_data,outcome_data,model,
lr = lr_set, epochs=1000,patience = 100,plot_and_evaluation_frequency = 50)
}
# Performance
par(mar=c(5,5,2,0))
plot(model$train_performance, type='l',yaxs='i', ylab="Mean squared error",
xlab="Epochs",main="Performance")
# Model visualisation
par(mar=c(0,0,0,0))
SRCL_3_plot_neural_network(model,names(exposure_data),5)
# AUC
library(pROC)
par(mar=c(5,5,2,0))
pred <- SRCL_4_predict_risks(exposure_data,model)
plot(roc(outcome_data,pred),print.auc=TRUE,main="Accuracy")
# Risk contributions
r_c <- SRCL_5_layerwise_relevance_propagation(exposure_data,model)
# Clustering
groups =3
library(fastcluster)
hc <- hclust(dist(r_c), method="ward.D") # RAM memory intensive
clus <- cutree(hc, groups)
p <- cbind(r_c,clus)
library(plyr)
p <- count(p)
pfreq <- p$freq
pclus <- p$clus
p <- p[,-c(ncol(p)-1,ncol(p))]
p <- hclust(dist(p),method = "ward.D", members=pfreq)
par(mfrow=c(1,1))
par(mar=c(5,5,5,5))
library(ggtree)
library(ggplot2)
ggtree(p,layout="equal_angle") +
geom_tippoint(size=sqrt(pfreq)/2, alpha=.2, color=colours[pclus])+
ggtitle("Dendrogram") +
theme(plot.title = element_text(size = 15, face = "bold"))
# Plot with the prevalence and mean risks
par(mar=c(4,5,2,1))
plot(0,0,type='n',xlim=c(0,1),asp=1,ylim=c(0,1),xaxs='i',yaxs='i',
axes=FALSE,ylab="Risk",xlab="Prevalence",frame.plot=FALSE,
main="Prevalence and mean risk of sub-groups")
axis(1,seq(0,1,.2));axis(2,seq(0,1,.2))
rect(0,0,1,1)
prev0 = 0; total = 0
for (i in 1:groups) {
prev <- sum(clus==i)/length(clus)
risk <- sum(colMeans(as.matrix(r_c[clus==i,])))
rect(xleft = prev0,ybottom = 0,xright = prev+prev0,ytop = risk, col=colours[i])
prev0 = prev + prev0
total = total + risk * prev
}
arrows(x0=0,x1=1,y0=mean(r_c$Baseline_risk),lty=1,length=0)
# The table with risk contributions
st <- 1.5
d <- data.frame(matrix(NA, nrow=ncol(r_c)))
for (g in 1:groups) {
for (i in 1:nrow(d)) {
d[i,g] <- mean(r_c[clus==g,i])
}}
d <- t(d)
rownames(d) <- paste("Group",1:groups)
colnames(d) <- names(r_c)
par(mar=c(0,0,0,0))
plot(0,0,type='n',xlim=c(-ncol(d)-5,0),ylim=c(-nrow(d)-1,1),axes=FALSE)
text(c(-ncol(d)):c(-1),0,rev(colnames(d)),srt=25,cex=st)
text(-ncol(d)-5,0,"Mean (SD) risk contributions\nby sub-group",pos=4,cex=st)
for (i in 1:groups) {
prev <- sum(clus==i)/length(clus)
risk <- sum(colMeans(as.matrix(r_c[clus==i,])))
risk_obs <- mean(outcome_data[clus==i])
text(-ncol(d)-5,-i,paste0("Sub-group ",i,": ","n=",sum(clus==i),", e=",
sum(outcome_data[clus==i]),",\nPrev=",format(round(prev*100,1),nsmall=1),"%,
risk=",format(round(risk*100,1),nsmall=1),"%, excess=",
format(round(prev*(risk-mean(r_c$Baseline_risk))/total*100,1),nsmall=1),
"%,\nObs risk=",format(round(risk_obs*100,1),nsmall=1),"% (",
paste0(format(round(prop.test(sum(outcome_data[clus==i]),
length(t(outcome_data)[clus==i]))$conf.int*100,1),nsmall=1),collapse="-"),
"%)"),pos=4,col=colours[i])
}
m <- max(d)
for(g in 1:ncol(d)) { for (i in 1:nrow(d)){
value <- paste0(format(round(as.numeric(d[i,g]),2),nsmall=2),"\n(",
format(round(sd(r_c[clus==i,g]),2),nsmall=2),")")
text(-g,-i,value,col=adjustcolor(colours[i],d[i,g]/m),cex=st*d[i,g]/m)
}}
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.