s2netR: Trains a generalized extended linear joint trained model...

View source: R/extJT.R

s2netRR Documentation

Trains a generalized extended linear joint trained model using semi-supervised data.

Description

This function is a wrapper for the class s2net. It creates the C++ object and fits the model using input data.

Usage

s2netR(data, params, loss = "default", frame = "ExtJT", proj = "auto", 
        fista = NULL, S3 = TRUE)

Arguments

data

A s2Data object with the (training) data.

params

A s2Params object with the model hyper-parameters.

loss

Loss function. One of "default" (figure it out from the data), "linear" or "logit".

frame

The semi-supervised frame: "ExtJT" (the extended linear joint trained model), "JT" (the linear joint trained model from Ryan and Culp. 2015)

proj

Should the unlabeled data be shifted to remove the model's effect? One of "no", "yes", "auto" (option auto shifts the unlabeled data if the angle betwen beta and the center of the data is important)

fista

Fista setup parameters. An object of class s2Fista.

S3

Boolean: should the method return an S3 object (default) or a C++ object?

Value

Returns an object of S3 class s2netR or a C++ object of class s2net

Author(s)

Juan C. Laria

References

Ryan, K. J., & Culp, M. V. (2015). On semi-supervised linear regression in covariate shift problems. The Journal of Machine Learning Research, 16(1), 3183-3217.

See Also

s2net

Examples

data("auto_mpg")
train = s2Data(xL = auto_mpg$P1$xL, yL = auto_mpg$P1$yL,  xU = auto_mpg$P1$xU)

model = s2netR(train, 
                s2Params(lambda1 = 0.1, 
                           lambda2 = 0,
                           gamma1 = 0.1,
                           gamma2 = 100,
                           gamma3 = 0.1),
                loss = "linear",
                frame = "ExtJT",
                proj = "auto",
                fista = s2Fista(5000, 1e-7, 1, 0.8))

valid = s2Data(auto_mpg$P1$xU, auto_mpg$P1$yU, preprocess = train)
ypred = predict(model, valid$xL)

## Not run: 
if(require(ggplot2)){
  ggplot() + 
    aes(x = ypred, y = valid$yL) + geom_point() + 
    geom_abline(intercept = 0, slope = 1, linetype = 2)
}

## End(Not run)


s2net documentation built on July 1, 2022, 1:06 a.m.