title: "Introduction to R Package ARTtransfer for Transfer Learning" author: "Boxiang Wang, Yunan Wu, and Chenglong Ye" date: "2024-10-16" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Introduction to ARTtransfer} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8}
ARTtransfer
The ARTtransfer
package implements Adaptive and Robust Transfer Learning (ART), a framework that enhances model performance on primary tasks by integrating auxiliary data from related domains. The goal of ART is to leverage information from these auxiliary data while being robust against the so-called negative transfer, meaning that the performance of the primary task will not be negatively affected by non-informative auxiliary data.
To install the development version from GitHub, run the following:
# Install the R package from CRAN install.packages("ARTtransfer")
This section demonstrates how to generate synthetic data for transfer learning and apply the ART framework using different models.
The function generate_data()
allows you to simulate data for transfer learning, including primary, auxiliary, and noisy datasets. The response can be either continuous or binary for regression and classification tasks, respectively.
library(ARTtransfer) # Generate synthetic datasets for transfer learning dat <- generate_data(n0 = 100, K = 3, nk = 50, p = 10, mu_trgt = 1, xi_aux = 0.5, ro = 0.3, err_sig = 1, is_test = TRUE, task = "classification") # Explore the generated data cat("Primary dataset (X):", dim(dat$X), "\n") cat("Auxiliary dataset 1 (X_aux[[1]]):", dim(dat$X_aux[[1]]), "\n") cat("Test dataset (X_test):", dim(dat$X_test), "\n")
Once the data is generated, you can use the ART()
function to apply transfer learning. In this example, we fit a logistic regression using the wrapper function fit_logit()
, which implements the glm()
function in R.
# Fit the ART model using generalized linear model (logistic regression) fit_logit <- ART(X = dat$X, y = dat$y, X_aux = dat$X_aux, y_aux = dat$y_aux, X_test = dat$X_test, func = fit_logit) # Summary of the fit summary(fit_logit)
The ARTtransfer
package provides several wrapper functions for model fitting, including:
fit_lm()
: Linear regressionfit_logit()
: Logistic regressionfit_random_forest()
: Random forestfit_nnet()
: Neural networkfit_gbm()
: Gradient boosting machineYou can pass any of these functions to ART()
using the func
argument. You may also write your own function following the format of these wrappers. Specifically, the following requirements must be satisfied:
X
, y
, X_val
, y_val
, X_test
, min_prod
, max_prod
.dev
(the deviance) and pred
(the predictions).is_coef = TRUE
and a regression model is being used, the function must return the coefficients (coef
).ART_I_AM
is an extension of the ART framework that automatically integrates three pre-defined models: random forest, AdaBoost, and neural networks. Users don't need to specify the functions for each dataset as the models are already built into the function. However, users can follow the format of ART_I_AM
and implement their own functions under this flexible and general framework.
To use ART_I_AM
, simply provide the primary and auxiliary datasets, and the integrated models will be applied automatically. We here demonstrate the method ART-I-AM by implementing three methods random forest, boosting, and neural networks. Users can modify the R function ART_I_AM
by using any other functions, while the function should be coded to have the same format with func
in the R function ART
.
# Fit the ART_I_AM model using integrated Random Forest, AdaBoost, and Neural Network models fit_I_AM <- ART_I_AM(X = dat$X, y = dat$y, X_aux = dat$X_aux, y_aux = dat$y_aux, X_test = dat$X_test) # View the predictions and weights fit_I_AM$pred_FTL fit_I_AM$W_FTL
The ARTtransfer
package provides a flexible framework for performing adaptive and robust transfer learning using various models. You can easily integrate your own functions and data to improve performance on primary tasks.
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.