R/naive_bn.R

Defines functions naive_bn

Documented in naive_bn

#' Naive Bayes Classifier
#'
#' A function of Naive Bayes Classifier for numeric only independents.
#'
#' @param data a data.table class modelling dataset; only support numeric independents.
#' @param method a character for distribution of independents. There are three options: "normal","exp" and "gamma".
#' @param target a character for dependent variable name.
#'
#' @return a list of modeling result. It contains the following components:
#'    target: a character for dependent variable name;
#'    prior: a data.table for prior
#'    method a character for distribution of independents.
#'    group.. and overall... data.tables for posterior 
#'
#' @export
naive_bn=function(data,method,target){
  # input data must be data.table class
  #data=data.table(iris)
  #method="exp" # or "normal"
  #target="Species"
  
  target.level=levels(data[[target]])
  prior=rep(NA,length(target.level))
  for (i in 1:length(target.level)){
    prior[i]=sum(data[[target]]==target.level[i])/length(data[[target]])
  }
  prior=data.table(target=target.level,prob=prior)
  
  if (method=="normal") {
    overall.mean=data[,!target,with=F][,lapply(.SD, mean)]
    overall.sd=data[,!target,with=F][,lapply(.SD, sd)]
    group.mean=data[,lapply(.SD, mean),by=c(target)]
    group.sd=data[,lapply(.SD, sd),by=c(target)]
    result=list(target=target,prior=prior,method=method,group.mean=group.mean,overall.mean=overall.mean,group.sd=group.sd,overall.sd=overall.sd) 
  }else if (method=="exp") {
    overall.mean=data[,!target,with=F][,lapply(.SD, mean)]
    group.mean=data[,lapply(.SD, mean),by=c(target)]
    result=list(target=target,prior=prior,method=method,group.mean=group.mean,overall.mean=overall.mean)
  }else if (method=="gamma"){
    mean.log=log(data[,!target,with=F][,lapply(.SD, mean)])
    data.log=data[,!target,with=F][,lapply(.SD, log)]
    log.mean=data.log[,lapply(.SD, mean)]
    overall.s=mean.log-log.mean
    overall.k=(3-overall.s+sqrt((overall.s-3)^2+24*overall.s))/(12*overall.s)
    overall.a=data[,!target,with=F][,lapply(.SD, mean)]/overall.k
    
    col=data[,lapply(.SD, mean),by=c(target)][,c(target),with=F]
    mean.log=log(data[,lapply(.SD, mean),by=c(target)][,!target,with=F])
    data.log=data[,!target,with=F][,lapply(.SD, log)]
    data.log=cbind(data[,target,with=F],data.log)
    log.mean=data.log[,lapply(.SD, mean),by=c(target)][,!target,with=F]
    group.s=mean.log-log.mean
    group.k=(3-group.s+sqrt((group.s-3)^2+24*group.s))/(12*group.s)
    group.a=data[,lapply(.SD, mean),by=c(target)][,!target,with=F]/group.k
    group.k=cbind(col,group.k)
    group.a=cbind(col,group.a)
    result=list(target=target,prior=prior,method=method,group.k=group.k,group.a=group.a,
                overall.k=overall.k,overall.a=overall.a)
  }
  
  return(result)
}
xinzhou1023/nbcont documentation built on May 28, 2017, 7:38 a.m.