## ----------------------------------------------------------------------
## Wrapper function for MADlib's ARIMA
## ----------------------------------------------------------------------
setGeneric ("madlib.arima",
def = function (x, ts, ...) {
call <- match.call()
fit <- standardGeneric("madlib.arima")
fit$call <- call # call must be in generic function
fit
})
setClass("arima.css.madlib")
## ----------------------------------------------------------------------
## One of the two interfaces of madlib.arima
setMethod (
"madlib.arima",
signature (x = "db.Rquery", ts = "db.Rquery"),
def = function (x, ts, by = NULL, order=c(1,1,1),
seasonal = list(order = c(0,0,0), period = NA),
include.mean = TRUE, method = "CSS",
optim.method = "LM",
optim.control = list(), ...)
{
method <- match.arg(method)
optim.method <- match.arg(optim.method)
if (length(names(x)) != 1 || length(names(ts)) != 1)
stop("ARIMA can only have one time stamp column and ",
"one time series value column !")
data <- cbind2(x, ts)
f.str <- paste(names(x), "~", names(ts))
## grouping is a list of db.Rquery
if (!is.null(by)) {
grp.names <- character(0)
for (i in seq_len(length(by))) {
data <- cbind2(data, by[[i]])
grp.names <- c(grp.names, names(by[[i]]))
}
f.str <- paste(f.str, "|", paste(grp.names, collapse = "+"))
}
## Transform to the other interface
madlib.arima(formula(f.str), data, order, seasonal, include.mean,
method, optim.method, optim.control, ...)
})
## ----------------------------------------------------------------------
setMethod (
"madlib.arima",
signature (x = "formula", ts = "db.obj"),
def = function (x, ts, order=c(1,1,1),
seasonal = list(order = c(0,0,0), period = NA),
include.mean = TRUE, method = "CSS",
optim.method = "LM",
optim.control = list(), ...)
{
## Prepare for another function
## To avoid argument checking/matching
args <- list(...)
args <- c(args, optim.control)
args$formula <- x
args$data <- ts
args$order <- order
args$seasonal <- seasonal
args$include.mean <- include.mean
args$optim.method <- optim.method
if (tolower(method) == "css") {
do.call(.madlib.arima.css, args)
} else
stop("Right now MADlib's ARIMA does not support methods ",
"other than \"CSS\" !")
})
## ----------------------------------------------------------------------
.madlib.arima.css <- function (formula, data, order=c(1,1,1),
seasonal = list(order = c(0, 0, 0),
period = NA),
include.mean = TRUE, optim.method = "LM",
tau = 1e-3, e1 = 1e-15, e2 = 1e-15,
e3 = 1e-15, max.iter = 100,
hessian.delta = 1e-4, param.init = "zero",
chunk.size = 10000, ...)
{
if (tolower(optim.method) != "lm")
stop("Right now MADlib ARIMA only supports ",
"LM optimization method !")
if (!identical(seasonal, list(order = c(0, 0, 0), period = NA)))
stop("Right now MADlib ARIMA does not support seasonal order !")
if (length(order) != 3)
stop("ARIMA needs order to be an integer array of length 3 !")
## make sure fitting to db.obj
if (! is(data, "db.obj"))
stop("madlib.lm cannot be used on the object ",
deparse(substitute(data)))
## Only newer versions of MADlib are supported
.check.madlib.version(data)
conn.id <- conn.id(data)
db <- .get.dbms.str(conn.id)
if (db$db.str == "HAWQ" && grepl("^1\\.1", db$version.str))
stop("MADlib on HAWQ 1.1 does not support ARIMA !")
warnings <- .suppress.warnings(conn.id)
## analyze the formula
analyzer <- .get.params(formula, data)
data <- analyzer$data
params <- analyzer$params
is.tbl.source.temp <- analyzer$is.tbl.source.temp
## tbl.source <- gsub("\"", "", content(data))
tbl.source <- content(data)
if (length(params$ind.vars) != 1)
stop("Only one time stamp is allowed !")
## dependent, independent and grouping strings
if (is.null(params$grp.str))
grp <- "NULL"
else
stop("Right now MADlib does not support grouping in ARIMA !")
## allow expressions as time series and time stamp
## create intermediate tables to accomodate this
col.names <- names(data)
if (!(.strip(params$ind.vars, "\"") %in% col.names) ||
!(.strip(params$dep.str, "\"") %in% col.names)) {
new.src <- .unique.string()
res <- db.q(paste("create table ", new.src, " as ",
"select ", params$dep.str, " as tval, ",
params$ind.vars, " as tid from ", tbl.source,
sep = ""),
conn.id = conn.id, verbose = FALSE)
if (is.tbl.source.temp) delete(tbl.source)
is.tbl.source.temp <- TRUE
tbl.source <- new.src
params$dep.str <- "tval"
params$ind.vars <- "tid"
}
## construct SQL string
madlib <- schema.madlib(conn.id) # MADlib schema name
tbl.output <- .unique.string()
order.str <- paste("array[", toString(order), "]", sep = "")
optim.control.str <- paste("tau=", tau, ", e1=", e1, ", e2=", e2,
", e3=", e3, ", hessian_delta=",
hessian.delta, ", chunk_size=",
chunk.size, ", param_init=\"",
param.init, "\"", sep = "")
sql <- paste("select ", madlib, ".arima_train('",
tbl.source, "', '", tbl.output, "', '",
params$ind.vars, "', '", params$dep.str, "', ",
grp, ", ", include.mean, ", ", order.str, ", '",
optim.control.str, "')", sep = "")
## execute and get the result
res <- db.q(sql, conn.id=conn.id, verbose = FALSE)
p <- order[1]
d <- order[2]
q <- order[3]
## retrieve the coefficients
res <- lk(tbl.output, conn.id=conn.id, -1)
rst <- list()
rst$coef <- numeric(0) # coefficients
rst$s.e. <- numeric(0) # standard errors
coef.names <- character(0)
if (p != 0) {
rst$coef <- c(rst$coef, as.numeric(res[1,1]))
rst$s.e. <- c(rst$s.e., as.numeric(res[1,2]))
coef.names <- c(coef.names, paste("ar", 1:p, sep = ""))
}
if (q != 0) {
rst$coef <- c(rst$coef, as.numeric(res[1,(p>0)*2+1]))
rst$s.e. <- c(rst$s.e., as.numeric(res[1,(p>0)*2+2]))
coef.names <- c(coef.names, paste("ma", 1:q, sep = ""))
}
if (include.mean && d == 0) {
rst$coef <- c(rst$coef, as.numeric(res[1,((p>0)+(q>0))*2+1]))
rst$s.e. <- c(rst$s.e., as.numeric(res[1,((p>0)+(q>0))*2+2]))
coef.names <- c(coef.names, "mean")
}
names(rst$coef) <- coef.names
names(rst$s.e.) <- coef.names
## retrieve the statistics
rst$series <- content(data)
rst$time.stamp <- params$ind.vars
rst$time.series <- params$dep.str
res <- preview(paste(tbl.output, "_summary", sep = ""),
conn.id=conn.id, "all")
rst$sigma2 <- res$residual_variance
rst$loglik <- res$log_likelihood
rst$iter.num <- res$iter_num
rst$exec.time <- res$exec_time
## create db.data.frame object for residual table
rst$residuals <- db.data.frame(paste(tbl.output, "_residual", sep = ""),
conn.id = conn.id, verbose = FALSE)
rst$model <- db.data.frame(tbl.output, conn.id = conn.id,
verbose = FALSE)
rst$statistics <- db.data.frame(paste(tbl.output, "_summary", sep = ""),
conn.id = conn.id, verbose = FALSE)
## If temp.source is TRUE, the delete(...) function
## will delete it
if (is.tbl.source.temp) rst$temp.source <- TRUE
else rst$temp.source <- FALSE
.restore.warnings(warnings)
class(rst) <- "arima.css.madlib"
rst
}
## ----------------------------------------------------------------------
summary.arima.css.madlib <- function (object, ...)
{
object
}
## ----------------------------------------------------------------------
print.arima.css.madlib <- function (x,
digits = max(3L,
getOption("digits") - 3L),
...)
{
cat("\nCall:\n", paste(deparse(x$call), sep = "\n", collapse = "\n"),
"\n\n", sep = "")
cat("Coefficients:\n")
coef <- format(x$coef, digits = digits)
std.err <- format(x$s.e., digits = digits)
est <- rbind(coef, std.err)
row.names(est) <- c("", "s.e.")
print(format(data.frame(est), justify = "right"))
cat("\n")
cat("sigma^2 estimated as ", format(x$sigma2, digits = digits),
", part log likelihood = ", format(x$loglik, digits = digits),
sep = "")
cat("\n")
}
## -----------------------------------------------------------------------
show.arima.css.madlib <- function (object)
{
print(object)
}
## ----------------------------------------------------------------------
## Some functionalities will be implemented in the future
predict.arima.css.madlib <- function(object, n.ahead = 1, ...)
{
conn.id <- conn.id(object$model)
warnings <- .suppress.warnings(conn.id)
tbl.output <- .unique.string()
tbl.model <- .strip(content(object$model), "\"")
madlib <- schema.madlib(conn.id) # MADlib schema name
sql <- paste("select ", madlib, ".arima_forecast('",
tbl.model, "', '", tbl.output, "',",
n.ahead, ")", sep = "")
res <- db.q(sql, conn.id=conn.id, verbose = FALSE)
rst <- db.data.frame(tbl.output, conn.id=conn.id, verbose = FALSE)
.restore.warnings(warnings)
rst
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.