Handle the the different format of inputs and outputs between fastai and transformers
FakeLearner Class just for Test¶
class FakeLearner():
def __init__(self, cb, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
cb.learn = self
self.cb = cb
def run_cb(self, event_name):
getattr(self.cb, event_name)()
GPT2LMHeadCallback¶
learn = FakeLearner(cb=GPT2LMHeadCallback(), pred=('last_hidden_state', 'past'))
learn.run_cb('after_pred')
test_eq(learn.pred, 'last_hidden_state')
BertSeqClassificationCallback¶
input_ids = torch.tensor([[4, 3, 1, 1],
[5, 6, 7, 1]])
attention_mask = torch.tensor([[1, 1, 0, 0],
[1, 1, 1, 0]])
learn = FakeLearner(cb=BertSeqClassificationCallback(pad_id=1), xb=(input_ids,))
learn.run_cb('begin_batch')
test_eq(learn.xb, (input_ids, attention_mask))
learn = FakeLearner(cb=BertSeqClassificationCallback(pad_id=1), pred=('logits',))
learn.run_cb('after_pred')
test_eq(learn.pred, 'logits')