# compare
# compare class definition and show method
setClass( "compareIC" , slots=c( dSE="matrix" ) , contains="data.frame" )
#compare.show <- function( object ) {
# r <- format_show( object@output , #digits=c('default__'=1,'weight'=2,'SE'=2,'dSE'=2) )
# print( r )
#}
setMethod( "show" , "compareIC" , function(object) {
r <- format_show( object ,
digits=c('default__'=1,'weight'=2,'SE'=2,'dSE'=2) )
print( r )
} )
# new compare function, defaulting to WAIC
compare <- function( ... , n=1e3 , sort="WAIC" , func=WAIC , WAIC=TRUE , refresh=0 , warn=TRUE , result_order=c(1,5,3,6,2,4) , log_lik="log_lik" ) {
# retrieve list of models
L <- list(...)
if ( is.list(L[[1]]) && length(L)==1 )
L <- L[[1]]
if ( length(L)==1 ) stop( "Need more than one model to compare." )
# retrieve model names from function call
mnames <- match.call()
mnames <- as.character(mnames)[2:(length(L)+1)]
# use substitute to deparse the func argument
the_func <- func
if ( class(the_func) != "character" )
the_func <- deparse(substitute(func))
# check class of fit models and warn when more than one class represented
classes <- as.character(sapply( L , class ))
if ( any(classes!=classes[1]) & warn==TRUE ) {
warning("Not all model fits of same class.\nThis is usually a bad idea, because it implies they were fit by different algorithms.\nCheck yourself, before you wreck yourself.")
}
# check nobs for all models
# if different, warn
nobs_list <- try( sapply( L , nobs ) )
if ( any(nobs_list != nobs_list[1]) & warn==TRUE ) {
nobs_out <- paste( mnames , nobs_list , "\n" )
nobs_out <- concat(nobs_out)
warning(concat(
"Different numbers of observations found for at least two models.\nModel comparison is valid only for models fit to exactly the same observations.\nNumber of observations for each model:\n",nobs_out))
}
dSE.matrix <- matrix( NA , nrow=length(L) , ncol=length(L) )
# deprecate WAIC==TRUE/FALSE flag
# catch it and convert to func intent
if ( WAIC==FALSE ) func <- DIC # assume old code that wants DIC
if ( the_func=="DIC" ) {
IC.list <- lapply( L , function(z) DIC( z , n=n ) )
p.list <- sapply( IC.list , function(x) attr(x,"pD") )
}
if ( the_func %in% c("WAIC","PSIS","LOO") ) {
IC.list.pw <- lapply( L , function(z) do.call( the_func , list( z , n=n , refresh=refresh , pointwise=TRUE , warn=warn , log_lik=log_lik ) ) )
p.list <- sapply( IC.list.pw , function(x) sum( x$penalty ) )
se.list <- sapply( IC.list.pw , function(x) x$std_err[1] )
IC.list <- sapply( IC.list.pw , function(x) sum( x[[1]] ) )
# compute SE of differences between adjacent models from top to bottom in ranking
colnames(dSE.matrix) <- mnames
rownames(dSE.matrix) <- mnames
for ( i in 1:(length(L)-1) ) {
for ( j in (i+1):length(L) ) {
ic_ptw1 <- IC.list.pw[[i]][[1]]
ic_ptw2 <- IC.list.pw[[j]][[1]]
dSE.matrix[i,j] <- as.numeric( sqrt( length(ic_ptw1)*var( ic_ptw1 - ic_ptw2 ) ) )
dSE.matrix[j,i] <- dSE.matrix[i,j]
}#j
}#i
}
if ( !(the_func %in% c("DIC","WAIC","LOO","PSIS")) ) {
# unrecognized IC function; just wing it
IC.list <- lapply( L , function(z) func( z ) )
}
IC.list <- unlist(IC.list)
dIC <- IC.list - min( IC.list )
w.IC <- ICweights( IC.list )
if ( the_func=="DIC" )
result <- data.frame( DIC=IC.list , pD=p.list , dDIC=dIC , weight=w.IC )
if ( the_func=="WAIC" ) {
# find out which model has dWAIC==0
topm <- which( dIC==0 )
dSEcol <- dSE.matrix[,topm]
result <- data.frame( WAIC=IC.list , pWAIC=p.list , dWAIC=dIC ,
weight=w.IC , SE=se.list , dSE=dSEcol )
}
if ( the_func=="LOO" | the_func=="PSIS" ) {
topm <- which( dIC==0 )
dSEcol <- dSE.matrix[,topm]
result <- data.frame( PSIS=IC.list , pPSIS=p.list , dPSIS=dIC ,
weight=w.IC , SE=se.list , dSE=dSEcol )
}
if ( !(the_func %in% c("DIC","WAIC","LOO","PSIS")) ) {
result <- data.frame( IC=IC.list , dIC=dIC , weight=w.IC )
}
rownames(result) <- mnames
if ( !is.null(sort) ) {
if ( sort!=FALSE ) {
if ( the_func=="LOO" ) the_func <- "PSIS" # must match header
if ( sort=="WAIC" ) sort <- the_func
result <- result[ order( result[[sort]] ) , ]
}
}
result <- result[ , result_order ]
new( "compareIC" , result , dSE=dSE.matrix )
}
# plot method for compareIC results shows deviance in and expected deviance out of sample, for each model, ordered top-to-bottom by rank
setMethod("plot" , "compareIC" , function(x,y,xlim,SE=TRUE,dSE=TRUE,weights=FALSE,...) {
dev_in <- x[[1]] - x[[5]]*2 # criterion - penalty*2
dev_out <- x[[1]]
if ( !is.null(x[['SE']]) ) devSE <- x[['SE']]
dev_out_lower <- dev_out - devSE
dev_out_upper <- dev_out + devSE
if ( weights==TRUE ) {
dev_in <- ICweights(dev_in)
dev_out <- ICweights(dev_out)
dev_out_lower <- ICweights(dev_out_lower)
dev_out_upper <- ICweights(dev_out_upper)
}
n <- length(dev_in)
if ( missing(xlim) ) {
xlim <- c(min(dev_in),max(dev_out))
if ( SE==TRUE & !is.null(x[['SE']]) ) {
xlim[1] <- min(dev_in,dev_out_lower)
xlim[2] <- max(dev_out_upper)
}
}
main <- colnames(x)[1]
set_nice_margins()
dotchart( dev_in[n:1] , labels=rownames(x)[n:1] , xlab="deviance" , pch=16 , xlim=xlim , ... )
points( dev_out[n:1] , 1:n )
mtext(main)
# standard errors
if ( !is.null(x[['SE']]) & SE==TRUE ) {
for ( i in 1:n ) {
lines( c(dev_out_lower[i],dev_out_upper[i]) , rep(n+1-i,2) , lwd=0.75 )
}
}
if ( !all(is.na(x@dSE)) & dSE==TRUE ) {
# plot differences and stderr of differences
dcol <- col.alpha("black",0.5)
abline( v=dev_out[1] , lwd=0.5 , col=dcol )
diff_dev_lower <- dev_out - x$dSE
diff_dev_upper <- dev_out + x$dSE
if ( weights==TRUE ) {
diff_dev_lower <- ICweights(diff_dev_lower)
diff_dev_upper <- ICweights(diff_dev_upper)
}
for ( i in 2:n ) {
points( dev_out[i] , n+2-i-0.5 , cex=0.5 , pch=2 , col=dcol )
lines( c(diff_dev_lower[i],diff_dev_upper[i]) , rep(n+2-i-0.5,2) , lwd=0.5 , col=dcol )
}
}
})
if ( FALSE ) {
# AICc/BIC model comparison table
compare_old <- function( ... , nobs=NULL , sort="AICc" , BIC=FALSE , DIC=FALSE , delta=TRUE , DICsamples=1e4 ) {
require(bbmle)
if ( is.null(nobs) ) {
stop( "Must specify number of observations (nobs)." )
}
getdf <- function(x) {
if (!is.null(df <- attr(x, "df")))
return(df)
else if (!is.null(df <- attr(logLik(x), "df")))
return(df)
}
# need own BIC, as one in stats doesn't allow nobs
myBIC <- function(x,nobs) {
k <- getdf(x)
as.numeric( -2*logLik(x) + log(nobs)*k )
}
# retrieve list of models
L <- list(...)
if ( is.list(L[[1]]) && length(L)==1 )
L <- L[[1]]
# retrieve model names from function call
mnames <- match.call()
mnames <- as.character(mnames)[2:(length(L)+1)]
AICc.list <- sapply( L , function(z) AICc( z , nobs=nobs ) )
dAICc <- AICc.list - min( AICc.list )
post.AICc <- exp( -0.5*dAICc ) / sum( exp(-0.5*dAICc) )
if ( BIC==TRUE ) {
BIC.list <- sapply( L , function(z) myBIC( z , nobs=nobs ) )
dBIC <- BIC.list - min( BIC.list )
post.BIC <- exp( -0.5*dBIC ) / sum( exp(-0.5*dBIC) )
}
k <- sapply( L , getdf )
result <- data.frame( k=k , AICc=AICc.list , w.AICc=post.AICc )
if ( BIC==TRUE )
result <- data.frame( k=k , AICc=AICc.list , BIC=BIC.list , w.AICc=post.AICc , w.BIC=post.BIC )
if ( delta==TRUE ) {
r2 <- data.frame( dAICc=dAICc )
if ( BIC==TRUE ) r2 <- data.frame( dAICc=dAICc , dBIC=dBIC )
result <- cbind( result , r2 )
}
# DIC from quadratic approx posterior defined by vcov and coef
if ( DIC==TRUE ) {
DIC.list <- rep( NA , length(L) )
pD.list <- rep( NA , length(L) )
for ( i in 1:length(L) ) {
m <- L[[i]]
if ( class(m)=="map" ) {
post <- sample.qa.posterior( m , n=DICsamples )
message( paste("Computing DIC for model",mnames[i]) )
dev <- sapply( 1:nrow(post) ,
function(i) {
p <- post[i,]
names(p) <- names(post)
2*m@fminuslogl( p )
}
)
dev.hat <- deviance(m)
DIC.list[i] <- dev.hat + 2*( mean(dev) - dev.hat )
pD.list[i] <- ( DIC.list[i] - dev.hat )/2
}
}
ddic <- DIC.list - min(DIC.list)
wdic <- exp( -0.5*ddic ) / sum( exp(-0.5*ddic) )
rdic <- data.frame( DIC=as.numeric(DIC.list) , pD=pD.list , wDIC=wdic , dDIC=ddic )
result <- cbind( result , rdic )
}
# add model names to rows
rownames( result ) <- mnames
if ( !is.null(sort) ) {
result <- result[ order( result[[sort]] ) , ]
}
new( "compareIC" , output=result )
}
}#FALSE
# convert estimated D_test values to weights
ICweights <- function( dev ) {
d <- dev - min(dev)
f <- exp(-0.5*d)
w <- f/sum(f)
return(w)
}
# tests
if (FALSE) {
library(rethinking)
data(chimpanzees)
d <- list(
pulled_left = chimpanzees$pulled_left ,
prosoc_left = chimpanzees$prosoc_left ,
condition = chimpanzees$condition ,
actor = as.integer( chimpanzees$actor )
)
m0 <- map(
alist(
pulled_left ~ dbinom(1,theta),
logit(theta) <- a,
a ~ dnorm(0,1)
) ,
data=d,
start=list(a=0)
)
m1 <- map(
alist(
pulled_left ~ dbinom(1,theta),
logit(theta) <- a + bp*prosoc_left,
a ~ dnorm(0,1),
bp ~ dnorm(0,1)
) ,
data=d,
start=list(a=0,bp=0)
)
m2 <- map(
alist(
pulled_left ~ dbinom(1,theta),
logit(theta) <- a + bp*prosoc_left + bpc*condition*prosoc_left,
a ~ dnorm(0,1),
bp ~ dnorm(0,1),
bpc ~ dnorm(0,1)
) ,
data=d,
start=list(a=0,bp=0,bpc=0)
)
m3 <- map(
alist(
pulled_left ~ dbinom(1,theta),
logit(theta) <- a + bp*prosoc_left + bc*condition + bpc*condition*prosoc_left,
a ~ dnorm(0,1),
bp ~ dnorm(0,1),
bc ~ dnorm(0,1),
bpc ~ dnorm(0,1)
) ,
data=d,
start=list(a=0,bp=0,bc=0,bpc=0)
)
( x <- compare(m0,m1,m2,m3) )
plot(x)
# now ulam
m0 <- ulam(
alist(
pulled_left ~ dbinom(1,theta),
logit(theta) <- a,
a ~ dnorm(0,1)
) ,
data=d , log_lik=TRUE )
m1 <- ulam(
alist(
pulled_left ~ dbinom(1,theta),
logit(theta) <- a + bp*prosoc_left,
a ~ dnorm(0,1),
bp ~ dnorm(0,1)
) ,
data=d , log_lik=TRUE )
m2 <- ulam(
alist(
pulled_left ~ dbinom(1,theta),
logit(theta) <- a + bp*prosoc_left + bpc*condition*prosoc_left,
a ~ dnorm(0,1),
bp ~ dnorm(0,1),
bpc ~ dnorm(0,1)
) ,
data=d , log_lik=TRUE )
m3 <- ulam(
alist(
pulled_left ~ dbinom(1,theta),
logit(theta) <- a + bp*prosoc_left + bc*condition + bpc*condition*prosoc_left,
a ~ dnorm(0,1),
bp ~ dnorm(0,1),
bc ~ dnorm(0,1),
bpc ~ dnorm(0,1)
) ,
data=d , log_lik=TRUE )
m4 <- ulam(
alist(
pulled_left ~ dbinom(1,theta),
logit(theta) <- a + aj[actor] + bp*prosoc_left + bc*condition + bpc*condition*prosoc_left,
a ~ dnorm(0,1),
aj[actor] ~ dnorm(0,sigma_actor),
bp ~ dnorm(0,1),
bc ~ dnorm(0,1),
bpc ~ dnorm(0,1),
sigma_actor ~ dcauchy(0,1)
) ,
data=d , log_lik=TRUE )
( x1 <- compare(m0,m1,m2,m3,m4,log_lik="log_lik") )
plot(x1)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.