SRCL_0_synthetic_data: SRCL synthetic data

Description Usage Arguments Examples

View source: R/SRCL_functions.R

Description

To reproduce the synthetic data from the paper Synergistic Cause Learning.

Usage

1

Arguments

n

number of observations for the synthetic data

Examples

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
	library(SRCL)
	library(graphics)
	colours <- c("grey","dodgerblue","red","orange","green")

	# Data simulation
	set.seed(1234567)
	data <- SRCL_0_synthetic_data(100) # use 40 000 to replicate the paper

	# Code data monotonisticly
	lm(Y~.,data)
	recode <- lm(Y~.,data)$coefficients<0
	for (i in 2:ncol(data)) {
	  if(recode[i]==TRUE) colnames(data)[i] <- paste0("Not_",colnames(data)[i])
	  if(recode[i]==TRUE) data[,i] = 1 - data[,i]
	}
	summary(lm(Y~.,data))
	exposure_data <- data[,-1]
	outcome_data <- data[,1]

	# Model fit
	model <- SRCL_1_initiate_neural_network(inputs=ncol(exposure_data),hidden=5)
	for (lr_set in c(0.001,0.0001,0.00001)) {
	  model <- SRCL_2_train_neural_network(exposure_data,outcome_data,model,
	   lr = lr_set, epochs=1000,patience = 100,plot_and_evaluation_frequency = 50)
	}

	# Performance
	par(mar=c(5,5,2,0))
	plot(model$train_performance, type='l',yaxs='i', ylab="Mean squared error",
	     xlab="Epochs",main="Performance")

	# Model visualisation
	par(mar=c(0,0,0,0))
	SRCL_3_plot_neural_network(model,names(exposure_data),5)

	# AUC
	library(pROC)
	par(mar=c(5,5,2,0))
	pred <- SRCL_4_predict_risks(exposure_data,model)
	plot(roc(outcome_data,pred),print.auc=TRUE,main="Accuracy")

	# Risk contributions
	r_c <- SRCL_5_layerwise_relevance_propagation(exposure_data,model)

	# Clustering
	groups =3
	library(fastcluster)
	hc <- hclust(dist(r_c), method="ward.D") # RAM memory intensive
	clus <- cutree(hc, groups)
	p <- cbind(r_c,clus)
	library(plyr)
	p <- count(p)
	pfreq <- p$freq
	pclus <- p$clus
	p <- p[,-c(ncol(p)-1,ncol(p))]
	p <- hclust(dist(p),method = "ward.D", members=pfreq)
	par(mfrow=c(1,1))
	par(mar=c(5,5,5,5))
	library(ggtree)
	library(ggplot2)
	ggtree(p,layout="equal_angle") +
	  geom_tippoint(size=sqrt(pfreq)/2, alpha=.2, color=colours[pclus])+
	  ggtitle("Dendrogram") +
	  theme(plot.title = element_text(size = 15, face = "bold"))

	# Plot with the prevalence and mean risks
	par(mar=c(4,5,2,1))
	plot(0,0,type='n',xlim=c(0,1),asp=1,ylim=c(0,1),xaxs='i',yaxs='i',
	     axes=FALSE,ylab="Risk",xlab="Prevalence",frame.plot=FALSE,
	     main="Prevalence and mean risk of sub-groups")
	axis(1,seq(0,1,.2));axis(2,seq(0,1,.2))
	rect(0,0,1,1)
	prev0 = 0; total = 0
	for (i in 1:groups) {
	  prev <- sum(clus==i)/length(clus)
	  risk <- sum(colMeans(as.matrix(r_c[clus==i,])))
	  rect(xleft = prev0,ybottom = 0,xright = prev+prev0,ytop = risk, col=colours[i])
	  prev0 = prev + prev0
	  total = total + risk * prev
	}
	arrows(x0=0,x1=1,y0=mean(r_c$Baseline_risk),lty=1,length=0)

	# The table with risk contributions
	st <- 1.5
	d <- data.frame(matrix(NA, nrow=ncol(r_c)))
	for (g in 1:groups) {
	  for (i in 1:nrow(d)) {
	    d[i,g] <- mean(r_c[clus==g,i])
	  }}
	d <- t(d)
	rownames(d) <- paste("Group",1:groups)
	colnames(d) <- names(r_c)
	par(mar=c(0,0,0,0))
	plot(0,0,type='n',xlim=c(-ncol(d)-5,0),ylim=c(-nrow(d)-1,1),axes=FALSE)
	text(c(-ncol(d)):c(-1),0,rev(colnames(d)),srt=25,cex=st)
	text(-ncol(d)-5,0,"Mean (SD) risk contributions\nby sub-group",pos=4,cex=st)
	for (i in 1:groups) {
	  prev <- sum(clus==i)/length(clus)
	  risk <- sum(colMeans(as.matrix(r_c[clus==i,])))
	  risk_obs <- mean(outcome_data[clus==i])
	  text(-ncol(d)-5,-i,paste0("Sub-group ",i,": ","n=",sum(clus==i),", e=",
	    sum(outcome_data[clus==i]),",\nPrev=",format(round(prev*100,1),nsmall=1),"%,
	    risk=",format(round(risk*100,1),nsmall=1),"%, excess=",
	    format(round(prev*(risk-mean(r_c$Baseline_risk))/total*100,1),nsmall=1),
	    "%,\nObs risk=",format(round(risk_obs*100,1),nsmall=1),"% (",
	    paste0(format(round(prop.test(sum(outcome_data[clus==i]),
	    length(t(outcome_data)[clus==i]))$conf.int*100,1),nsmall=1),collapse="-"),
	    "%)"),pos=4,col=colours[i])
	}
	m <- max(d)
	for(g in 1:ncol(d)) { for (i in 1:nrow(d)){
	  value <- paste0(format(round(as.numeric(d[i,g]),2),nsmall=2),"\n(",
	                  format(round(sd(r_c[clus==i,g]),2),nsmall=2),")")
	  text(-g,-i,value,col=adjustcolor(colours[i],d[i,g]/m),cex=st*d[i,g]/m)
	}}

ekstroem/SRCL documentation built on Sept. 5, 2020, 8:59 p.m.