Source code for tcr_deep_insight.model.modeling_bert._trainer

from collections import Counter
import torch
from torch import nn
from torch.nn import functional as F
import json
import tqdm
from abc import ABC
import datasets
from typing import Optional, Union, Tuple, Mapping
from transformers import PreTrainedModel
import numpy as np
from pathlib import Path
from tokenizers import Tokenizer
import pandas as pd

from scatlasvae.utils._tensor_utils import one_hot

from ._collator import TRabCollatorForVJCDR3, AminoAcidsCollator
from ._model import TRabModelingBertForPseudoSequence, TRabModelingBertForVJCDR3
from .._model_utils import EarlyStopping

class TrainerBase(ABC):
    def attach_train_dataset(self, train_dataset: Optional[datasets.Dataset]):
        self.train_dataset = train_dataset

    def attach_test_dataset(self, test_dataset: Optional[datasets.Dataset]):
        self.test_dataset = test_dataset

[docs] class TRabModelingBertForVJCDR3Trainer(TrainerBase): def __init__( self, model: TRabModelingBertForVJCDR3, collator: TRabCollatorForVJCDR3, train_dataset: Optional[datasets.Dataset] = None, test_dataset: Optional[datasets.Dataset] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), loss_weight: Mapping[str, float] = {"reconstruction_loss": 1}, device: str = 'cuda' ) -> None: self.model = model self.optimizer, self.scheduler = optimizers self.device = device if self.device != model.device: self.device = model.device self.collator = collator self.loss_weight = loss_weight self.attach_train_dataset(train_dataset) self.attach_test_dataset(test_dataset)
[docs] def fit( self, *, max_epoch: int, max_train_sequence: int = 0, n_per_batch: int = 10, shuffle: bool = False, balance_label: bool = False, label_weight: Optional[torch.Tensor] = None, show_progress: bool = False, early_stopping: bool = False ): """Main training entry point""" loss = [] self.model.train() if early_stopping: early_stopper = EarlyStopping() if max_train_sequence == 0: max_train_sequence = len(self.train_dataset) if shuffle: self.train_dataset.shuffle() for i in range(1, max_epoch + 1): epoch_total_loss = [] if show_progress: pbar = tqdm.tqdm(total=max_train_sequence // n_per_batch) for j in range(0, max_train_sequence, n_per_batch): self.optimizer.zero_grad() epoch_indices = np.array( self.train_dataset[j:j+n_per_batch]['input_ids'] ) epoch_attention_mask = np.array( self.train_dataset[j:j+n_per_batch]['attention_mask'] ) if 'cdr3_mr_mask' in self.train_dataset.features.keys(): cdr3_mr_mask = np.array( self.train_dataset[j:j+n_per_batch]['cdr3_mr_mask'] ) else: cdr3_mr_mask = self.train_dataset[j:j+n_per_batch]['attention_mask'] epoch_token_type_ids = torch.tensor( self.train_dataset[j:j+n_per_batch]['token_type_ids'], dtype=torch.int64 ).to(self.device) epoch_indices_mlm, epoch_attention_mask_mlm = self.collator( epoch_indices, epoch_attention_mask, cdr3_mr_mask ) # print(tokenizer.convert_ids_to_tokens(torch.tensor(epoch_indices_mlm))) epoch_indices_mlm = epoch_indices_mlm.to(self.device) epoch_attention_mask_mlm = epoch_attention_mask_mlm.to(self.device) epoch_label_ids = None if 'tcr_label' in self.train_dataset.features.keys(): epoch_label_ids = torch.tensor( self.train_dataset[j:j+n_per_batch]['tcr_label'], dtype=torch.int64 ).to(self.device) if balance_label: label_counter = Counter(epoch_label_ids.cpu().numpy()) min_label_number = np.min(list(label_counter.values())) min_label = {v:k for k,v in label_counter.items()}[min_label_number] min_indices = np.argwhere(epoch_label_ids.cpu().numpy() == min_label).squeeze() for k in label_counter.keys(): if k != min_label: indices = np.argwhere(epoch_label_ids.cpu().numpy() == k).squeeze() np.random.shuffle(indices) indices = indices[:min_label_number] min_indices = np.concatenate([min_indices, indices]) epoch_indices_mlm = epoch_indices_mlm[min_indices] epoch_attention_mask_mlm = epoch_attention_mask_mlm[min_indices] epoch_label_ids = epoch_label_ids[min_indices] epoch_token_type_ids = epoch_token_type_ids[min_indices] epoch_indices = epoch_indices[min_indices] self.optimizer.zero_grad() output = self.model.forward( input_ids = epoch_indices_mlm, attention_mask = epoch_attention_mask_mlm, token_type_ids = epoch_token_type_ids, labels = torch.tensor(epoch_indices, dtype=torch.int64).to(self.device) ) if epoch_label_ids is not None: prediction_loss = nn.BCEWithLogitsLoss( weight=label_weight.to(self.device) )( output["prediction_out"], one_hot(epoch_label_ids.unsqueeze(1), self.model.labels_number) ) else: prediction_loss = torch.tensor(0.) total_loss = output["output"].loss * self.loss_weight['reconstruction_loss'] + prediction_loss self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() epoch_total_loss.append( total_loss.item() ) if show_progress: pbar.update(1) pbar.set_postfix({ f'Learning rate': self.optimizer.param_groups[0]['lr'], f'Loss': np.mean(epoch_total_loss[-4:]), f'Prediction': prediction_loss.item() }) if self.scheduler is not None: self.scheduler.step(np.mean(epoch_total_loss)) if early_stopping: early_stopper(np.mean(epoch_total_loss)) if early_stopper.early_stop: print("Early stopping at epoch {}".format(i)) break if show_progress: pbar.close() print("epoch {} | total loss {:.2e}".format(i, np.mean(epoch_total_loss))) loss.append(np.mean(epoch_total_loss)) self.model.train(False) return loss
[docs] def evaluate( self, *, n_per_batch: int = 10, max_train_sequence: int = None, max_test_sequence: int = None, show_progress: bool = False ): """Main training entry point""" self.model.eval() ### Train all_train_result = { "aa": [], "aa_pred": [], "aa_gt": [], "av": [], "bv": [] } if self.train_dataset is not None: max_train_sequence = len(self.train_dataset) if max_train_sequence is None else max_train_sequence if show_progress: pbar = tqdm.tqdm(total=max_train_sequence // n_per_batch) for j in range(0, max_train_sequence, n_per_batch): self.optimizer.zero_grad() epoch_indices = np.array( self.train_dataset[j:j+n_per_batch]['input_ids'] ) epoch_attention_mask = np.array( self.train_dataset[j:j+n_per_batch]['attention_mask'] ) if 'cdr3_mr_mask' in self.train_dataset.features.keys(): cdr3_mr_mask = np.array( self.train_dataset[j:j+n_per_batch]['cdr3_mr_mask'] ) else: cdr3_mr_mask = self.train_dataset[j:j+n_per_batch]['attention_mask'] epoch_token_type_ids = torch.tensor( self.train_dataset[j:j+n_per_batch]['token_type_ids'], dtype=torch.int64 ).to(self.device) epoch_indices_mlm, epoch_attention_mask_mlm = self.collator( epoch_indices, epoch_attention_mask, cdr3_mr_mask ) # print(tokenizer.convert_ids_to_tokens(torch.tensor(epoch_indices_mlm))) epoch_indices_mlm = epoch_indices_mlm.to(self.device) epoch_attention_mask_mlm = epoch_attention_mask_mlm.to(self.device) epoch_label_ids = None if 'tcr_label' in self.train_dataset.features.keys(): epoch_label_ids = torch.tensor( self.train_dataset[j:j+n_per_batch]['tcr_label'], dtype=torch.int64 ).to(self.device) self.optimizer.zero_grad() output = self.model.forward( input_ids = epoch_indices_mlm, attention_mask = epoch_attention_mask_mlm, token_type_ids = epoch_token_type_ids, # tcr_label_ids = epoch_label_ids, labels = torch.tensor(epoch_indices, dtype=torch.int64).to(self.device) ) predictions = output['output']['logits'].topk(1)[1].squeeze() evaluate_mask = epoch_indices_mlm == 22 ground_truth = torch.tensor(epoch_indices, dtype=torch.int64).to(self.device) evaluate_mask_trav = torch.zeros_like(evaluate_mask, dtype=torch.bool) evaluate_mask_trav[:, 0] = True evaluate_mask_trbv = torch.zeros_like(evaluate_mask, dtype=torch.bool) evaluate_mask_trbv[:, 48] = True evaluate_mask[:, [0, 48]] = False result = (predictions == ground_truth)[evaluate_mask].detach().cpu().numpy() all_train_result['aa'].append(result) all_train_result['aa_pred'].append(predictions[evaluate_mask].detach().cpu().numpy()) all_train_result['aa_gt'].append(ground_truth[evaluate_mask].detach().cpu().numpy()) result = (predictions == ground_truth)[evaluate_mask_trav].detach().cpu().numpy() all_train_result['av'].append(result) result = (predictions == ground_truth)[evaluate_mask_trbv].detach().cpu().numpy() all_train_result['bv'].append(result) if show_progress: pbar.update(1) if show_progress: pbar.close() ### Test all_test_result = { "aa": [], "av": [], "bv": [] } if self.test_dataset is not None: if show_progress: pbar = tqdm.tqdm(total=max_test_sequence // n_per_batch) max_test_sequence = len(self.test_dataset) if max_test_sequence is None else max_test_sequence for j in range(0, max_test_sequence, n_per_batch): self.optimizer.zero_grad() epoch_indices = np.array( self.test_dataset[j:j+n_per_batch]['input_ids'] ) epoch_attention_mask = np.array( self.test_dataset[j:j+n_per_batch]['attention_mask'] ) if 'cdr3_mr_mask' in self.test_dataset.features.keys(): cdr3_mr_mask = np.array( self.test_dataset[j:j+n_per_batch]['cdr3_mr_mask'] ) else: cdr3_mr_mask = self.test_dataset[j:j+n_per_batch]['attention_mask'] epoch_token_type_ids = torch.tensor( self.test_dataset[j:j+n_per_batch]['token_type_ids'], dtype=torch.int64 ).to(self.device) epoch_indices_mlm, epoch_attention_mask_mlm = self.collator( epoch_indices, epoch_attention_mask, cdr3_mr_mask ) # print(tokenizer.convert_ids_to_tokens(torch.tensor(epoch_indices_mlm))) epoch_indices_mlm = epoch_indices_mlm.to(self.device) epoch_attention_mask_mlm = epoch_attention_mask_mlm.to(self.device) epoch_label_ids = None if 'tcr_label' in self.test_dataset.features.keys(): epoch_label_ids = torch.tensor( self.test_dataset[j:j+n_per_batch]['tcr_label'], dtype=torch.int64 ).to(self.device) self.optimizer.zero_grad() output = self.model.forward( input_ids = epoch_indices_mlm, attention_mask = epoch_attention_mask_mlm, token_type_ids = epoch_token_type_ids, # tcr_label_ids = epoch_label_ids, labels = torch.tensor(epoch_indices, dtype=torch.int64).to(self.device) ) predictions = output['output']['logits'].topk(1)[1].squeeze() evaluate_mask = epoch_indices_mlm == 22 ground_truth = torch.tensor(epoch_indices, dtype=torch.int64).to(self.device) evaluate_mask_trav = torch.zeros_like(evaluate_mask, dtype=torch.bool) evaluate_mask_trav[:, 0] = True evaluate_mask_trbv = torch.zeros_like(evaluate_mask, dtype=torch.bool) evaluate_mask_trbv[:, 48] = True evaluate_mask[:, [0, 48]] = False result = (predictions == ground_truth)[evaluate_mask].detach().cpu().numpy() all_test_result['aa'].append(result) result = (predictions == ground_truth)[evaluate_mask_trav].detach().cpu().numpy() all_test_result['av'].append(result) result = (predictions == ground_truth)[evaluate_mask_trbv].detach().cpu().numpy() all_test_result['bv'].append(result) if show_progress: pbar.update(1) if show_progress: pbar.close() return all_train_result, all_test_result
class TRabModelingBertForPseudoSequenceTrainer(TrainerBase): def __init__( self, model: TRabModelingBertForPseudoSequence, collator: AminoAcidsCollator, train_dataset: Optional[Union[datasets.Dataset, pd.DataFrame]] = None, test_dataset: Optional[Union[datasets.Dataset, pd.DataFrame]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), loss_weight: Mapping[str, float] = {"reconstruction_loss": 1}, device: str = 'cuda' ) -> None: self.model = model self.optimizer, self.scheduler = optimizers self.device = device if self.device != model.device: self.device = model.device self.collator = collator self.loss_weight = loss_weight self.attach_train_dataset(train_dataset) self.attach_test_dataset(test_dataset) def fit( self, *, max_epoch: int, n_per_batch: int = 10, shuffle: bool = False, show_progress: bool = False, early_stopping: bool = False, cache_file: Optional[Path] = None ): loss = [] self.model.train() if early_stopping: early_stopper = EarlyStopping() max_train_sequence = len(self.train_dataset) if shuffle: self.train_dataset.shuffle() for i in range(1, max_epoch + 1): epoch_total_loss = [] if show_progress: pbar = tqdm.tqdm(total=max_train_sequence // n_per_batch) for j in range(0, max_train_sequence, n_per_batch): self.optimizer.zero_grad() epoch_indices = np.array( self.train_dataset[j:j+n_per_batch]['input_ids'] ) epoch_attention_mask = np.array( self.train_dataset[j:j+n_per_batch]['attention_mask'] ) epoch_indices_mlm, epoch_attention_mask_mlm = self.collator( epoch_indices, epoch_attention_mask ) # print(tokenizer.convert_ids_to_tokens(torch.tensor(epoch_indices_mlm))) epoch_indices_mlm = epoch_indices_mlm.to(self.device) epoch_attention_mask_mlm = epoch_attention_mask_mlm.to(self.device) epoch_label_ids = None if 'tcr_label' in self.train_dataset.features.keys(): epoch_label_ids = torch.tensor( self.train_dataset[j:j+n_per_batch]['tcr_label'], dtype=torch.int64 ).to(self.device) self.optimizer.zero_grad() output = self.model.forward( input_ids = epoch_indices_mlm, attention_mask = epoch_attention_mask_mlm, labels = torch.tensor( epoch_indices, dtype=torch.int64 ).to(self.device) ) if epoch_label_ids is not None: prediction_loss = nn.BCEWithLogitsLoss()( output["prediction_out"], one_hot(epoch_label_ids.unsqueeze(1), self.model.labels_number) ) else: prediction_loss = torch.tensor(0.) total_loss = output["output"].loss * self.loss_weight['reconstruction_loss'] + prediction_loss self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() epoch_total_loss.append( total_loss.item() ) if show_progress: pbar.update(1) pbar.set_postfix({ f'Learning rate': self.optimizer.param_groups[0]['lr'], f'Loss': np.mean(epoch_total_loss[-4:]), f'Prediction': prediction_loss.item() }) if self.scheduler is not None: self.scheduler.step(np.mean(epoch_total_loss)) if early_stopping: early_stopper(np.mean(epoch_total_loss)) if early_stopper.early_stop: print("Early stopping at epoch {}".format(i)) break if show_progress: pbar.close() print("epoch {} | total loss {:.2e}".format(i, np.mean(epoch_total_loss))) loss.append(np.mean(epoch_total_loss)) self.model.train(False) return loss def evaluate( self, *, n_per_batch: int = 10, max_train_sequence: int = None, max_test_sequence: int = None, show_progress: bool = False ): self.model.eval() ### Train all_train_result = { "cdr1a_aa": [0,0], "cdr2a_aa": [0,0], "cdr3a_aa": [0,0], "cdr1b_aa": [0,0], "cdr2b_aa": [0,0], "cdr3b_aa": [0,0], } if self.train_dataset is not None: max_train_sequence = len(self.train_dataset) if max_train_sequence is None else max_train_sequence if show_progress: pbar = tqdm.tqdm(total=max_train_sequence // n_per_batch) for j in range(0, max_train_sequence, n_per_batch): epoch_indices = np.array( self.train_dataset[j:j+n_per_batch]['input_ids'] ) epoch_attention_mask = np.array( self.train_dataset[j:j+n_per_batch]['attention_mask'] ) epoch_indices_mlm, epoch_attention_mask_mlm = self.collator( epoch_indices, epoch_attention_mask ) epoch_indices_mlm = epoch_indices_mlm.to(self.device) epoch_attention_mask_mlm = epoch_attention_mask_mlm.to(self.device) output = self.model.forward( input_ids = epoch_indices_mlm, attention_mask = epoch_attention_mask_mlm, labels = torch.tensor( epoch_indices, dtype=torch.int64 ).to(self.device) ) predictions = output['output']['logits'].topk(1)[1].squeeze() evaluate_mask = epoch_indices_mlm == 4 ground_truth = torch.tensor(epoch_indices, dtype=torch.int64).to(self.device) ground_truth_cdr1a = ground_truth[:,1:9] predictions_cdr1a = predictions[:,1:9] cdr1a_prediction = ( ((ground_truth_cdr1a == predictions_cdr1a) * evaluate_mask[:,1:9].detach()).sum(), evaluate_mask[:,1:9].sum() ) ground_truth_cdr2a = ground_truth[:,9:18] predictions_cdr2a = predictions[:,9:18] cdr2a_prediction = ( ((ground_truth_cdr2a == predictions_cdr2a) * evaluate_mask[:,9:18].detach()).sum(), evaluate_mask[:,9:18].sum() ) ground_truth_cdr3a = ground_truth[:,18:55] predictions_cdr3a = predictions[:,18:55] cdr3a_prediction = ( ((ground_truth_cdr3a == predictions_cdr3a) * evaluate_mask[:,18:55].detach()).sum(), evaluate_mask[:,18:55].sum() ) ground_truth_cdr1b = ground_truth[:,55:64] predictions_cdr1b = predictions[:,55:64] cdr1b_prediction = ( ((ground_truth_cdr1b == predictions_cdr1b) * evaluate_mask[:,55:64].detach()).sum(), evaluate_mask[:,55:64].sum() ) ground_truth_cdr2b = ground_truth[:,64:73] predictions_cdr2b = predictions[:,64:73] cdr2b_prediction = ( ((ground_truth_cdr2b == predictions_cdr2b) * evaluate_mask[:,64:73].detach()).sum(), evaluate_mask[:,64:73].sum() ) ground_truth_cdr3b = ground_truth[:,73:110] predictions_cdr3b = predictions[:,73:110] cdr3b_prediction = ( ((ground_truth_cdr3b == predictions_cdr3b) * evaluate_mask[:,73:110].detach()).sum(), evaluate_mask[:,73:110].sum() ) all_train_result['cdr1a_aa'][0] += cdr1a_prediction[0].item() all_train_result['cdr1a_aa'][1] += cdr1a_prediction[1].item() all_train_result['cdr2a_aa'][0] += cdr2a_prediction[0].item() all_train_result['cdr2a_aa'][1] += cdr2a_prediction[1].item() all_train_result['cdr3a_aa'][0] += cdr3a_prediction[0].item() all_train_result['cdr3a_aa'][1] += cdr3a_prediction[1].item() all_train_result['cdr1b_aa'][0] += cdr1b_prediction[0].item() all_train_result['cdr1b_aa'][1] += cdr1b_prediction[1].item() all_train_result['cdr2b_aa'][0] += cdr2b_prediction[0].item() all_train_result['cdr2b_aa'][1] += cdr2b_prediction[1].item() all_train_result['cdr3b_aa'][0] += cdr3b_prediction[0].item() all_train_result['cdr3b_aa'][1] += cdr3b_prediction[1].item() if show_progress: pbar.update(1) if show_progress: pbar.close() ### Test all_test_result = { "cdr1a_aa": [0,0], "cdr2a_aa": [0,0], "cdr3a_aa": [0,0], "cdr1b_aa": [0,0], "cdr2b_aa": [0,0], "cdr3b_aa": [0,0], } if self.test_dataset is not None: max_test_sequence = len(self.test_dataset) if max_test_sequence is None else max_test_sequence if show_progress: pbar = tqdm.tqdm(total=max_test_sequence // n_per_batch) for j in range(0, max_test_sequence, n_per_batch): epoch_indices = np.array( self.test_dataset[j:j+n_per_batch]['input_ids'] ) epoch_attention_mask = np.array( self.test_dataset[j:j+n_per_batch]['attention_mask'] ) epoch_indices_mlm, epoch_attention_mask_mlm = self.collator( epoch_indices, epoch_attention_mask ) epoch_indices_mlm = epoch_indices_mlm.to(self.device) epoch_attention_mask_mlm = epoch_attention_mask_mlm.to(self.device) output = self.model.forward( input_ids = epoch_indices_mlm, attention_mask = epoch_attention_mask_mlm, labels = torch.tensor( epoch_indices, dtype=torch.int64 ).to(self.device) ) predictions = output['output']['logits'].topk(1)[1].squeeze() evaluate_mask = epoch_indices_mlm == 4 ground_truth = torch.tensor(epoch_indices, dtype=torch.int64).to(self.device) ground_truth_cdr1a = ground_truth[:,1:9] predictions_cdr1a = predictions[:,1:9] cdr1a_prediction = ( ((ground_truth_cdr1a == predictions_cdr1a) * evaluate_mask[:,1:9].detach()).sum(), evaluate_mask[:,1:9].sum() ) ground_truth_cdr2a = ground_truth[:,9:18] predictions_cdr2a = predictions[:,9:18] cdr2a_prediction = ( ((ground_truth_cdr2a == predictions_cdr2a) * evaluate_mask[:,9:18].detach()).sum(), evaluate_mask[:,9:18].sum() ) ground_truth_cdr3a = ground_truth[:,18:55] predictions_cdr3a = predictions[:,18:55] cdr3a_prediction = ( ((ground_truth_cdr3a == predictions_cdr3a) * evaluate_mask[:,18:55].detach()).sum(), evaluate_mask[:,18:55].sum() ) ground_truth_cdr1b = ground_truth[:,55:64] predictions_cdr1b = predictions[:,55:64] cdr1b_prediction = ( ((ground_truth_cdr1b == predictions_cdr1b) * evaluate_mask[:,55:64].detach()).sum(), evaluate_mask[:,55:64].sum() ) ground_truth_cdr2b = ground_truth[:,64:73] predictions_cdr2b = predictions[:,64:73] cdr2b_prediction = ( ((ground_truth_cdr2b == predictions_cdr2b) * evaluate_mask[:,64:73].detach()).sum(), evaluate_mask[:,64:73].sum() ) ground_truth_cdr3b = ground_truth[:,73:110] predictions_cdr3b = predictions[:,73:110] cdr3b_prediction = ( ((ground_truth_cdr3b == predictions_cdr3b) * evaluate_mask[:,73:110].detach()).sum(), evaluate_mask[:,73:110].sum() ) all_test_result['cdr1a_aa'][0] += cdr1a_prediction[0].item() all_test_result['cdr1a_aa'][1] += cdr1a_prediction[1].item() all_test_result['cdr2a_aa'][0] += cdr2a_prediction[0].item() all_test_result['cdr2a_aa'][1] += cdr2a_prediction[1].item() all_test_result['cdr3a_aa'][0] += cdr3a_prediction[0].item() all_test_result['cdr3a_aa'][1] += cdr3a_prediction[1].item() all_test_result['cdr1b_aa'][0] += cdr1b_prediction[0].item() all_test_result['cdr1b_aa'][1] += cdr1b_prediction[1].item() all_test_result['cdr2b_aa'][0] += cdr2b_prediction[0].item() all_test_result['cdr2b_aa'][1] += cdr2b_prediction[1].item() all_test_result['cdr3b_aa'][0] += cdr3b_prediction[0].item() all_test_result['cdr3b_aa'][1] += cdr3b_prediction[1].item() if show_progress: pbar.update(1) if show_progress: pbar.close() return all_train_result, all_test_result