knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
library(AccelBenchmark)
benchmark
is the main function in these codes. We use some examples to show how to use it. First let's try poisson mixture.
# special control parameters for quasi-newton qn_control = list(qn=2) # set seed for reproducibility # randomness comes from parameter initialization # or data generation process set.seed(54321) res_poi = benchmark( task = "poissmix", algorithm = c("raw", "squarem", "qn", "pem", "daarem", "nes"), ntimes = 20, control = list(maxiter=2000, tol=1e-7), control.spec = list(qn_v2=qn_control), verbose = FALSE ) plot(res_poi, type="iter", start_from=5) summary(res_poi, rate_tol=0.1)
Here, benchmark
takes these basic parameters.
task
is for a problem and it is a list containing initfn
for parameter initiation function, fixptfn
for fix-point iteration function, objfn
for loss function and other parameters that are needed for fixptfn
and objfn
. task
can also be a function that returns such list. And there are 6 recorded list which you can also pass in their character names. "poissmix" for Poisson mixture, "mvt_mmf" for multivariate t-distribution using MMF algorithm, "lasso" for LASSO, "bvs" for Bayesian variable selection, "sinkhorn" for Sinkhorn iteration and "tsne" for t-SNE.
algorithm
is an array for algorithms being benchmarked. They are "raw" for original algorithm, "squarem" for SQUAREM, "daarem" for DAAREM, "qn" for Quasi-Newton, "pem" for parabolic-EM and "nes" for restart-Nesterov.
ntimes
is the number of runs used in benchmarking.
control
is a control list used for all algorithm. For the information of how to specify the control list, please check their individual packages turboEM
, SQUAREM
, daarem
.
control.spec
is a named list to pass in specific control list for any used algorithm.
benchmark
function will return a benchmark
class that plot
method and summary
method are implemented for it. plot
will plot out loss v.s fpevals / elapsed time for one run, where loss is difference of loss with the smallest loss. type
parameter controls x-axis to be fpevals or elapsed time and start_from
parameter controls which run to be plotted. summary
will calculate all benchmark measures. It has two potential parameters: rate_tol
means a run with a loss smaller than (1+rate_tol) * minimum_loss
will be regarded as convergence. loss_tol
means a run with a loss smaller than loss_tol
will be regarded as convergence.
task
can also be a function that return this kind of list. And the parameters for that function can be passed into benchmark
function. For example the LASSO problem requires a regularization parameter $\lambda$, which is the weight for the $L_1$ penalty. We can benchmark it as following with lam
being $lambda$ (not run)
sqem_ctrl = list(step.min0=0) set.seed(54321) res_l1 = benchmark( "lasso", algorithm=c("raw", "squarem", "qn_v3", "pem", "daarem", "nes"), ntimes = 200, control = list(maxiter=20000), control.spec = list(squarem=sqem_ctrl), lam = 1 )
Now let's give an example how to benchmark your own task. We consider a toy example where the fixed point iteration is $f(x) = b / (a + x)$, where $a, b > 0$. Therefore the fixed point is $(\sqrt{a^2 + 4b} - a) / 2$ (start from $x > 0$). Therefore we can specify loss as MSE. Here is the code for benchmarking the task (initialize $x$ uniformly in $[0, 1]$).
# Initialization function # make sure this function takes no parameters toy_init = function(){ runif(1, 0, 1) } # Fixed point iteration toy_fixptfn = function(par, a, b){ b / (a + par) } # Loss function toy_objfn = function(par, a, b){ sqrt((par - (sqrt(a^2 + 4*b) - a) / 2)^2) } # task function # returned list can contain any extra parameters # they will be passed in fixptfn and objfn # make sure their parameter names match toy_task = function(a=1, b=1){ list( initfn = toy_init, fixptfn = toy_fixptfn, objfn = toy_objfn, a = a, b = b ) } # benchmarking set.seed(54321) res_toy = benchmark( toy_task, algorithm=c("raw", "squarem", "qn", "pem", "daarem", "nes"), ntimes = 10, control = list(maxiter=1000), verbose = FALSE, a = 11, b = 100 ) summary(res_toy, loss_tol=1e-7)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.