R/method_wrappers.R

Defines functions wrap_pu_reliable_negative base_pu_bagging wrap_pu_bagging wrap_one_svm wrap_svm wrap_kmeans

Documented in base_pu_bagging wrap_kmeans wrap_one_svm wrap_pu_bagging wrap_pu_reliable_negative wrap_svm

# ---------------------------
# method_wrappers.R
# wrappers for classification
#	methods.
# ---------------------------

# NOTE: cleanup
# ---------------------------
# METHOD WRAPPERS
# ---------------------------
# A method wrapper accepts:
#	1. train_data - Data without group identifiers
#	2. full_data - Data with infection status identifiers
#	3. method_params - Any parameters it requires, bundled in a list
#	4. protect_data - Data with protection status identifiers
# The parent wrapper is required to return:
#	1. The original data, the data with new labels, and the ground truth data
#	2. Scores comparing new assignments to old labels and to true labels
#	3. Test results for each pair of labels, for each component
#	4. Fold changes for each pair of labels, for each component
# A method wrapper is required to:
#	1. Parse its method_params list
#	2. Return:
#		a. full_data as old_data (i.e. old/inf labels)
#		b. train_data with the new labels as new_data (i.e. predicted status)
# ---------------------------

#' wrap_kmeans
#'
#' @export
wrap_kmeans = function(x, y, ..., verbose=FALSE){
	method_params = ...$method_params
	improve_cutoff = method_params$improve_cutoff
	min_ss = method_params$min_ss

	# Not used by kmeans
	method_params$improve_cutoff = NULL
	method_params$min_ss = NULL

	method_params$x = x

	# H1: method_params$centers number of centers
	h1_model = do.call(kmeans, method_params)

	# If we don't converge in the Quick TRANSfer stage, use another algorithm
	if (verbose && model$ifault == 4){
		warning("Quick TRANSfer stage steps exceeded - trying MacQueen algorithm...", call.=TRUE)
		method_params$algorithm = "MacQueen"
		h1_model = do.call(kmeans, method_params)
	}

	# H0: 1 center
	method_params$centers = 1
	h0_model = do.call(kmeans, method_params)

	# If H1 accounts for at least min_ss of the data, then we'll take it.
	#	Otherwise, H0 is a good explanation.
	h1_ss = round(h1_model$betweenss / h1_model$totss, digits=3)

	if (h1_ss > min_ss){
		model = h1_model

	} else {
		model = h0_model
	}

	# Relabeled
	new_data = cbind(x, groups=model$cluster)

	gc()
	return(list(old_data=cbind(x, groups=y), new_data=new_data))
}

# ---------------------------
# Requires:
# library(e1071)
# library(caret)

#' wrap_svm
#'
#' @export
wrap_svm = function(x, y, newdata, ...){
	method_params = ...$method_params

	method_params$y_test = NULL

	method_params$x = as.data.frame(apply(x, 2, as.numeric))
	rownames(method_params$x) = rownames(x)
	method_params$y = y

	groups = c(globals_list$UnprotNum, globals_list$ProtNum)

	# NOTE: If all groups are the same, you will get "Model is empty!"
	#	So we set one sample (here, the last one) to whichever
	#	number we don't have so that the model will work.
	if (length(unique(method_params$y)) == 1){
		method_params$y[length(method_params$y)] = groups[groups != unique(method_params$y)]
	}

	method_params$y = factor(method_params$y, levels=c(globals_list$UnprotNum, globals_list$ProtNum))

	model = do.call(svm, method_params)

	rnames = rownames(newdata)
	newdata = as.data.frame(apply(newdata, 2, as.numeric))
	rownames(newdata) = rnames

	pred = predict(model, newdata)

	# If SVM does weird shit
	if (length(fccu_plevels(pred)) > 2){
		warning("SVM went rogue - Check why predict isn't giving classes!", call.=TRUE)
		quit()
	}

	new_data = cbind(newdata, groups=pred)

	gc()
	return(list(old_data=cbind(x, groups=y), new_data=new_data))
}

# ---------------------------
# Requires:
# library(e1071)
# library(caret)

# NOTE: x must be positive data, newdata must be unlabeled data
#' wrap_one_svm
#'
#' @export
wrap_one_svm = function(x, y, newdata, ...){
	method_params = ...$method_params

	method_params$y_test = NULL

	# Train on positives, i.e. infected subjects
	method_params$x = x
	method_params$y = as.numeric(y)

	method_params$type = "one-classification"

	model = do.call(svm, method_params)

	# Test on unlabeleds
	pred = predict(model, newdata)
	pred[pred == TRUE] = globals_list$UnprotNum
	pred[pred == FALSE] = globals_list$ProtNum

	new_data = cbind(x, groups=pred)

	gc()
	return(list(old_data=cbind(x, groups=y), new_data=new_data))
}


# ---------------------------
# P/U methods!
# ---------------------------

# TODO: scores are weird
#' wrap_pu_bagging
#'
#' @export
wrap_pu_bagging = function(x, y, ...){
	wrap_method_params = ...$wrap_method_params
	method_params = ...$method_params

	if (is.null(wrap_method_params$k)){
		warning("wrap_method_params$k not set - defaulting to nrow(unlabeled)/2", call.=TRUE)
		wrap_method_params$k = nrow(unlabeled)/2

	} else if (wrap_method_params$k == "positive"){
		wrap_method_params$k = nrow(positive)
	}

	model = base_pu_bagging(x, y, wrap_method_params, method_params)

	new_data = cbind(x, groups=model$pred)

	gc()
	return(list(old_data=cbind(x, groups=y), new_data=new_data))
}

#' base_pu_bagging
#'
#' @export
base_pu_bagging = function(x, y, wrap_method_params, method_params){
	# Train on positives, i.e. infected subjects
	full_data = cbind(x, groups=y)
	rownames(full_data) = 1:nrow(full_data)
	positive = full_data[y == globals_list$InfNum,]
	unlabeled = full_data[y != globals_list$InfNum,]

	# Bootstrap max_iters times, average results
	# List of bootstrapped training and evaluation data
	boots = lapply(1:wrap_method_params$max_iters, wrap_boot, unlabeled, positive, wrap_method_params$k)
	boot_xs = lapply(boots, function(x){ return(x$train[,colnames(x$train) != "groups"]) })
	boot_ys = lapply(boots, function(x){ return(x$train[,colnames(x$train) == "groups"]) })
	boot_oobs = lapply(boots, function(x){ return(x$oob) })

	boot_results = mapply(wrap_method_params$wrap_pu_bag_method, boot_xs, boot_ys, boot_oobs,
			      MoreArgs=list(method_params=method_params), SIMPLIFY=FALSE)

	# NOTE: If any results look weird, check that preds get assigned
	#	to the correct row index.

	# boot_results[[1:max_iters]][[old_data/new_data]]
	pred = lapply(boot_results, function(x){ return(x$new_data[,"groups"]) })
	rnames = lapply(boot_results, function(x){ return(rownames(x$new_data)) })
	pred = generate_pu_pred(pred, rnames)

	if (length(pred + nrow(positive)) != nrow(x)){
		warning("wrap_pu_bagging OOB predictions must have missed some samples.", call.=TRUE)
	}

	groups = vector("numeric", nrow(x))
	groups[y == globals_list$InfNum] = positive[,"groups"]
	groups[y != globalst_list$InfNum] = pred

	gc()
	return(list(pred=pred))
}

#' wrap_pu_reliable_negative
#'
#' @export
wrap_pu_reliable_negative = function(x, y, newdata, ...){
	wrap_method_params = ...$wrap_method_params
	method_params = ...$method_params
	y_test = method_params$y_test
	y_test = factor(y_test, levels=c(globals_list$UnprotNum, globals_list$ProtNum))

	full_data = cbind(x, groups=y)
	rownames(full_data) = 1:nrow(full_data)

	positive = full_data[y == globals_list$InfNum,]
	unlabeled = full_data[y != globals_list$InfNum,]
	unlabeled = unlabeled[,colnames(unlabeled) != "groups"]

	# Get reliable negatives, update unlabeleds
	reliable_negatives = generate_reliable_negatives(x, full_data, positive, method_params)
	unlabeled = unlabeled[!(rownames(unlabeled) %in% rownames(reliable_negatives)),]

	full_data = as.matrix(full_data)
	positive = as.matrix(positive)
	unlabeled = as.matrix(unlabeled)
	reliable_negatives = as.matrix(reliable_negatives)

	positive[,"groups"] = rep(globals_list$InfNum, nrow(positive))

	first_model = NULL
	final_model = NULL

	unlabeled = unlabeled
	reliable_negatives = reliable_negatives

	# NOTE: length check necessary?
	if ((length(nrow(reliable_negatives))) > 0 && (nrow(reliable_negatives) > 0)){
		# While stopping criteria not met
		while (TRUE){
			# Train svm on P + RN, where l_P = 1, l_RN = -1
			merged = rbind(positive, reliable_negatives)

			model = wrap_method_params$wrap_pu_rn_method(merged, method_params)

			if (is.null(first_model)){
				first_model = model
			}

			pred = predict(model, unlabeled)
			
			maybe_neg = (as.numeric(pred) == globals_list$UninfNum)
			potential_neg = unlabeled[maybe_neg,]

			# If no negatives were identified, then we're done
			if (!(nrow(potential_neg) > 0)){
				final_model = model
				break

			# Else, assign predicted negatives as reliable negatives, and
			#	remove them from the unlabeled dataset
			# NOTE: This will not create groups that could translate 1:1
			#	to x, but that doesn't matter since we return based
			#	on behavior on newdata
			} else {
				# Update reliable negatives
				rn_groups = rep(globals_list$UninfNum, nrow(potential_neg))
				to_add = cbind(potential_neg, groups=rn_groups)
				reliable_negatives = rbind(reliable_negatives, to_add)

				# Update unlabeled
				unlabeled = unlabeled[!maybe_neg,]
			}
		}

		# Predict on new data
		first_pred = predict(first_model, newdata)
		final_pred = predict(final_model, newdata)

	# NOTE: If no reliable negatives are found, then we can predict nothing,
	#	and we therefore only have one cluster
	} else {
		first_pred = rep(globals_list$InfNum, nrow(newdata))
		final_pred = rep(globals_list$InfNum, nrow(newdata))
	}

	first_frame = cbind(obs=y_test, pred=first_pred)
	first_score = mcc_score(first_frame, levels(as.factor(y_test)))

	final_frame = cbind(obs=y_test, pred=final_pred)
	final_score = mcc_score(final_frame, levels(as.factor(y_test)))

	# NOTE: return some information about which model is better
	if (first_score >= final_score){
		pred = first_pred

	} else {
		pred = final_pred
	}

	new_data = cbind(newdata, groups=pred)

	gc()
	return(list(old_data=full_data, new_data=new_data))
}
kmorrisongr/ksmthesis documentation built on Oct. 5, 2020, 6:41 a.m.