R/psupertime_plots.R

Defines functions plot_profiles_of_gene_clusters plot_heatmap_of_gene_clusters plot_go_results .make_plot_dt .do_topgo_for_cluster .calc_clusters_dt psupertime_go_analysis psupertime_go_analysis_old plot_double_psupertime_confusion plot_double_psupertime_genes plot_double_psupertime_contour plot_double_psupertime double_psupertime plot_predictions_against_classes .check_conf_params plot_new_data_over_psupertime project_onto_psupertime .process_new_data plot_specified_genes_over_psupertime plot_identified_genes_over_psupertime plot_identified_gene_coefficients plot_labels_over_psupertime .make_col_vals plot_train_results psupertime_plot_all

Documented in .calc_clusters_dt .check_conf_params .do_topgo_for_cluster double_psupertime .make_col_vals .make_plot_dt plot_double_psupertime plot_double_psupertime_confusion plot_double_psupertime_contour plot_double_psupertime_genes plot_go_results plot_heatmap_of_gene_clusters plot_identified_gene_coefficients plot_identified_genes_over_psupertime plot_labels_over_psupertime plot_new_data_over_psupertime plot_predictions_against_classes plot_profiles_of_gene_clusters plot_specified_genes_over_psupertime plot_train_results project_onto_psupertime psupertime_go_analysis psupertime_go_analysis_old psupertime_plot_all

# psupertime_plots.R

#' Convenience function to do multiple plots
#'
#' @importFrom ggplot2 ggsave
#' @param psuper_obj Psupertime object, output from psupertime
#' @param output_dir Directory to save to
#' @param tag Label for all files
#' @param label_name Description for the ordered labels in the legend (e.g. 'Developmental stage (days)')
#' @param ext Image format for outputs, compatible with ggsave (eps, ps, tex, pdf, jpeg, tiff, png, bmp, svg, wmf)
#' @export
psupertime_plot_all <- function(psuper_obj, output_dir='.', tag='', label_name='Ordered labels', ext='png') {
	# validate model
	cat('plotting results\n')
	g 			= plot_train_results(psuper_obj)
	plot_file 	= file.path(output_dir, sprintf('%s training results.%s', tag, ext))
	ggsave(plot_file, g, height=6, width=6)

	g 			= plot_labels_over_psupertime(psuper_obj, label_name)
	plot_file 	= file.path(output_dir, sprintf('%s labels over psupertime.%s', tag, ext))
	ggsave(plot_file, g, height=6, width=12)

	g 			= plot_identified_gene_coefficients(psuper_obj)
	plot_file 	= file.path(output_dir, sprintf('%s identified genes.%s', tag, ext))
	ggsave(plot_file, g, height=6, width=8)

	g 			= plot_identified_genes_over_psupertime(psuper_obj, label_name)
	plot_file 	= file.path(output_dir, sprintf('%s identified genes over psupertime.%s', tag, ext))
	ggsave(plot_file, g, height=8, width=12)

	g 			= plot_predictions_against_classes(psuper_obj)
	plot_file 	= file.path(output_dir, sprintf('%s predictions over psupertime, original data.%s', tag, ext))
	ggsave(plot_file, g, height=6, width=10)
}

#' Plot results of training
#'
#' @param psuper_obj Psupertime object, output from psupertime
#' @return ggplot2 object showing test and training performance of classifier.
#' @export
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 facet_grid
#' @importFrom ggplot2 geom_line
#' @importFrom ggplot2 geom_point
#' @importFrom ggplot2 geom_linerange
#' @importFrom ggplot2 geom_vline
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 guides
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 scale_colour_manual
#' @importFrom ggplot2 scale_fill_brewer
#' @importFrom ggplot2 scale_size_manual
#' @importFrom ggplot2 theme_bw
plot_train_results <- function(psuper_obj) {
	# unpack
	ps_params 		= psuper_obj$ps_params
	scores_dt 		= psuper_obj$scores_dt
	glmnet_best 	= psuper_obj$glmnet_best

	# add sparsity to plots
	sparse_dt 		= data.table(
		lambda 		= glmnet_best$lambda,
		score_var 	= 'sparsity',
		data 		= 'train',
		mean 		= apply(abs(glmnet_best$beta)>0, 2, sum),
		se 			= NA
		)

	# calculate mean scores
	mean_scores 	= scores_dt[, 
		list(
			mean 	= mean(score_val), 
			se 		= sd(score_val)/sqrt(.N)
			), 
		by = list(lambda, score_var, data)
		]

	# where should vertical lines go?
	lines_best 		= copy(psuper_obj$best_dt)
	dummy_dt 		= data.table(score_var=c('sparsity', 'xentropy', 'class_error'))
	lines_best 		= lines_best[ dummy_dt, on='score_var']
	lines_best[, selected := score_var==ps_params$score ]

	# add nice labels for accuracy measures
	plot_dt 		= rbind(mean_scores, sparse_dt)
	measures_dt = data.table(
		score_var 		= c('xentropy', 'class_error', 'sparsity'),
		nice_score_var 	= c('Cross entropy', 'Classification error', 'Non-zero genes')
		)
	plot_dt 		= measures_dt[plot_dt, on='score_var']
	lines_best 		= measures_dt[lines_best, on='score_var']
	
	# which measure used for model selection?
	nice_sel_var 		= measures_dt[ score_var==ps_params$score ]$nice_score_var

	# set up
	g = ggplot(plot_dt) +
		aes( x=log10(lambda), y=mean, colour=data )
	
	# # plot each fold
	# g = g + geom_point(data=scores_dt, aes(fill=factor(fold), y=score_val), colour='transparent', shape=21 ) +
	# 	scale_fill_brewer( palette='Set1' )

	# plot test and training data
	g = g + geom_linerange(aes(ymin=mean-se, ymax=mean+se) ) +
		geom_point() +
		geom_line() +
		scale_colour_manual( values=c('grey', 'black') )

	# annotate with best lambdas, tidy up
	g = g + geom_vline(data=lines_best, aes(xintercept=log10(best_lambda), size=selected), colour='grey', linetype='solid' ) +
		geom_vline(data=lines_best, aes(xintercept=log10(next_lambda), size=selected), colour='grey', linetype='dashed' ) +
		scale_size_manual( values = c(0.5, 1) ) +
		guides( size=FALSE )

	# label nicely
	g = g + 
		facet_grid( nice_score_var ~ ., scales='free_y' ) +
		theme_bw() +
		labs(
			x 		= 'log10( lambda )'
			,y 		= 'Accuracy measure'
			,colour = 'Data'
			# ,fill 	= 'Fold'
			,title 	= sprintf('%s used for model selection', nice_sel_var)
			) +
		theme(
			plot.title 	= element_text( size=10, hjust=1 )
			)

	return(g)
}

#' Define RColorBrewer palette to use; default is RdBu.
#'
#' @importFrom RColorBrewer brewer.pal
#' @importFrom grDevices colorRampPalette
#' @param y_labels List of labels used for training
#' @return Colour values
#' @keywords internal
.make_col_vals <- function(y_labels, palette='RdBu') {
	n_labels 	= length(levels(y_labels))
	max_col 	= 11
	if (n_labels==1) {
		col_vals 	= brewer.pal(3, palette)
		col_vals 	= col_vals[1]
	} else if (n_labels==2) {
		col_vals 	= brewer.pal(3, palette)
		col_vals 	= col_vals[-2]
	} else if (n_labels<=max_col) {
		col_vals 	= brewer.pal(n_labels, palette)
	} else {
		col_pal 	= brewer.pal(max_col, palette)
		col_vals 	= colorRampPalette(col_pal)(n_labels)
	}
	col_vals 	= rev(col_vals)

	return(col_vals)
}

#' Plots labels over their projected values on psupertime.
#'
#' @param psuper_obj Psupertime object, output from psupertime
#' @param label_name Description for the ordered labels in the legend (e.g. 'Developmental stage (days)')
#' @param palette RColorBrewer palette to use
#' @return ggplot2 object
#' @export
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 geom_density
#' @importFrom ggplot2 geom_vline
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 guide_legend
#' @importFrom ggplot2 guides
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 scale_colour_manual
#' @importFrom ggplot2 scale_fill_manual
#' @importFrom ggplot2 scale_x_continuous
#' @importFrom scales pretty_breaks
plot_labels_over_psupertime <- function(psuper_obj, label_name='Ordered labels', palette='RdBu') {
	# unpack
	proj_dt 		= psuper_obj$proj_dt
	cuts_dt 		= psuper_obj$cuts_dt

	# make nice colours
	col_vals 		= .make_col_vals(proj_dt$label_input, palette)

	# plot
	g = ggplot(proj_dt) +
		aes( x=psuper, fill=label_input, colour=label_input ) +
		geom_density( alpha=0.5 ) +
		scale_fill_manual( values=col_vals ) +
		geom_vline( data=cuts_dt, aes(xintercept=psuper, colour=label_input) ) +
		scale_colour_manual( values=col_vals ) +
		guides(
			fill 	= guide_legend(override.aes = list(alpha=1))
			,colour = FALSE
			) +
		scale_x_continuous( breaks=pretty_breaks() ) +
		labs(
			x 		= 'psupertime'
			,y 		= 'Density'
			,fill 	= label_name
			) +
		theme_bw()

	return(g)
}

#' Plots top coefficients
#'
#' @param psuper_obj Psupertime object, output from psupertime
#' @return ggplot2 object
#' @export
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 element_text
#' @importFrom ggplot2 geom_hline
#' @importFrom ggplot2 geom_segment
#' @importFrom ggplot2 geom_point
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 scale_y_continuous
#' @importFrom ggplot2 theme
#' @importFrom ggplot2 theme_bw
#' @importFrom scales pretty_breaks
plot_identified_gene_coefficients <- function(psuper_obj, n=20, abs_cutoff=0.05) {
	# prepare plot
	plot_dt 	= psuper_obj$beta_dt[ abs_beta > abs_cutoff ]
	plot_dt 	= plot_dt[ 1:min(n, nrow(plot_dt)) ]
	max_val 	= ceiling(max(plot_dt$abs_beta)*10)/10

	# plot
	g 	= ggplot(plot_dt) +
		aes( x=symbol, xend=symbol, y=beta, yend=0 ) +
		geom_segment( colour='black' ) +
		geom_point( colour='blue', size=5 ) +
		# geom_hline( yintercept=0, colour='grey' ) +
		scale_y_continuous( breaks=pretty_breaks(), limits=c(-max_val, max_val) ) +
		theme_bw() +
		theme(
			axis.text.x 	= element_text( angle=-45, hjust=0 )
			) +
		labs(
			x 	= 'Gene',
			y 	= 'Coefficient value'
			)
	return(g)
}

#' Plots profiles of identified genes against psupertime.
#'
#' @param psuper_obj Psupertime object, output from psupertime
#' @param label_name Description for the ordered labels in the legend (e.g. 'Developmental stage (days)')
#' @param n_to_plot Maximum number of genes to plot (default 20)
#' @param palette RColorBrewer palette to use
#' @param plot_ratio ratio of columns to rows (default is 5:4)
#' @return ggplot2 object
#' @export
#' @importFrom data.table data.table
#' @importFrom data.table melt.data.table
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 element_blank
#' @importFrom ggplot2 facet_wrap
#' @importFrom ggplot2 geom_point
#' @importFrom ggplot2 geom_smooth
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 scale_colour_manual
#' @importFrom ggplot2 scale_shape_manual
#' @importFrom ggplot2 scale_x_continuous
#' @importFrom ggplot2 scale_y_continuous
#' @importFrom ggplot2 theme
#' @importFrom ggplot2 theme_bw
#' @importFrom scales pretty_breaks
plot_identified_genes_over_psupertime <- function(psuper_obj, label_name='Ordered labels', n_to_plot=20, palette='RdBu', plot_ratio=1.25) {
	# unpack
	proj_dt 	= psuper_obj$proj_dt
	beta_dt 	= psuper_obj$beta_dt
	x_data 		= psuper_obj$x_data
	ps_params 	= psuper_obj$ps_params

	# aset
	beta_nzero 	= beta_dt[ abs_beta > 0 ]
	n_nzero 	= nrow(beta_nzero)
	top_genes 	= as.character(beta_nzero[1:min(n_to_plot, nrow(beta_nzero))]$symbol)

	# set up data for plotting
	plot_wide 	= cbind(proj_dt, data.table(x_data[, top_genes, drop=FALSE]))
	plot_dt 	= melt.data.table(plot_wide, id=c('cell_id', 'psuper', 'label_input', 'label_psuper'), measure=top_genes, variable.name='symbol')
	plot_dt[, symbol := factor(symbol, levels=top_genes)]

	# get colours
	col_vals 	= .make_col_vals(plot_dt$label_input, palette)
	n_genes 	= length(top_genes)
	ncol 		= ceiling(sqrt(n_genes*plot_ratio))
	nrow 		= ceiling(n_genes/ncol)

	# plot
	g =	ggplot(plot_dt) +
		aes( x=psuper, y=value) +
		geom_point( size=1, aes(colour=label_input) ) +
		geom_smooth(se=FALSE, colour='black') +
		scale_colour_manual( values=col_vals ) +
		scale_shape_manual( values=c(1, 16) ) +
		scale_x_continuous( breaks=pretty_breaks() ) +
		scale_y_continuous( breaks=pretty_breaks() ) +
		facet_wrap( ~ symbol, scales='free_y', nrow=nrow, ncol=ncol ) +
		theme_bw() +
		theme(
			axis.text.x = element_blank()
			) +
		labs(
			x 		= 'psupertime'
			,y 		= 'z-scored log2 expression'
			,colour = label_name
			)
	return(g)
}

#' Plots profiles of hand-selected genes against psupertime.
#'
#' @param psuper_obj psupertime object, output from psupertime
#' @param extra_genes List of genes to be plotted (these must be in the set of genes used for calculating psupertime, e.g. highly variable genes)
#' @param label_name Description for the ordered labels in the legend (e.g. 'Developmental stage (days)')
#' @param palette RColorBrewer palette to use
#' @param plot_ratio ratio of columns to rows (default is 5:4)
#' @return ggplot2 object
#' @export
#' @importFrom data.table data.table
#' @importFrom data.table melt.data.table
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 element_blank
#' @importFrom ggplot2 facet_wrap
#' @importFrom ggplot2 geom_point
#' @importFrom ggplot2 geom_smooth
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 scale_colour_manual
#' @importFrom ggplot2 scale_x_continuous
#' @importFrom ggplot2 scale_y_continuous
#' @importFrom ggplot2 theme
#' @importFrom ggplot2 theme_bw
#' @importFrom scales pretty_breaks
plot_specified_genes_over_psupertime <- function(psuper_obj, extra_genes, label_name='Ordered labels', palette='RdBu', plot_ratio=1.25) {
	# unpack
    proj_dt     = psuper_obj$proj_dt
    beta_dt     = psuper_obj$beta_dt
    x_data      = psuper_obj$x_data
    ps_params 	= psuper_obj$ps_params

	# restrict to specified genes
	extra_genes = intersect(extra_genes, colnames(x_data))
	if (length(extra_genes)==0) {
		warning('genes not found; did not plot')
		return()
	}

	# set up data
    plot_wide   = cbind(proj_dt, data.table(x_data[, extra_genes, drop=FALSE]))
    plot_dt     = melt.data.table(
    	plot_wide, 
    	id 				= c("psuper", "label_input", "label_psuper"), 
    	measure 		= extra_genes, 
    	variable.name 	= "symbol"
    	)
    plot_dt[, `:=`(symbol, factor(symbol, levels = extra_genes))]

	# set up plot
	col_vals 	= .make_col_vals(plot_dt$label_input, palette)
	n_genes 	= length(extra_genes)
	ncol 		= ceiling(sqrt(n_genes*plot_ratio))
	nrow 		= ceiling(n_genes/ncol)

	# plot
	g =	ggplot(plot_dt) +
		aes( x=psuper, y=value ) +
		geom_point( size=1, aes(colour=label_input) ) +
		geom_smooth(se=FALSE, colour='black') +
		scale_colour_manual( values=col_vals ) +
		scale_x_continuous( breaks=pretty_breaks() ) +
		scale_y_continuous( breaks=pretty_breaks() ) +
		facet_wrap( ~ symbol, scales='free_y', nrow=nrow, ncol=ncol ) +
		theme_bw() +
		theme(
			axis.text.x = element_blank()
			) +
		labs(
			x 		= 'psupertime'
			,y 		= 'z-scored log2 expression'
			,colour = label_name
			)
	return(g)
}

#' @keywords internal
.process_new_data <- function(psuper_obj, new_x) {
	# process new_x
	params_copy 				= psuper_obj$ps_params
	params_copy$sel_genes 		= 'list'
	params_copy$gene_list 		= colnames(psuper_obj$x_data)
	params_copy$min_expression 	= 0
	sel_genes 					= .select_genes(new_x, params_copy)
	new_data 					= .make_x_data(new_x, sel_genes, params_copy)
	return(new_data)
}

#' Gives projection of data onto psupertime (either using original data, or new data)
#'
#' @param psuper_obj Psupertime object, output from psupertime
#' @param new_x, new_y Optional pair of new data and labels
#' @return data.table with projection and labels
#' @export
project_onto_psupertime <- function(psuper_obj, new_x=NULL, new_y=NULL, process=FALSE) {
	# unpack
	glmnet_best 	= psuper_obj$glmnet_best
	best_lambdas 	= psuper_obj$best_lambdas

	# project new data
	if ( is.null(new_x) & is.null(new_y) ) {
		x_in 			= psuper_obj$x_data
		y_in 			= psuper_obj$y
	} else if ( !is.null(new_x) & !is.null(new_y) ) {
		if (process==TRUE) {
			x_in 			= .process_new_data(psuper_obj, new_x)
		} else {
			x_in 			= new_x
		}
		if (!is.factor(new_y)) {
			new_y 			= factor(new_y)
			message('converting new_y into factor, with the following ordered values:')
			message(paste(levels(new_y), ', '))
			message('(define new_y as a factor if you prefer a different ordering)')
		}
		y_in 			= factor(new_y)
	} else {
		stop('either both of new_x and new_y must be given, or neither')
	}

	proj_dt 		= .calc_proj_dt(glmnet_best, x_in, y_in, best_lambdas)

	return(proj_dt)
}

#' Plots profiles of hand-selected genes against psupertime.
#'
#' @param psuper_obj Psupertime object, output from psupertime
#' @param new_x,new_y Optional data to predict with psuper_obj
#' @param palette RColorBrewer palette to use
#' @return ggplot2 object
#' @export
#' @importFrom cowplot plot_grid
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 aes_string
#' @importFrom ggplot2 geom_raster
#' @importFrom ggplot2 scale_fill_distiller
#' @importFrom ggplot2 expand_limits
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 theme_bw
#' @importFrom scales pretty_breaks
plot_new_data_over_psupertime <- function(psuper_obj, new_x, new_y, labels=c('Original', 'New data'), palette='BrBG', process=FALSE) {
	# project new data
	proj_new 	= project_onto_psupertime(psuper_obj, new_x, new_y, process)

	# make nice colours
	col_vals 	= .make_col_vals(proj_new$label_input, palette)

	# get cutpoints
	cuts_dt 	= psuper_obj$cuts_dt

	# do plot
	x_label 	= sprintf('psupertime trained on %s', labels[[1]])
	g1 = plot_labels_over_psupertime(psuper_obj, label_name=labels[[1]]) +
		xlab( x_label )
	g2 = ggplot(proj_new) +
		aes( x=psuper, fill=label_input, colour=label_input) +
		geom_density( alpha=0.5 ) +
		geom_vline( data=cuts_dt, aes(xintercept=psuper), colour='black' ) +
		scale_fill_manual( values=col_vals ) +
		scale_colour_manual( values=col_vals ) +
		guides(
			fill 	= guide_legend(override.aes = list(alpha=1))
			,colour = FALSE
			) +
		scale_x_continuous( breaks=pretty_breaks() ) +
		labs(
			x 		= x_label
			,y 		= 'Density'
			,fill 	= labels[[2]]
			) +
		theme_bw()

	# give same x range
	proj_orig 	= psuper_obj$proj_dt
	x_range 	= c(
		floor(min(quantile(proj_new$psuper, prob=0.01), quantile(proj_orig$psuper, prob=0.01))),
		ceiling(max(quantile(proj_new$psuper, 0.99), quantile(proj_orig$psuper, 0.99)))
		)
	g1 			= g1 + coord_cartesian( xlim=x_range )
	g2 			= g2 + coord_cartesian( xlim=x_range )

	# put into grid
	g			= plot_grid(plotlist=list(g1, g2), labels=NULL, nrow=2, ncol=1, align='v', axis='lr')

	return(g)
}


#' Check variables for confusion matrices
#'
#' @param plot_var Variable to plot: prop_true is proportion of true labels, prop_predict is proportion of predicted labels, N is # of cells
#' @return list with checked plot_var, and nice label
#' @internal
.check_conf_params <- function(plot_var) {
	plot_var_list 	= c('prop_true', 'N', 'prop_predict')
	plot_var 		= match.arg(plot_var, plot_var_list)
	labels_list 	= c(prop_true='Proportion\nof labelled\nclass\n', N='# of cells', prop_predict='Proportion\nof predicted\nclass\n')
	plot_label 		= labels_list[[plot_var]]

	return( list(plot_var=plot_var, plot_label=plot_label) )
}

#' Plots confusion matrix of true labels against predicted labels.
#'
#' @param psuper_obj Psupertime object, output from psupertime
#' @param new_x,new_y Optional data to predict with psuper_obj
#' @param plot_var Variable to plot: prop_true is proportion of true labels, prop_predict is proportion of predicted labels, N is # of cells
#' @param palette RColorBrewer palette to use
#' @return ggplot2 object
#' @export
#' @importFrom data.table data.table
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 expand_limits
#' @importFrom ggplot2 geom_text
#' @importFrom ggplot2 geom_raster
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 scale_fill_distiller
#' @importFrom ggplot2 scale_x_discrete
#' @importFrom ggplot2 theme_bw
#' @importFrom scales pretty_breaks
plot_predictions_against_classes <- function(psuper_obj, new_x=NULL, new_y=NULL, process=FALSE, plot_var='prop_true', palette='BuPu') {
	# decide what to plot
	conf_params 	= .check_conf_params(plot_var)
	plot_var 		= conf_params$plot_var
	plot_label 		= conf_params$plot_label

	# unpack
	which_idx 		= psuper_obj$best_lambdas$which_idx
	glmnet_best 	= psuper_obj$glmnet_best

	# define fn to handle y
	.get_y_in <- function(new_y) {
		if (is.null(new_x)) {
			if ( length(new_y) != length(psuper_obj$y) ) {
				stop('when no new_x given, new_y must be same length as original y')
			}
		}
		if (!is.factor(new_y)) {
			new_y 			= factor(new_y)
			message('converting new_y into factor, with the following ordered values:')
			message(paste(levels(new_y), ', '))
			message('(define new_y as a factor if you prefer a different ordering)')
		}
		y_in 			= factor(new_y)
		return(y_in)
	}

	# what inputs to use?
	if ( is.null(new_x) & is.null(new_y) ) {
		x_in 			= psuper_obj$x_data
		y_in 			= psuper_obj$y

	} else if ( is.null(new_x) & !is.null(new_y) ) {
		x_in 			= psuper_obj$x_data
		y_in 			= .get_y_in(new_y)

	} else if ( !is.null(new_x) & !is.null(new_y) ) {
		x_in 			= .process_new_data(psuper_obj, new_x)
		y_in 			= .get_y_in(new_y)

	} else if ( !is.null(new_x) & is.null(new_y) ) {
		stop('to use new_x, new_y must also be given')

	} else {
		stop('aargh some unexpected error')
	}

	# get predicted classes for each thing
	predictions 	= .predict_glmnetcr_propodds(glmnet_best, x_in, y_in)
	pred_classes 	= factor(predictions$class[, which_idx], levels=levels(psuper_obj$y))
	predict_dt 		= data.table( predicted=pred_classes, true=y_in )

	# count and average
	counts_dt 		= predict_dt[, .N, by=list(predicted, true)]
	counts_dt[, prop_true 		:= N / sum(N), by=true ]
	counts_dt[, prop_predict 	:= N / sum(N), by=predicted ]

	# define where borders should be
	borders_dt 		= counts_dt[as.character(true)==as.character(predicted), list(true, predicted) ]

	# plot grid
	g = ggplot(counts_dt) +
		aes( y=true, x=predicted ) +
		geom_tile( aes_string(fill=plot_var) ) +
		geom_tile(data=borders_dt, aes(y=true, x=predicted), fill=NA, colour='black', size=0.5) +
		geom_text( aes(label=N) ) +
		scale_x_discrete( drop=FALSE ) +
		scale_fill_distiller( palette=palette, direction=1, breaks=pretty_breaks() )
	if (plot_var=='N') {
		g = g + expand_limits( fill=0 )
	} else {
		g = g + expand_limits( fill=c(0,1) )
	}
	g = g + labs(
			x 		= 'Predicted class'
			,y 		= 'Labelled class'
			,fill 	= plot_label
			) +
		theme_bw()
	return(g)
}

#' Projects two different psupertimes onto each other
#'
#' @importFrom forcats fct_drop
#' @param psuper_1, psuper_2 Two previously calculated psupertime objects
#' @param labels Character vector of length two, labelling the psupertime inputs
#' @return data.table containing projections in both directions
#' @export
double_psupertime <- function(psuper_1, psuper_2, run_names=NULL, process=FALSE) {
	# check run_names
	if ( is.null(run_names) ) {
		run_names 	= c('1','2')
		message('using default values for run_names:', paste(run_names, sep=', '))
	} else {
		if ( !is.character(run_names) | length(unique(run_names))!=2 ) {
			stop('run_names must be character vector of length two with no repeated values')
		}
	}

	# repack
	psuper_list 	= list(psuper_1, psuper_2)
	n_psupers 		= length(psuper_list)
	# names(psuper_list) 	= run_names

	# loop through projections on both
	doubles_dt 		= data.table()
	for (ii in 1:n_psupers) {
		# unload
		psuper_ii 		= psuper_list[[ii]]
		label_ii 		= run_names[[ii]]

		for (jj in 1:n_psupers) {
			# unload
			psuper_jj 		= psuper_list[[jj]]
			label_jj 		= run_names[[jj]]

			# get appropriate projection
			if (ii == jj) {
				proj_ii_on_jj 	= psuper_ii$proj_dt
			} else {
				proj_ii_on_jj 	= project_onto_psupertime(psuper_jj, psuper_ii$x_data, psuper_ii$y, process)
			}

			# label
			proj_ii_on_jj[, input 		:= label_ii ]
			proj_ii_on_jj[, projection 	:= label_jj ]
			n_digits = ceiling(log10(nrow(psuper_ii$x_data)))
			proj_ii_on_jj[, cell_id 	:= sprintf(sprintf('%%s_%%0%dd', n_digits), label_ii, 1:nrow(psuper_ii$x_data)) ]

			# store
			doubles_dt 		= rbind(doubles_dt, proj_ii_on_jj)
		}
	}

	# sort out levels
	lvls_all 		= c()
	for (ii in 1:n_psupers) {
		lvls_temp 	= setdiff(levels(psuper_list[[ii]]$y), lvls_all)
		lvls_all 	= c(lvls_all, lvls_temp)
	}
	doubles_dt[, label_input := factor(label_input, levels=lvls_all) ]

	# make wide, sort out levels?
	doubles_wide 	= dcast(doubles_dt, input + cell_id + label_input ~ projection, value.var=c('psuper', 'label_psuper'))
	for (ii in 1:n_psupers) {
		label 	= run_names[[ii]]
		doubles_wide[[ paste0('label_psuper_', label) ]] 	= fct_drop(doubles_wide[[ paste0('label_psuper_', label) ]])
		# levels(doubles_wide[[ paste0('label_psuper_', label) ]]) 	= levels(psuper_list[[ii]]$y)
	}

	# make lists of levels
	levels_list 		= lapply(psuper_list, function(p) levels(p$y) )
	names(levels_list) 	= run_names

	# put into list
	double_obj 	= list(
		run_names 		= run_names
		,levels_list 	= levels_list
		,doubles_dt 	= doubles_dt
		,doubles_wide 	= doubles_wide
		)
	return(double_obj)
}

#' Projects two different psupertimes onto each other, using points, side by side
#'
#' To do this, psupertime builds an internal \code{double_psupertime} object containing 
#' the projections. Given two psupertime objects \code{psuper_1} and \code{psuper_2}, you can 
#' call it in two ways: 
#' 
#' (1) By specifying the two psupertime objects you want to project:
#' \code{plot_double_psupertime(psuper_1=psuper_1, psuper_2=psuper_2)}
#' 
#' (2) Or by first constructing a \code{double_psupertime} object:
#' \code{double_obj = double_psupertime(psuper_1, psuper_2)}
#' \code{plot_double_psupertime(double_obj=double_obj)}
#'
#' For the coefficients of the two objects to be meaningfully applied to each 
#' other, the data needs to have been processed in the same way for each. We 
#' therefore recommend first preprocessing the data (either via \code{psupertime}'s 
#' defaults, or via your preferred method, then running \code{psupertime} with 
#' \code{smooth=FALSE} and \code{scale=FALSE}.
#'
#' @param double_obj Result of applying double_psupertime to two previously calculated psupertime objects
#' @param psuper_1, psuper_2 Two previously calculated psupertime objects
#' @param run_names Character vector of length two, labelling the psupertime inputs
#' @return ggplot object plotting the two against each other
#' @export
#' @importFrom ggplot2 aes_string
#' @importFrom ggplot2 facet_grid
#' @importFrom ggplot2 geom_point
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 scale_colour_manual
#' @importFrom ggplot2 theme_bw
plot_double_psupertime <- function(double_obj=NULL, psuper_1=NULL, psuper_2=NULL, run_names=NULL, process=FALSE) {
	# check inputs
	if (is.null(double_obj)) {
		if ( is.null(psuper_1) | is.null(psuper_2) ) {
			stop('either a double_obj must be given, or psuper_1 and psuper_2 must both be given')
		} else {
			double_obj 		= double_psupertime(psuper_1, psuper_2, run_names, process)
		}
	}

	# unpack
	run_names 		= double_obj$run_names
	label_x 		= run_names[[1]]
	label_y 		= run_names[[2]]
	doubles_wide 	= double_obj$doubles_wide

	# make colours
	col_vals 		= .make_col_vals(doubles_wide$label_input)

	# add facet labels
	plot_dt 		= copy(doubles_wide)
	plot_dt[, input_label := paste0('Input data: ', input) ]

	# do some plotting
	g = ggplot(plot_dt) +
		aes_string(
			x 			= paste0('psuper_', label_x)
			,y 			= paste0('psuper_', label_y)
			,colour 	= paste0('label_input')
			) +
		geom_point() +
		scale_colour_manual( values=col_vals ) +
		facet_grid( . ~ input_label) +
		theme_bw() +
		labs(
			x 			= paste0('Psupertime trained on ', label_x)
			,y 			= paste0('Psupertime trained on ', label_y)
			,colour 	= 'Known\nlabels'
			)

	return(g)
}

#' Projects two different psupertimes on top of each other
#'
#' See `plot_double_psupertime` for further detail.
#' 
#' @param double_obj Result of applying double_psupertime to two previously calculated psupertime objects
#' @param psuper_1, psuper_2 Two previously calculated psupertime objects
#' @param run_names Character vector of length two, labelling the psupertime inputs
#' @return ggplot object plotting the two against each other
#' @export
#' @importFrom ggplot2 aes_string
#' @importFrom ggplot2 geom_density2d
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 scale_colour_brewer
#' @importFrom ggplot2 theme_bw
plot_double_psupertime_contour <- function(double_obj=NULL, psuper_1=NULL, psuper_2=NULL, run_names=NULL) {
	# check run_names
	if ( is.null(run_names) ) {
		run_names 	= c('1','2')
		message('using default values for run_names:', paste(run_names, sep=', '))
	} else {
		if ( !is.character(run_names) | length(unique(run_names))!=2 ) {
			stop('run_names must be character vector of length two with no repeated values')
		}
	}
	# check inputs
	if (is.null(double_obj)) {
		if ( is.null(psuper_1) | is.null(psuper_2) ) {
			stop('either a double_obj must be given, or psuper_1 and psuper_2 must both be given')
		} else {
			double_obj 		= double_psupertime(psuper_1, psuper_2, run_names)
		}
	}

	# unpack
	run_names 		= double_obj$run_names
	label_x 		= run_names[[1]]
	label_y 		= run_names[[2]]
	doubles_wide 	= double_obj$doubles_wide

	# do some plotting
	g = ggplot(doubles_wide) +
		aes_string(
			x 			= paste0('psuper_', label_x)
			,y 			= paste0('psuper_', label_y)
			,colour 	= 'input'
			) +
		geom_density2d() +
		scale_colour_brewer( palette='Set1' ) +
		theme_bw() +
		labs(
			x 			= paste0('Psupertime run on ', label_x)
			,y 			= paste0('Psupertime run on ', label_y)
			,colour 	= 'Input\ndata'
			)

	return(g)
}

#' Compares coefficients for genes learned from different psupertimes
#'
#' @param psuper_1, psuper_2 Two previously calculated psupertime objects
#' @param run_names Character vector of length two, labelling the psupertime inputs
#' @return ggplot object plotting the two sets of coefficients
#' @export
#' @importFrom data.table setnames
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 geom_point
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 theme_bw
plot_double_psupertime_genes <- function(psuper_1, psuper_2, run_names=NULL) {
	# check run_names
	if ( is.null(run_names) ) {
		run_names 	= c('1','2')
		message('using default values for run_names:', paste(run_names, sep=', '))
	} else {
		if ( !is.character(run_names) | length(unique(run_names))!=2 ) {
			stop('run_names must be character vector of length two with no repeated values')
		}
	}

	# get genes from both
	old_names 	= c('beta', 'abs_beta')
	genes_1_dt 	= psuper_1$beta_dt[ abs_beta > 0 ]
	setnames(genes_1_dt, old_names, paste0(old_names, '_1'))
	genes_2_dt 	= psuper_2$beta_dt[ abs_beta > 0 ]
	setnames(genes_2_dt, old_names, paste0(old_names, '_2'))

	# join together, tidy up
	genes_dt 	= merge(genes_1_dt, genes_2_dt, by='symbol', all=TRUE, )
	genes_dt[ is.na(beta_1), beta_1 := 0 ]
	genes_dt[ is.na(abs_beta_1), abs_beta_1 := 0 ]
	genes_dt[ is.na(beta_2), beta_2 := 0 ]
	genes_dt[ is.na(abs_beta_2), abs_beta_2 := 0 ]

	# plot
	g = ggplot(genes_dt) +
		aes( x=beta_1, y=beta_2 ) +
		geom_point( alpha=0.5 ) +
		theme_bw() +
		labs(
			x 		= paste0('Coefficient for ', run_names[[1]])
			,y 		= paste0('Coefficient for ', run_names[[2]])
			)

	return(g)
}

#' Plots the confusion matrices of two psupertime objects against each other
#'
#' See `plot_double_psupertime` for further detail.
#'
#' @param double_obj Result of applying double_psupertime to two previously calculated psupertime objects
#' @param psuper_1, psuper_2 Two previously calculated psupertime objects
#' @param run_names Character vector of length two, labelling the psupertime inputs
#' @param palette RColorBrewer palette to use
#' @return cowplot plot_grid object, showing known and predicted labels for each dataset, and each set of predictions
#' @export
#' @importFrom cowplot plot_grid
#' @importFrom forcats fct_drop
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 aes_string
#' @importFrom ggplot2 expand_limits
#' @importFrom ggplot2 geom_text
#' @importFrom ggplot2 geom_tile
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 scale_fill_distiller
#' @importFrom ggplot2 scale_x_discrete
#' @importFrom ggplot2 theme_bw
#' @importFrom scales pretty_breaks
plot_double_psupertime_confusion <- function(double_obj=NULL, psuper_1=NULL, psuper_2=NULL, run_names=NULL, plot_var='prop_true', palette='BuPu') {
	if ( !requireNamespace("cowplot", quietly=TRUE) ) {
		message('cowplot not installed; not plotting confusion matrix')
		return()
	}

	# decide what to plot
	conf_params 	= .check_conf_params(plot_var)
	plot_var 		= conf_params$plot_var
	plot_label 		= conf_params$plot_label

	# check inputs
	if (is.null(double_obj)) {
		if ( is.null(psuper_1) | is.null(psuper_2) ) {
			stop('either a double_obj must be given, or psuper_1 and psuper_2 must both be given')
		} else {
			double_obj 		= double_psupertime(psuper_1, psuper_2, run_names)
		}
	}

	# unpack
	run_names 		= double_obj$run_names
	label_x 		= run_names[[1]]
	label_y 		= run_names[[2]]
	doubles_dt 		= double_obj$doubles_dt

	# set up
	input_list 		= unique(doubles_dt$input)
	n_inputs 		= length(input_list)
	proj_list 		= unique(doubles_dt$projection)
	n_projs 		= length(proj_list)
	g_list 			= list()
		
	# get factor lists
	levels_list 	= double_obj$levels_list

	# do multiple plots
	for (ii in 1:n_inputs) {
		for (jj in 1:n_projs) {
			# restrict to this combo of inputs/predictions
			input_ii 		= input_list[[ii]]
			psuper_jj 		= proj_list[[jj]]
			counts_dt 		= doubles_dt[ input==input_ii & projection==psuper_jj, .N, by=list(label_input, label_psuper) ]

			# calculate proportions
			counts_dt[, prop_true 		:= N / sum(N), 		by=label_input ]
			counts_dt[, prop_predict 	:= N / sum(N), 	by=label_psuper ]

			# tidy up labels
			counts_dt[, label_input 	:= fct_drop(label_input) ]
			counts_dt[, label_input 	:= factor(label_input, levels=levels_list[[input_ii]])]
			counts_dt[, label_psuper 	:= fct_drop(label_psuper) ]
			counts_dt[, label_psuper 	:= factor(label_psuper, levels=levels_list[[psuper_jj]])]

			# define where borders should be
			borders_dt 		= counts_dt[as.character(label_input)==as.character(label_psuper), list(label_input, label_psuper) ]

			# plot grid
			g = ggplot(counts_dt) +
				aes( y=label_input, x=label_psuper ) +
				geom_tile( aes_string(fill=plot_var) ) +
				geom_tile(data=borders_dt, aes(y=label_input, x=label_psuper), fill=NA, colour='black', size=0.5) +
				geom_text( aes(label=N) ) +
				scale_x_discrete( drop=FALSE ) +
				scale_fill_distiller( palette=palette, direction=1, breaks=pretty_breaks(), guide=FALSE ) +
				theme_bw()

			# colouring for tiles
			if (plot_var=='N') {
				g = g + expand_limits( fill=0 )
			} else {
				g = g + expand_limits( fill=c(0,1) )
			}

			# x, y labels
			if ( ii==n_inputs ) {
				g = g + labs( x=paste0('Predicted: ', run_names[[jj]]) )
			} else {
				g = g + labs( x=NULL )
			}
			if ( jj==1 ) {
				g = g + labs( y=paste0('Known: ', run_names[[ii]]) )
			} else {
				g = g + labs( y=NULL )
			}

			g_list[[ (ii - 1)*n_inputs + jj ]] 	= g
		}
	}

	g_grid 			= plot_grid(plotlist=g_list, labels=NULL, nrow=n_inputs, ncol=n_projs, align='h', axis='b')

	return(g_grid)
}

#' GO enrichment analysis for genes learned from different psupertimes
#'
#' @importFrom data.table data.table
#' @importFrom data.table setnames
#' @importFrom topGO runTest
#' @importFrom topGO GenTable
#' @param psuper_obj A previously calculated psupertime object
#' @param org_mapping Organism to use for annotations (e.g. 'org.Mm.eg.db', 'org.Hs.eg.db')
#' @return data.table containing results of GO enrichment analysis
#' @internal
psupertime_go_analysis_old <- function(psuper_obj, org_mapping) {
	# can we do this?
	if ( !requireNamespace("topGO", quietly=TRUE) ) {
		message('topGO not installed; not doing GO analysis')
		return()
	}
	library('topGO')

	# unpack
	psuper 			= scale(psuper_obj$proj_dt$psuper)
	n_obs 			= length(psuper)
	x_data 			= psuper_obj$x_data

	# calculate correlations
	corrs 			= as.vector(matrix(psuper, nrow=1) %*% x_data) / n_obs
	names(corrs) 	= colnames(x_data)

	# calculate p values for these
	t_stat 			= (corrs*sqrt(n_obs-2))/sqrt(1-corrs^2)
	p_vals 			= 2*(1 - pt(abs(t_stat),(n_obs-2)))

	# do GO in various ways
	go_dt 			= data.table()
	for (up_or_down in c('both', 'up', 'down')) {
		# do ranking
		if (up_or_down=='both') {
			scores 		= abs(corrs)

		} else if (up_or_down=='up') {
			scores 		= corrs

		} else if (up_or_down=='down') {
			scores 		= -corrs

		}
		scores[ scores < 0 ] 	= 0
		scores 		= sort(scores, decreasing=TRUE)
		if ( sum(scores > 0)==0 ) {
			next
		}

		# make topGO object
		topGO_data = new("topGOdata", 
			description 		= up_or_down, 
			allGenes 			= scores, 
			geneSel 			= function(x) {x>0.1},
			annot 				= topGO::annFUN.org, 
			mapping 			= org_mapping, 
			ontology 			= 'BP',
			ID 					= 'symbol'
			)

		# run enrichment tests on these, extract results
		go_weight 	= runTest(topGO_data, algorithm = "weight01", statistic = "fisher")
		go_temp 	= data.table(GenTable(topGO_data, 
			p_go 	= go_weight, 
			orderBy 		= 'p_go', 
			ranksOf 		= 'p_go', 
			topNodes 		= 1000
			))
		setnames(go_temp, 'p_go', 'temp')
		go_temp[, p_go := as.numeric(temp) ]
		go_temp[ temp == '< 1e-30', p_go := 9e-31 ]
		go_temp[, temp := NULL ]
		go_temp[, direction := up_or_down]
		go_temp[, rank := 1:nrow(go_temp)]

		# store
		go_dt 		= rbind(go_dt, go_temp)
	}

	# change column order
	setcolorder(go_dt, c('direction', 'rank'))
	# print top terms
	p_cutoff 		= 5e-2
	n_terms_cutoff 	= 5
	print_dt 		= go_dt[ p_go < p_cutoff & Significant>n_terms_cutoff ]
	if (nrow(print_dt)==0) {
		message(sprintf('no GO terms met the cutoffs (p-value < %.1e and at least %d genes significant)', p_cutoff, n_terms_cutoff))
	} else {
		message('Significant GO terms:')
		print(print_dt)
	}

	return(go_dt)
}

#' GO enrichment analysis for genes learned from different psupertimes
#'
#' @importFrom data.table set
#' @importFrom data.table setorder
#' @importFrom fastcluster hclust
#' @param psuper_obj A previously calculated psupertime object
#' @param org_mapping Organism to use for annotations (e.g. 'org.Mm.eg.db', 'org.Hs.eg.db')
#' @return data.table containing results of GO enrichment analysis
#' @export
psupertime_go_analysis <- function(psuper_obj, org_mapping, k=5, sig_cutoff=5) {
	if ( !requireNamespace("topGO", quietly=TRUE) ) {
		message('topGO not installed; not doing GO analysis')
		return()
	}
	if ( !requireNamespace("fastcluster", quietly=TRUE) ) {
		message('fastcluster not installed; not doing GO analysis')
		return()
	}

	# unpack
	glmnet_best 	= psuper_obj$glmnet_best
	best_lambdas 	= psuper_obj$best_lambdas
	proj_dt 		= copy(psuper_obj$proj_dt)
	x_data 			= copy(psuper_obj$x_data)
	beta_dt 		= psuper_obj$beta_dt
	cuts_dt 		= psuper_obj$cuts_dt

	# put cells in nice order, label projections
	rownames(x_data) 		= sprintf('cell_%04d', 1:nrow(x_data))
	set(proj_dt, i=NULL, 'cell_id', rownames(x_data))
	setorder(proj_dt, psuper)

	# do clustering on symbols
	message('clustering genes')
	hclust_obj 		= fastcluster::hclust(dist(t(x_data)), method='complete')

	# extract clusters from them
	clusters_dt 	= .calc_clusters_dt(hclust_obj, x_data, proj_dt, k)
	go_results 		= .do_topgo_for_cluster(clusters_dt, sig_cutoff, org_mapping)

	# make plot_dt
	plot_dt 		= .make_plot_dt(x_data, hclust_obj, proj_dt, clusters_dt)

	# assemble outputs
	go_list = list(
		clusters_dt 	= clusters_dt,
		go_results 		= go_results,
		plot_dt 		= plot_dt,
		cuts_dt 		= copy(psuper_obj$cuts_dt)
		)

	return(go_list)
}

#' make nice data.table of hierarchical clusters
#'
#' @param hclust_obj Result of hclust
#' @param x_data Data used to calculate psuper_obj
#' @param proj_dt Projection of cells onto psupertime
#' @return data.table containing clusters of genes, ordered according to correlation with psupertime
#' @internal
.calc_clusters_dt <- function(hclust_obj, x_data, proj_dt, k=5) {
	# make thing
	clusters_dt 	= data.table( h_clust=cutree(hclust_obj, k=k), symbol=colnames(x_data))
	# add clustering
	clusters_dt[, N:=.N, by=h_clust ]

	# order by correlation with psupertime
	temp_dt 			= data.table(melt(x_data, varnames=c('cell_id', 'symbol')))
	temp_dt 			= clusters_dt[ temp_dt, on='symbol' ]
	means_dt 			= temp_dt[, list(mean=mean(value)), by=list(cell_id, h_clust) ]
	means_dt 			= proj_dt[ means_dt, on='cell_id' ]
	corrs_dt 			= means_dt[, list( cor=cor(mean, psuper) ), by=h_clust]
	setorder(corrs_dt, cor)
	corrs_dt[, clust := 1:.N ]
	corrs_dt[, clust := factor(clust)]
	
	# add clusters ordered by size back in
	clusters_dt			= corrs_dt[ clusters_dt, on='h_clust' ]
	clusters_dt[, clust_label := factor(sprintf('%02d (%d genes)', clust, N)) ]
	clusters_dt[, h_clust := NULL ]
	setorder(clusters_dt, clust, symbol)

	return(clusters_dt)
}

#' Calculate GO enrichment for each cluster vs all other genes
#'
#' @param clusters_dt 
#' @param sig_cutoff How many genes should be in the cluster for us to consider a GO term?
#' @return data.table with GO term results
#' @internal
.do_topgo_for_cluster <- function(clusters_dt, sig_cutoff, org_mapping) {
	# set up
	all_clusters 	= unique(clusters_dt[N>=sig_cutoff]$clust)
	go_results 		= data.table()

	# loop through clusters
	message(sprintf('calculating GO enrichments for %d clusters:', length(all_clusters)))
	for (c in all_clusters) {
		message('.', appendLF=FALSE)
		gene_list 			= factor( as.integer(clusters_dt$clust == c) )
		names(gene_list) 	= clusters_dt$symbol

		# make topGO object
		suppressMessages({
			topGO_data 	= new("topGOdata", 
			description 		= c, 
			allGenes 			= gene_list, 
			# geneSelectionFun 	= function(x) {x==TRUE},
			annot 				= annFUN.org, 
			mapping 			= org_mapping, 
			ontology 			= 'BP',
			ID 					= 'symbol'
			)
		})
		# run enrichment tests on these, extract results
		suppressMessages({go_weight 	= runTest(topGO_data, algorithm = "weight", statistic = "fisher")})
		n_terms 		= length(go_weight@score)
		temp_results 	= data.table(GenTable(topGO_data, 
			p_go 	= go_weight, 
			orderBy 		= 'p_go', 
			ranksOf 		= 'p_go', 
			topNodes 		= n_terms
			))
		temp_results[ , cluster := c ]

		# store
		go_results 	= rbind(go_results, temp_results)
	}
	message('')

	# tidy up
	setnames(go_results, 'p_go', 'tmp')
	go_results[, p_go := as.numeric(tmp) ]
	go_results[ tmp == '< 1e-30', p_go := 9e-31 ]
	go_results[ , tmp := NULL ]
	go_results[ , cluster := factor(cluster, levels=all_clusters) ]

	return(go_results)
}

#' Internal function
#'
#' @param x_data 
#' @param hclust_obj 
#' @param proj_dt 
#' @param clusters_dt 
#' @return data.table for plotting
#' @internal
.make_plot_dt <- function(x_data, hclust_obj, proj_dt, clusters_dt) {
	# plot
	plot_dt 			= data.table(melt(x_data, varnames=c('cell_id', 'symbol')))

	# nice ordering
	symbol_order 		= colnames(x_data)[hclust_obj$order]
	plot_dt[, symbol 	:= factor(symbol, levels=symbol_order)]
	plot_dt[, cell_id 	:= factor(cell_id, levels=proj_dt$cell_id)]

	# put this into plotting 
	plot_dt 			= clusters_dt[ plot_dt, on='symbol' ]
	plot_dt 			= proj_dt[ plot_dt, on='cell_id' ]

	return(plot_dt)
}

#' Plots the significant GO terms for each cluster
#'
#' @param go_results Output from GO analysis
#' @param sig_cutoff What is the minimum number of annotated genes to display a GO term?
#' @param p_cutoff What is the maximum p-value to display a GO term?
#' @return bar plot
#' @export
#' @importFrom data.table setorder
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 geom_col
#' @importFrom ggplot2 facet_grid
#' @importFrom ggplot2 coord_flip
#' @importFrom ggplot2 scale_y_continuous
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 theme_bw
#' @importFrom scales pretty_breaks
plot_go_results <- function(go_list, sig_cutoff=5, p_cutoff=0.1) {
	# unpack
	go_results 	= go_list$go_results

	# set up
	plot_dt 	= go_results[ Significant>=sig_cutoff & p_go<p_cutoff ]
	setorder(plot_dt, cluster, -p_go)
	plot_dt[, N := .N, by=Term ]
	plot_dt[	  , term_n := Term ]
	plot_dt[ N > 1, term_n := paste0(Term, '_', 1:.N), by=Term ]
	plot_dt[, term_n := factor(term_n, levels=plot_dt$term_n) ]

	# plot
	g = ggplot(plot_dt) +
		aes( x=term_n, y=-log10(p_go) ) +
		geom_col() +
		scale_y_continuous( breaks=pretty_breaks() ) +
		facet_grid( cluster ~ ., scales='free_y', space='free_y') +
		coord_flip() +
		labs(
			x 	= NULL
			,y 	= '-log10( p-value )'
			) +
		theme_bw()
	return(g)
}

#' Plot heatmap of gene clusters
#'
#' @param go_list Output from GO analysis
#' @return ggplot object
#' @export
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 geom_tile
#' @importFrom ggplot2 scale_fill_distiller
#' @importFrom ggplot2 facet_grid
#' @importFrom ggplot2 theme
#' @importFrom ggplot2 element_blank
#' @importFrom ggplot2 theme_bw
#' @importFrom ggplot2 labs
plot_heatmap_of_gene_clusters <- function(go_list) {
	# unpack
	plot_dt 	= go_list$plot_dt

	# plot
	g = ggplot(plot_dt) +
		aes( x=cell_id, y=symbol, fill=value ) +
		geom_tile() +
		scale_fill_distiller( palette='RdBu', limits=c(-3, 3) ) +
		facet_grid( clust_label ~ ., scale='free_y', space='free_y' ) +
		theme_bw() +
		theme(
			axis.text 	 = element_blank()
			,axis.ticks  = element_blank()
			) +
		labs(
			x 		= 'Cell'
			,y 		= 'Symbol'
			,fill 	= 'z-scored gene\nexpression'
			)
	return(g)
}

#' Plot heatmap of gene clusters
#'
#' @param go_list Output from GO analysis
#' @param label_name Description for the ordered labels in the legend (e.g. 'Developmental stage (days)')
#' @param palette RColorBrewer palette to use
#' @return ggplot object
#' @export
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 geom_vline
#' @importFrom ggplot2 scale_colour_manual
#' @importFrom ggplot2 geom_rug
#' @importFrom ggplot2 geom_smooth
#' @importFrom ggplot2 facet_grid
#' @importFrom ggplot2 theme_bw
#' @importFrom ggplot2 theme
#' @importFrom ggplot2 element_blank
#' @importFrom ggplot2 labs
plot_profiles_of_gene_clusters <- function(go_list, label_name='Ordered labels', palette='RdBu') {
	# unpack
	plot_dt 	= go_list$plot_dt
	cuts_dt 	= go_list$cuts_dt

	# set up what to plot
	means_dt 	= plot_dt[, list(value=mean(value)), by=list(psuper, clust_label)]

	# make nice colours
	col_vals 	= .make_col_vals(cuts_dt$label_input, palette)

	# plot
	g = ggplot(means_dt) +
		geom_vline(data=cuts_dt, aes(xintercept=psuper, colour=label_input)) +
		scale_colour_manual( values=col_vals ) +
		geom_smooth( colour='black', span=0.2, method='loess', aes( x=psuper, y=value ) )
	n_cells 	= length(unique(plot_dt$cell_id))
	if ( n_cells<=2000 ) {
		rug_dt 		= unique(plot_dt[, list(psuper, cell_id)])
		g = g + geom_rug(data=rug_dt, sides='b', alpha=0.1, aes(x=psuper) )
	}
	g = g + facet_grid( clust_label ~ ., scales='free_y' ) +
		theme_bw() +
		theme(
			axis.text 	= element_blank()
			) +
		labs(
			x 		= 'psupertime'
			,y 		= 'z-scored gene expression'
			,colour = label_name
			)
	return(g)
}
wmacnair/psupertime documentation built on July 10, 2020, 8:12 p.m.