Nothing
knitr::opts_chunk$set(echo = TRUE, eval = FALSE)
First, we need to install blurr module
for Transformers integration.
reticulate::py_install('https://github.com/ohmeow/blurr',pip = TRUE)
Grab data and take 1 % for fast training:
library(fastai) library(magrittr) library(zeallot) df = HF_load_dataset('civil_comments', split='train[:1%]')
Select multiple outputs/columns:
df = data.table::as.data.table(df) lbl_cols = c('severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit') df <- df[,(lbl_cols) := round(.SD,0), .SDcols=lbl_cols] df <- df[, (lbl_cols) := lapply(.SD, as.integer), .SDcols=lbl_cols]
Load distill RoBERTa:
task = HF_TASKS_ALL()$SequenceClassification pretrained_model_name = "distilroberta-base" config = AutoConfig()$from_pretrained(pretrained_model_name) config$num_labels = length(lbl_cols) c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name, task=task, config=config)
Downloading: 100%|██████████| 899k/899k [00:00<00:00, 961kB/s] Downloading: 100%|██████████| 456k/456k [00:00<00:00, 597kB/s] Downloading: 100%|██████████| 331M/331M [03:26<00:00, 1.61MB/s]
Create data blocks:
blocks = list( HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer), MultiCategoryBlock(encoded=TRUE, vocab=lbl_cols) ) dblock = DataBlock(blocks=blocks, get_x=ColReader('text'), get_y=ColReader(lbl_cols), splitter=RandomSplitter()) dls = dblock %>% dataloaders(df, bs=8) dls %>% one_batch()
[[1]] [[1]]$input_ids tensor([[ 0, 24268, 5257, ..., 1, 1, 1], [ 0, 287, 4505, ..., 1, 1, 1], [ 0, 38, 437, ..., 1, 1, 1], ..., [ 0, 152, 1129, ..., 1, 1, 1], [ 0, 85, 18, ..., 1, 1, 1], [ 0, 22014, 31, ..., 1, 1, 1]], device='cuda:0') [[1]]$attention_mask tensor([[1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], ..., [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0]], device='cuda:0') [[2]] TensorMultiCategory([[0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.]], device='cuda:0')
model = HF_BaseModelWrapper(hf_model) learn = Learner(dls, model, opt_func=partial(Adam), loss_func=BCEWithLogitsLossFlat(), metrics=partial(accuracy_multi(), thresh=0.2), cbs=HF_BaseModelCallback(), splitter=hf_splitter()) learn$loss_func$thresh = 0.2 learn$create_opt() # -> will create your layer groups based on your "splitter" function learn$freeze() learn %>% summary()
See summary:
epoch train_loss valid_loss accuracy_multi time ------ ----------- ----------- --------------- ------ HF_BaseModelWrapper (Input shape: 8 x 391) ================================================================ Layer (type) Output Shape Param # Trainable ================================================================ Embedding 8 x 391 x 768 38,603,520 False ________________________________________________________________ Embedding 8 x 391 x 768 394,752 False ________________________________________________________________ Embedding 8 x 391 x 768 768 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 768 590,592 True ________________________________________________________________ Dropout 8 x 768 0 False ________________________________________________________________ Linear 8 x 6 4,614 True ________________________________________________________________ Total params: 82,123,014 Total trainable params: 615,174 Total non-trainable params: 81,507,840 Optimizer used: functools.partial(<function make_python_function.<locals>.python_function at 0x7fee7e8166a8>) Loss function: FlattenedLoss of BCEWithLogitsLoss() Model frozen up to parameter group #2 Callbacks: - TrainEvalCallback - Recorder - ProgressCallback - HF_BaseModelCallback
Finally, fit the model:
lrs = learn %>% lr_find(suggestions=TRUE) learn %>% fit_one_cycle(1, lr_max=1e-2)
epoch train_loss valid_loss accuracy_multi time ------ ----------- ----------- --------------- ------ 0 0.040617 0.034286 0.993257 01:21
Predict:
learn$loss_func$thresh = 0.02 learn %>% predict("Those damned affluent white people should only eat their own food, like cod cakes and boiled potatoes. No enchiladas for them!")
$probabilities severe_toxicity obscene threat insult identity_attack sexual_explicit 1 9.302437e-07 0.004268706 0.0007849637 0.02687055 0.003282947 0.00232468 $labels [1] "insult"
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.