tests/testthat/test-predict.R

context("predict")

test_that("predict.dbn works as expected", {
	pred <- predict(trained.mnist, test.dat)
	expected_pred <- structure(c(1.36068022900793, -3.93217205798754, 35.7254451713427, 
								 -13.8076382423929, -4.71351325675112, 30.4040762128152, 0.410025307175701, 
								 -2.12388546164525, -1.2066478128856, 2.79178278476448, -1.6070840520724, 
								 -5.40068151742813, -14.2830225517939, -7.12445103462788, 7.0039541996005, 
								 -14.0238083846545, 1.94602294098004, 14.8564747465857, -6.9961284483539, 
								 -9.22410750585415), .Dim = c(10L, 2L))
	
	expect_equal(pred, expected_pred) # somehow expected_error is a bit imprecise

	# Works with 1 row
	expect_equal(predict(trained.mnist, test.dat[1,, drop = FALSE]), expected_pred[1,])
})


test_that("error.rbm works as expected", {
	rbm <- pretrained.mnist[[1]]
	pred <- predict(rbm, test.dat)
	expected_preds <- c(3.23825520196823e-07, 0.853223368782291, 0.908010582118508)
	
	# Too big to store all of it
	expect_equal(pred[c(1, 42, 10000)], expected_preds)
	expect_identical(dim(pred), c(10L, 1000L))
	
	# Works with 1 row
	expected_row1 <- structure(c(3.23825520196823e-07, 1.2747660623425e-27, 1.62520539814456e-21, 
								 7.65097694331178e-33, 0.444310268076889, 0.71418019108177, 3.58820630854045e-08, 
								 0.0168191962265736, 0.324120450637274, 0.961825179808209, 9.30540201836536e-22, 
								 0.999999999999723, 0.992130704929728, 0.187793186023204, 0.0265888823855532, 
								 0.938111673685392, 5.42245254544129e-08, 0.99293726513523, 1.32865052638786e-16, 
								 0.113421121161031, 0.159326408131786, 6.99021045551389e-22, 2.07312223695019e-10, 
								 0.569583329124739, 0.402164273747326, 1.28850622151655e-22, 0.0318475293629036, 
								 0.999873753994536, 0.580295654937852, 0.605008732145501, 0.427927548254064, 
								 0.506465469335718, 0.0569522755129851, 9.91970230529577e-21, 
								 0.91602703337159, 2.78589238320679e-20, 1.24416211424903e-15, 
								 7.16148856989776e-24, 0.0830992370119461, 0.972656405678966, 
								 0.959763355757482, 0.999547710991095, 1.55936836321708e-10, 0.0133939961478821, 
								 0.904096798621082, 4.7716714936173e-25, 0.0189086692178496, 0.00546883951188438, 
								 5.58475553261727e-14, 0.884616676291911, 4.99042470982291e-07, 
								 0.999999247876863, 0.999999999999988, 3.49445043810474e-13, 3.00977603735493e-52, 
								 0.818287430351112, 0.999819069538944, 2.16420766474104e-18, 0.00250165769265763, 
								 2.60287332329798e-08, 0.784284572626387, 0.804339102971062, 9.72896849370898e-08, 
								 0.775399177404407, 0.967882768513023, 0.878866300938417, 0.49019672594595, 
								 0.912824122398108, 8.60514323813134e-10, 0.988857397881061, 2.83997001661585e-17, 
								 0.949646760519388, 0.168425509937223, 0.964028187600938, 0.0109356550371559, 
								 0.962662673832367, 1.06341865896743e-14, 0.952280516411567, 0.393386859342872, 
								 0.944778992555724, 0.817607065677308, 0.201519378686501, 0.999999998105695, 
								 2.77966041550397e-08, 0.393758076347974, 3.57661592828483e-07, 
								 2.55849813653747e-08, 0.969272440002828, 0.767617779769631, 0.970979116335913, 
								 0.918228427770301, 3.85257155049394e-06, 9.32012993406488e-26, 
								 1.15210334704503e-28, 1.84184074537e-07, 4.08919321751609e-17, 
								 3.9611314900018e-05, 0.78015371892169, 0.0963911856887342, 0.999996946434062, 
								 0.932055377926315, 0.935878697056096, 0.999996563089235, 0.957882363869008, 
								 0.148059868226008, 0.898905621693743, 0.99999910020579, 0.99999999999913, 
								 0.095281960923322, 0.176539090237364, 0.999248141626928, 0.330659182230064, 
								 0.261831649134774, 0.602458909634234, 0.0215199285467085, 0.99995259003056, 
								 0.99973797057982, 0.69928265783318, 0.0897187894658331, 4.49767735200124e-11, 
								 0.807284901042704, 0.000111586309084222, 5.68339234951314e-16, 
								 9.83853735424034e-09, 0.016835853121092, 0.999999999998535, 0.888061155013778, 
								 0.996295184562211, 0.991406219283425, 0.820375891022045, 1, 0.00195783599051965, 
								 0.979666929825277, 0.167361684830828, 3.12227322992233e-05, 2.43850070283933e-06, 
								 8.20680697531696e-23, 0.840824154293628, 0.922739043753179, 0.836266640473147, 
								 0.999999204635611, 1.83929399168311e-12, 7.52973553803561e-44, 
								 0.989509693116519, 0.000231683350079692, 0.999996533722758, 0.913869489177818, 
								 8.57675253861048e-16, 5.8607858870885e-14, 0.797957771432963, 
								 0.999998725888524, 0.194034760216961, 0.550711182843469, 1.62072829823335e-31, 
								 0.992403973160738, 0.452535934389046, 8.25067110750617e-26, 0.691953263536627, 
								 1, 5.7249867449699e-10, 0.670517730390977, 0.200007217863441, 
								 0.991522120723398, 0.000425151118954307, 1.69390813006314e-16, 
								 1.84362933868481e-14, 0.972962443663081, 6.86966101237509e-17, 
								 0.9997117203957, 0.978267931262471, 0.481479628417547, 0.845918239769021, 
								 5.17979965558435e-07, 0.999996931486529, 2.59450587413989e-15, 
								 0.96723864184322, 0.999999999909086, 0.999958521609608, 0.999999999964829, 
								 0.999482718552727, 0.999999985863858, 1.83714187252683e-12, 0.999989211705658, 
								 4.1157544038075e-47, 0.00746961482188896, 0.999999561302425, 
								 1.915952488664e-26, 2.04217877755456e-05, 0.475990702949127, 
								 0.698282867083937, 0.259881778729565, 4.93045591229937e-21, 0.931052494919767, 
								 9.91520567217172e-15, 8.27760439780022e-28, 5.13280566265941e-12, 
								 0.997543401669936, 0.75262909645926, 0.26771736239188, 0.917100465767929, 
								 0.519695640961409, 7.66579641276377e-38, 0.00619108451206794, 
								 0.954968416132938, 0.961946595797079, 0.374919824024143, 0.414573685916689, 
								 0.967037641826638, 1.34202966283877e-06, 7.33829168704546e-06, 
								 1.4500073866558e-05, 3.16833325840121e-21, 3.84959306457464e-07, 
								 0.00580488551513013, 0.995266201558471, 0.714318603254019, 2.94399215782232e-10, 
								 0.0933771472743558, 0.943911585829584, 0.960458871818204, 0.049278180503151, 
								 0.672289224330154, 0.772749902680489, 0.00118630654766909, 0.484426926778941, 
								 0.999999999165076, 0.987013828525936, 0.999992987690944, 0.852991450619175, 
								 0.546292536357702, 0.110144554545772, 0.94315329751876, 0.976637312240252, 
								 0.910585012736869, 1.8271959359984e-05, 0.744554752617489, 0.00103788603048029, 
								 0.99954947724342, 0.0319621119615618, 0.995308282159671, 1, 0.999985711262848, 
								 0.957304730850237, 5.28033284475385e-10, 2.1776970367098e-10, 
								 0.115114072443901, 0.0102409421698109, 1.93959018131162e-22, 
								 0.663982121846158, 0.961568178917304, 0.999999999891598, 7.41635911345469e-08, 
								 0.999999996932417, 0.99340268213451, 0.0150353610916381, 0.00803872726369283, 
								 0.93778443172767, 1.59447950163865e-07, 0.396558310658103, 0.999999999226936, 
								 8.96354886013489e-28, 0.716779290950649, 0.131603461064653, 6.85994906227074e-07, 
								 0.653659034784391, 0.835159187372871, 0.00183699009219512, 6.24682376525475e-34, 
								 0.337368154156897, 0.747443093650239, 0.724742818031943, 0.876725440820118, 
								 0.998969172713414, 2.42634792846844e-05, 0.989353364944005, 0.0364495380351688, 
								 0.950209372068147, 0.0686068582650694, 0.471656617395085, 0.876790483194587, 
								 1, 0.999482384030922, 0.00317368695300074, 1.41538814703924e-08, 
								 0.922314580242426, 0.912540923014179, 0.800529149208067, 6.46697322404322e-06, 
								 3.82289532641312e-10, 0.999999999607236, 1.06099936316866e-23, 
								 0.870491041793731, 0.999999755276536, 0.999823751840881, 1.06536282661015e-25, 
								 6.79986264420323e-23, 0.00108332744031669, 0.994939473070214, 
								 2.92591729919757e-32, 0.0579511417982094, 0.000151328543760137, 
								 0.988783021524725, 0.403546763580554, 0.785251130865957, 0.976641290471511, 
								 0.93980225044797, 1, 0.726571088756637, 0.724896279916309, 8.5258075703706e-09, 
								 0.911957636481532, 8.05209763366361e-26, 0.290949633096613, 0.268800444966793, 
								 5.95709339698268e-05, 0.894650207628881, 0.260144298197806, 0.464710294726368, 
								 1.15567738040076e-07, 3.66891155402129e-38, 0.0108167980382965, 
								 0.0387968921567866, 7.20737773027246e-13, 6.34104340473323e-23, 
								 0.558213387513506, 0.00315738228358128, 0.55405775805334, 0.124143081988982, 
								 6.53280210651047e-25, 0.841018748699923, 1, 1, 1.36335727027385e-09, 
								 0.83055024431747, 7.92487370416889e-07, 0.748529814935739, 4.52458515549227e-14, 
								 6.76501777351111e-10, 1.99382452949288e-11, 0.961160948778362, 
								 0.231990749859227, 8.47703808397378e-05, 0.00705935980724859, 
								 0.567220030591679, 0.23589999127173, 0.0399023382399479, 2.74405998373359e-27, 
								 0.99779002992035, 0.221605485498709, 1.08594437791599e-30, 8.72191005018717e-19, 
								 0.00950689902468092, 0.574535258465867, 0.928817063051905, 0.00106824951882869, 
								 0.815783075784067, 0.952277649611854, 1, 0.825711907249101, 0.426454747247484, 
								 0.954214127077153, 3.20817047326774e-12, 0.615657370553236, 1.10384105135242e-09, 
								 0.854056450937741, 0.256279042730395, 0.948383589425407, 0.388116273286498, 
								 0.999864967164405, 3.37081942825393e-17, 0.721111613529676, 0.0978797980887358, 
								 0.000710973006182442, 0.616039237568209, 0.96570714593266, 0.40168847072258, 
								 2.25180759843822e-21, 5.89199418480754e-13, 0.590387419711792, 
								 0.0031910207265499, 0.00691554983798541, 0.293401828286786, 4.1486650099177e-08, 
								 8.97045509992571e-19, 0.79430778300107, 0.785860279058499, 0.98935090547293, 
								 0.570148765638828, 1.86964510411956e-10, 0.998223731661057, 0.955876421741576, 
								 0.999999999999997, 0.753363075778991, 0.999999999474998, 0.977514082148559, 
								 8.99428932956739e-23, 9.48214074525514e-11, 0.306142416242105, 
								 0.00118659421068185, 0.00460080331219883, 0.806424813663988, 
								 0.78287091575837, 0.999988625723435, 0.521422203748527, 0.861692079748515, 
								 0.883342423125586, 0.999999882127012, 3.26094956176716e-06, 1.42978778869022e-33, 
								 0.00791593085532548, 0.17939259442653, 0.831857043575209, 0.160475664513042, 
								 0.999999999736539, 0.90085801188125, 0.345905500387904, 0.0010067536800896, 
								 0.944030625426112, 0.951928008667909, 0.0130668889956826, 0.00459568083832981, 
								 0.994862931789984, 2.12140616699818e-10, 9.52296823139664e-26, 
								 3.67741291964919e-07, 0.99999999958883, 0.00284973811400498, 
								 7.69260117784737e-09, 0.999958299022625, 0.778123368477642, 0.0885826633859301, 
								 1.45947454875514e-19, 0.00317738928893523, 0.676184103714131, 
								 0.0109265215457243, 0.98496546040988, 0.999863160559716, 0.361306364352767, 
								 3.29683544052312e-36, 0.939959966518264, 0.00925513981571421, 
								 0.000581586773671047, 1, 0.999980160400871, 0.992959245472862, 
								 4.08139745095887e-06, 0.538762126779133, 3.26078454311134e-28, 
								 0.881759279618468, 1.99987266365142e-07, 3.03059628821886e-09, 
								 8.43137221972453e-24, 0.4520013648143, 0.999920159421818, 0.982821642978874, 
								 0.000733551497080461, 0.995236776136048, 0.000190707788503476, 
								 0.983354323358879, 8.70121453369228e-15, 0.621061430377666, 0.00016397429260015, 
								 0.708440591386744, 0.864337259721277, 0.805209202938118, 0.889649462108362, 
								 8.21532699555946e-08, 0.103905665689623, 0.254328292813925, 0.435233033590724, 
								 0.995338571696625, 0.768580065654275, 1.04130484535963e-07, 4.16288931880217e-14, 
								 0.617789481270989, 0.805586904899139, 0.702443124796442, 9.19364853714306e-06, 
								 0.0193182113737422, 0.73087117195702, 0.940341887064515, 2.30777961611012e-37, 
								 0.0904954457796294, 5.02069239205325e-08, 1.57822700964065e-21, 
								 0.331871720103599, 5.37213529118093e-35, 3.39787674841149e-09, 
								 1, 0.108167576032237, 1.78161256780598e-11, 0.817688747229836, 
								 0.00827860307672816, 7.80886315511049e-05, 0.965170423336149, 
								 0.947409680153941, 3.37652007362996e-05, 0.184115456235309, 0.620886994259374, 
								 9.08108489328152e-06, 0.999518015055168, 7.9341522496818e-07, 
								 0.851483778753845, 1.38793598421363e-14, 0.885454686914303, 0.213246167660423, 
								 6.88196053872261e-19, 0.330726158374722, 2.36757362042601e-24, 
								 5.32300182169013e-08, 1.70261937657807e-33, 0.660508941228838, 
								 6.82317154465024e-05, 0.608180017725954, 0.863457600531788, 0.999386941505421, 
								 0.00457082373302068, 0.156083929333163, 0.932218517671301, 0.619787246463156, 
								 0.152285827670774, 0.208825421895999, 0.94869347505612, 1, 0.910555825024142, 
								 0.275697679523827, 1.6549390972382e-05, 8.00973589868094e-22, 
								 0.999957404702512, 1.52472928995643e-06, 0.807696392216548, 0.877752235419164, 
								 1.70325508895247e-26, 9.96452069260975e-08, 0.99712300006114, 
								 4.169036090563e-07, 1.52068065529797e-20, 0.875911248321864, 
								 0.0664809205576234, 0.47067348328737, 1.05673645641615e-17, 3.48814186533334e-18, 
								 1.70272028228311e-27, 1.1459566676181e-16, 1.20440282855701e-19, 
								 0.997598534477374, 0.745166380085395, 0.923018244814319, 1, 9.0800412618601e-23, 
								 3.81004267743531e-20, 0.721165772625664, 0.846670703054472, 0.00274296638602901, 
								 5.95455108252941e-18, 2.74089102688951e-05, 0.839251637833027, 
								 1.64732159343841e-07, 5.43839294544495e-16, 0.000711699264854365, 
								 7.75506850873861e-19, 4.35222904992415e-30, 0.0288422822625757, 
								 0.318417307033026, 0.174410927573296, 6.87553868608297e-27, 3.33794212060182e-14, 
								 0.92281068160328, 0.0816859535164543, 0.999999687059611, 0.999999970303569, 
								 0.590123111647907, 1.94546881819595e-05, 0.947783138351016, 0.679986115982311, 
								 4.07947219222819e-05, 4.30167235938482e-16, 3.35017162788622e-17, 
								 7.51942872683496e-09, 2.70054772197064e-05, 2.53497127058839e-07, 
								 0.118111540510991, 0.999999999999027, 0.674851554311374, 0.999971207268466, 
								 0.981418144918196, 0.878574510174331, 0.19023156797354, 9.01015720690091e-14, 
								 2.5992954700523e-07, 2.02206433239837e-09, 0.999999315001613, 
								 0.786184826284576, 0.996676524069982, 3.05303565826881e-12, 0.745889338213759, 
								 7.49451421639054e-24, 1.07732770824214e-11, 0.916286119900906, 
								 0.544328530679284, 0.0168798873682507, 7.44825203615605e-11, 
								 0.655958844914994, 3.26989694143936e-07, 2.78906335912252e-09, 
								 0.236077569297782, 0.60567989311608, 0.505586412797037, 0.00075347482649913, 
								 1.13505258149994e-08, 6.24508786650526e-11, 0.706917767762729, 
								 1.05583035327305e-24, 0.00183691478874839, 0.00016021444807227, 
								 0.918820623757843, 1, 0.999999999907836, 7.94019484383387e-20, 
								 0.0989185449582434, 0.162208769863642, 0.89306660554614, 0.0313111473237227, 
								 0.0236602000308463, 0.040466688430793, 0.999989428991899, 0.993990874425128, 
								 0.541918401470912, 0.0540738795443326, 4.05947547338988e-28, 
								 6.17644509080718e-05, 8.82290440728381e-52, 7.66684453031124e-33, 
								 0.91226403776575, 1.6107540518422e-22, 0.674154601248426, 0.630800543202178, 
								 2.45669154379724e-31, 1.07913548335517e-07, 0.790951909489114, 
								 0.4616767008508, 5.071428439388e-34, 1.71283657584009e-22, 4.16140864163771e-05, 
								 0.593371940224573, 0.000191943180179595, 2.14084825992986e-14, 
								 0.999946506065279, 0.892383168051007, 0.962222052702789, 0.000308687445369302, 
								 7.77367905815501e-13, 0.939625876540545, 1, 0.82846777742723, 
								 0.957623726264489, 1.52013606113077e-17, 7.32882368258331e-08, 
								 1, 0.969632469978442, 0.992996457835904, 5.15672402385898e-31, 
								 0.998448703212174, 2.79665809328639e-05, 3.10558615257816e-10, 
								 6.21741178913761e-05, 0.999588702489581, 0.990861456869948, 0.670914637329063, 
								 0.0422403759953232, 9.41717431094589e-11, 7.6576060078287e-26, 
								 8.30396854049113e-11, 0.000642670372912764, 0.588706398646988, 
								 7.17836904098678e-23, 0.871733252799866, 0.344539554028747, 0.920416201197336, 
								 0.793331631463113, 0.502179051830747, 5.19188366828396e-15, 2.03594574491464e-09, 
								 0.892940466663428, 2.56768770369244e-14, 2.86487445893983e-08, 
								 5.99757678556631e-25, 0.916158057980285, 0.999999999986379, 0.759232048309398, 
								 0.0339251028988822, 0.443775934413605, 0.056921155821427, 2.56230185862894e-17, 
								 0.930049114541805, 0.0181357235994395, 0.998309895266863, 0.99999988376812, 
								 0.77248202972291, 0.00030072595812062, 0.999999999990548, 0.607996249074198, 
								 0.118390802701761, 0.000134237256720282, 0.130255875078478, 3.62418514393657e-10, 
								 0.999993179931523, 0.407036891489173, 0.811404706936343, 1.23629707221203e-08, 
								 0.89140638033423, 7.5358327175466e-17, 6.25112838023295e-21, 
								 0.802180126725945, 0.999934720729824, 0.00818177571372915, 0.637474391277441, 
								 0.449026981801898, 0.617128329160202, 1.5262285036145e-22, 0.751300465415345, 
								 0.994739722346991, 0.757063208382713, 1.48017286452976e-06, 3.54906584309543e-08, 
								 0.801246667258498, 0.191424573353656, 2.60176078028874e-17, 0.677754649090389, 
								 0.420074809213223, 0.998134211460249, 0.999944101332981, 4.61178085041482e-10, 
								 3.46906831931557e-06, 0.999999997177973, 0.977270632227283, 1.50741368029394e-07, 
								 0.910976879883071, 3.48347743011666e-05, 0.00252643674502132, 
								 0.999969971102248, 0.952582324249016, 5.22453983617341e-07, 1, 
								 0.987584460281668, 0.237182304708705, 0.0711308557052641, 0.216740459447895, 
								 0.00287610882354019, 0.438267056226004, 1.26008976389518e-08, 
								 6.73513415467982e-09, 0.013882928181265, 0.999999999964029, 4.2324580034407e-08, 
								 0.812684020745125, 0.999999999999973, 0.743391895862251, 5.60692729306083e-13, 
								 0.0247918165878212, 0.000886258814868248, 3.9781083203786e-18, 
								 0.997174899146715, 0.723862401350798, 1.21650102891647e-29, 0.950524012177701, 
								 9.28902270798497e-07, 0.0690744192508556, 0.59100727279369, 0.0425621062336364, 
								 0.981096961967537, 0.733285484477444, 2.25962380541409e-08, 0.900030497499811, 
								 0.847096836508985, 0.0081742541076562, 0.849340020014125, 0.99999936700721, 
								 0.77386551109244, 0.999999999999999, 0.213461143512257, 0.36653425430243, 
								 1.06231323053878e-22, 0.992131082222062, 0.140013184833686, 0.999998806042511, 
								 0.948717375216009, 0.999650566486324, 0.989155287575616, 0.296618211603971, 
								 0.934817497130265, 0.517764319192733, 0.0995050552045481, 0.932305375054405, 
								 7.16034989955017e-22, 0.999998073569632, 0.995239552044378, 0.940046417556632, 
								 0.903650990732173, 0.90538770743337, 1.26120647603352e-10, 4.21319788862685e-09, 
								 0.287989920395021, 0.639413075726044, 9.00155523128498e-10, 0.186471073295553, 
								 0.585786444050744, 0.0535262778303885, 2.01390561141894e-23, 
								 0.0147907688529111, 0.0117487928463762, 2.63762023278352e-07, 
								 0.999999994341181, 0.515160944286638, 6.79373977613914e-05, 0.707521803617004, 
								 1.13474811469539e-09, 0.318886687194778, 0.215445664415334, 0.64929185219056, 
								 0.808648826632142, 0.902273032547609, 8.13984080004834e-09, 0.906751646361401, 
								 0.279248286525212, 0.916480112092266, 2.98154116762794e-24, 0.338725521068769, 
								 2.49489933565727e-07, 0.0918288627064727, 0.0278299837153009, 
								 1, 1.3719112522345e-08, 0.394374989887691, 0.909717956474002, 
								 0.999999996970794, 0.57181858560216, 0.999999955691494, 0.264479205867989, 
								 0.814495548435982, 0.952571413650522, 0.00100887774932033, 3.24242225195756e-06, 
								 0.999079675941707, 0.0476649570392156, 0.006110604270577, 0.000829318695027932, 
								 0.602863427167416, 0.842798084835237, 0.00035066000058287, 6.9349260582509e-17, 
								 0.999999935985678, 0.000548609023831971, 0.492297401667468, 6.73826132231764e-12, 
								 0.508630696248293, 7.88164564582984e-14, 0.194870509156846, 0.622457256916845, 
								 0.317096940899838, 0.936595001742703, 8.57696995587304e-32, 0.941700534926994, 
								 0.986871850090105, 0.029207017080162, 0.00685348507905203, 3.85645434520806e-17, 
								 0.969456185472064, 1.19410316425304e-10, 2.32318116750573e-49, 
								 0.978523666785946, 9.29917004308063e-11, 0.243933275539275, 2.128402292997e-09, 
								 1.09537283727101e-23, 0.841398591514587, 2.08800671204325e-07, 
								 0.0094400513672983, 0.269398622993753, 1.07349195218027e-24, 
								 0.576305404028249, 1.16427554937837e-09, 0.000100595025515557, 
								 2.7092687108082e-08, 0.928198606205269, 0.978103063794744, 1.9064031651123e-05, 
								 0.787088439437554, 0.999999999977321, 0.785243463165017, 2.38170103978064e-06, 
								 1.47199105030891e-06, 0.00202997481655525, 2.70068237121535e-23, 
								 8.98479357705248e-08, 0.0101710690959973, 0.847162504965558, 
								 0.703112453153436, 0.707678137382658, 0.999999999570885, 0.00847834126376322, 
								 0.00160808131750474, 0.998721944435896, 0.936482504450213, 0.060357431896606, 
								 0.272042635378934, 0.923111974005615, 0.117297305902335, 0.567007380412194, 
								 0.945481948820637, 0.999376433069411, 0.93510379103946, 5.15633893546605e-19, 
								 0.995263927359555, 0.493854975070094, 3.23429459004904e-29, 0.876188139467389, 
								 0.468449016145959, 0.623314947573191, 0.998871928226713, 0.973325687247417, 
								 0.931527033432776, 0.839367433132549, 0.977434167633327, 2.54062182595754e-16, 
								 0.998997909718511, 0.000208013118155735, 1.30943928673986e-05, 
								 0.608012680996382, 4.91776209082745e-06, 3.81096156854161e-30, 
								 0.534250224369447, 2.65463672891136e-14, 0.827586300658421, 5.2019320809169e-07, 
								 0.996184383661204, 1.99932965621014e-21, 0.999999957436665, 0.953933034058462, 
								 0.224041935026037, 0.576568356184237, 0.89480061497678, 4.17892007269883e-13, 
								 0.0890248503571716, 1.01935198800158e-07, 0.263110521718921, 
								 0.915282758584832, 0.772192687890798, 1.08547156348606e-34, 0.956290985681132, 
								 1.83043234356613e-43, 0.0192257316297258, 0.9738774936414, 0.468783326228224, 
								 1, 0.909539353707079, 0.505055056552719, 7.91447202516043e-08, 
								 0.995370028904695, 0.252199878299199, 2.29918291125531e-08, 1.02733246000866e-08, 
								 0.855643145009783, 0.266511373540389, 0.752710813331891, 0.000471468056670209, 
								 0.765017529381362, 0.0233522782764749, 0.000286933380194221, 
								 0.968614919618147, 0.0673500939235816, 0.447536671573333, 0.311201103566898, 
								 0.997691504252783, 0.00352821406780143, 0.915447569126618, 0.200261436073223, 
								 0.538667750643057, 9.22789497702296e-21, 0.938939055761523, 0.998325696060267, 
								 0.429940484593749, 0.999541966166039, 5.19064579386067e-07, 0.933446321121202, 
								 0.999986322003754, 0.999991964135869, 1.06233348674139e-07, 1.21622641803192e-42, 
								 2.35397841184291e-15, 0.00167217088884053, 2.28856186909396e-08, 
								 0.975505719738271, 6.26188514418996e-07, 0.999943774620244, 0.365377398079567, 
								 0.000160429101069922, 0.912738006200719, 1.08270708566231e-06, 
								 0.961145752675066), .Dim = c(1L, 1000L))
	expect_equal(predict(rbm, test.dat[1,, drop=FALSE], drop = FALSE), expected_row1)
})


test_that("error.dbn errors if passed invalid data", {
	# Don't accept a vector
	expect_error(predict(trained.mnist, test.dat[1,, drop = TRUE]))
	expect_error(predict(pretrained.mnist[[1]], test.dat[1,, drop = TRUE]))
	
	# Don't accept wrong dimensions
	expect_error(predict(trained.mnist, test.dat[, 1:20, drop = FALSE]), regexp = "column")
	expect_error(predict(trained.mnist[[1]], test.dat[, 1:20, drop = FALSE]), regexp = "column")
})
xrobin/DeepLearning documentation built on Sept. 18, 2020, 5:23 a.m.