Handle the the different format of inputs and outputs between fastai and transformers

FakeLearner Class just for Test

class FakeLearner():
    def __init__(self, cb, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)
        cb.learn = self
        self.cb = cb
    
    def run_cb(self, event_name):
        getattr(self.cb, event_name)()

GPT2LMHeadCallback

class GPT2LMHeadCallback[source]

GPT2LMHeadCallback() :: Callback

Basic class handling tweaks of the training loop by changing a Learner in various events

learn = FakeLearner(cb=GPT2LMHeadCallback(), pred=('last_hidden_state', 'past'))
learn.run_cb('after_pred')
test_eq(learn.pred, 'last_hidden_state')

BertSeqClassificationCallback

class BertSeqClassificationCallback[source]

BertSeqClassificationCallback(pad_id:int) :: Callback

It should be ok to use it in all Bert like model. eg: Roberta

input_ids = torch.tensor([[4, 3, 1, 1], 
                          [5, 6, 7, 1]])
attention_mask = torch.tensor([[1, 1, 0, 0], 
                               [1, 1, 1, 0]])

learn = FakeLearner(cb=BertSeqClassificationCallback(pad_id=1), xb=(input_ids,))
learn.run_cb('begin_batch')
test_eq(learn.xb, (input_ids, attention_mask))

learn = FakeLearner(cb=BertSeqClassificationCallback(pad_id=1), pred=('logits',))
learn.run_cb('after_pred')
test_eq(learn.pred, 'logits')