cvMTL: K-fold cross-validation

View source: R/MTL.R

cvMTLR Documentation

K-fold cross-validation

Description

Perform the k-fold cross-validation to estimate the λ_1.

Usage

cvMTL(
  X,
  Y,
  type = "Classification",
  Regularization = "L21",
  Lam1_seq = 10^seq(1, -4, -1),
  Lam2 = 0,
  G = NULL,
  k = 2,
  opts = list(init = 0, tol = 10^-3, maxIter = 1000),
  stratify = FALSE,
  nfolds = 5,
  ncores = 2,
  parallel = FALSE
)

Arguments

X

A set of feature matrices

Y

A set of responses, could be binary (classification problem) or continues (regression problem). The valid value of binary outcome \in\{1, -1\}

type

The type of problem, must be Regression or Classification

Regularization

The type of MTL algorithm (cross-task regularizer). The value must be one of {L21, Lasso, Trace, Graph, CMTL }

Lam1_seq

A positive sequence of Lam1 which controls the cross-task regularization

Lam2

A positive constant λ_{2} to improve the generalization performance

G

A matrix to encode the network information. This parameter is only used in the MTL with graph structure (Regularization=Graph )

k

A positive number to modulate the structure of clusters with the default of 2. This parameter is only used in MTL with clustering structure (Regularization=CMTL ) Note, the larger number is adapted to more complex clustering structure.

opts

Options of the optimization procedure. One can set the initial search point, the tolerance and the maximized number of iterations through the parameter. The default value is list(init=0, tol=10^-3, maxIter=1000)

stratify

stratify=TRUE is used for stratified cross-validation

nfolds

The number of folds

ncores

The number of cores used for parallel computing with the default value of 2

parallel

parallel=TRUE is used for parallel computing

Value

The estimated λ_1 and related information

Examples

#create the example data
data<-Create_simulated_data(Regularization="L21", type="Classification")
#perform the cross validation
cvfit<-cvMTL(data$X, data$Y, type="Classification", Regularization="L21", 
    Lam2=0, opts=list(init=0,  tol=10^-6, maxIter=1500), nfolds=5,
    stratify=TRUE, Lam1_seq=10^seq(1,-4, -1))
#show meta-infomration
str(cvfit)
#plot the CV accuracies across lam1 sequence
plot(cvfit)

RMTL documentation built on May 2, 2022, 5:06 p.m.