Source code for tcr_deep_insight.model.modeling_bert._model

# Pytorch
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.distributions import kl_divergence as kld
from torch.optim.lr_scheduler import ReduceLROnPlateau


# Hugginface Transformers
from transformers import (
    BertConfig,
    PreTrainedModel,
    BertForMaskedLM
)

# Third Party Transformers
from .externals.bert_layers import (
    BertForMaskedLM as TritonBertForMaskedLM
)

# Third Party packages
from sklearn.neighbors import KNeighborsClassifier
from joblib import dump, load
import datasets
import numpy as np
from pathlib import Path
import umap
import numpy as np
import pandas as pd
import scanpy as sc
from scipy.sparse import issparse
from sklearn.utils import class_weight

# Built-in
import time
from collections import Counter
from itertools import chain
from copy import deepcopy
import json
from typing import Callable, Mapping, Union, Iterable, Tuple, Optional, Mapping
import os
import warnings

# Friendly dependencites
from scatlasvae.model._primitives import *
from scatlasvae.model._primitives import Linear, FCLayer, SAE
from scatlasvae.utils._loss import LossFunction
from scatlasvae.utils._parallelizer import Parallelizer

# Package
from ..tokenizers._tokenizer import TRabTokenizerForVJCDR3
from ._config import get_config
from ._layers import (
    OuterProductMean,
    TriangleAttentionStartingNode,
    TriangleAttentionEndingNode,
    PairTransition,
    JointEmbedder,
)
from .._model_utils import add, ModelOutput
from ...utils._logger import mt, Colors, get_tqdm
from ...utils._compat import Literal


MODULE_PATH = Path(__file__).parent
warnings.filterwarnings("ignore")


class ModuleBase(nn.Module):
    def to(self, device:str):
        super(ModuleBase, self).to(device)
        self.device=device
        return self
    

[docs] class TRabModelingBertForVJCDR3(ModuleBase): def __init__(self, bert_config: BertConfig, pooling: Literal["cls", "mean", "max", "pool", "trb", "tra", "weighted"] = "mean", pooling_cls_position: int = 0, pooling_weight = (0.1,0.9), labels_number: int = 1, device = "cuda", ) -> None: """ TRABModel is a BERT model that takes in a abTCR sequence :param bert_config: BertConfig :param pooling: Pooling method, one of "cls", "mean", "max", "pool", "trb", "tra", "weighted" :param pooling_cls_position: Position of the cls token :param pooling_weight: Weight of the cls token :param hidden_layers: Hidden layers of the classifier :param labels_number: Number of labels :example: >>> from t_deep_insight as tdi >>> model = tdi.model.TCRModel( >>> tdi.model.config.get_human_config(), >>> labels_number=4 >>> ) """ super(TRabModelingBertForVJCDR3, self).__init__() self.model = BertForMaskedLM(bert_config) self.pooler = nn.Sequential( nn.Linear(bert_config.hidden_size,bert_config.hidden_size), nn.Tanh() ) self.pooling = pooling self.pooling_cls_position = pooling_cls_position self.pooling_weight = pooling_weight self.config = bert_config self.labels_number = labels_number self.classifier = nn.Sequential( nn.Dropout(0.1), Linear(bert_config.hidden_size, labels_number, init='final') ) self.device = device self.to(device) def __repr__(self): return f'{Colors.GREEN}TRABModel{Colors.NC} object containing:\n' + \ f' bert_config: {self.config}\n' + \ f' pooling: {self.pooling}\n' + \ f' pooling_cls_position: {self.pooling_cls_position}\n' + \ f' pooling_weight: {self.pooling_weight}\n' + \ f' labels_number: {self.labels_number}\n'
[docs] def forward(self, *, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor, token_type_ids: torch.Tensor, output_hidden_states = True, ): ''' Forward pass of the model :param input_ids: Input ids :param attention_mask: Attention mask :param labels: Labels :param token_type_ids: Token type ids :return: Output of the model ''' output = self.model.forward( input_ids = input_ids, attention_mask = attention_mask, return_dict = True, token_type_ids = token_type_ids, labels = labels, output_hidden_states = output_hidden_states ) hidden_states = None hidden_states_length = int(output.hidden_states[-1].shape[1]/2) if output_hidden_states: if self.pooling == "mean": hidden_states = output.hidden_states[-1][:,1:,:].mean(1) elif self.pooling == "max": hidden_states = output.hidden_states[-1][:,1:,:].max(1)[0] elif self.pooling == "cls": hidden_states = output.hidden_states[-1][:,self.pooling_cls_position,:] elif self.pooling == 'pool': hidden_states = self.pooler(output.hidden_states[-1][:,self.pooling_cls_position,:]) elif self.pooling == 'tra': if self.pooling_cls_position == 1: hidden_states = output.hidden_states[-1][ :, torch.hstack([ torch.tensor([self.pooling_cls_position]), torch.arange(2,hidden_states_length), ]), : ].mean(1) + output.hidden_states[-1][:, 0] + output.hidden_states[-1][:, hidden_states_length] else: hidden_states = output.hidden_states[-1][ :, torch.arange(1,hidden_states_length), : ].mean(1) elif self.pooling == 'trb': if self.pooling_cls_position == 1: hidden_states = output.hidden_states[-1][ :, torch.hstack([ torch.tensor([self.pooling_cls_position]), torch.arange(hidden_states_length+2,hidden_states_length*2), ]), : ].mean(1) + output.hidden_states[-1][:, 0] + output.hidden_states[-1][:, hidden_states_length] else: hidden_states = output.hidden_states[-1][ :, torch.arange(hidden_states_length+1,hidden_states_length*2), : ].mean(1) elif self.pooling == 'weighted': if self.pooling_cls_position == 1: hidden_states = (output.hidden_states[-1][ :, torch.hstack([ torch.tensor([self.pooling_cls_position]), torch.arange(2,hidden_states_length) ]), : ] * self.pooling_weight[0] + output.hidden_states[-1][ :, torch.hstack([ torch.tensor([self.pooling_cls_position]), torch.arange(hidden_states_length+2,hidden_states_length*2), ]), : ] * self.pooling_weight[1]).mean(1) + output.hidden_states[-1][:, 0] + output.hidden_states[-1][:, hidden_states_length] else: hidden_states = (output.hidden_states[-1][ :, torch.arange(hidden_states_length+1,hidden_states_length*2), : ] * self.pooling_weight[0] + output.hidden_states[-1][ :, torch.arange(1,hidden_states_length) : ] * self.pooling_weight[1]).mean(1) else: raise ValueError("Unrecognized pool strategy") prediction_out = self.classifier(hidden_states) return { "output": output, "hidden_states": hidden_states, "prediction_out": prediction_out, }
[docs] class TRabModelingBertForPseudoSequence(nn.Module): def __init__(self, bert_config: BertConfig, pooling: Union[Callable, Literal["cls", "mean", "max", "cdr3a", "cdr3b", "weighted"]] = "mean", pooling_weight: Optional[torch.Tensor] = None, labels_number: int = 1, use_triton: bool = False, device = "cuda", ) -> None: super(TRabModelingBertForPseudoSequence, self).__init__() if use_triton: self.model = TritonBertForMaskedLM(bert_config) else: self.model = BertForMaskedLM(bert_config) self.classifier = nn.Sequential( nn.Dropout(0.1), Linear(bert_config.hidden_size, labels_number, init='final') ) self.pooling = pooling self.pooling_weight = pooling_weight.to(device) if pooling_weight is not None else None self.config = bert_config self.device = device self.labels_number = labels_number self.to(device)
[docs] def forward(self, *, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor = None, output_hidden_states = True, ): ''' Forward pass of the model :param input_ids: Input ids :param attention_mask: Attention mask :param labels: Labels :param token_type_ids: Token type ids :return: Output of the model ''' output = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, labels=labels, output_hidden_states=output_hidden_states, ) hidden_states = None if output_hidden_states: if isinstance(self.pooling, str): if self.pooling == "mean": hidden_states = output.hidden_states[-1][:, 1:, :].mean(1) elif self.pooling == "max": hidden_states = output.hidden_states[-1][:,1:,:].max(1)[0] elif self.pooling == "cls": hidden_states = output.hidden_states[-1][:,0,:] elif self.pooling == 'cdr3b': hidden_states = output.hidden_states[-1][:,74:,:].mean(1) elif self.pooling == 'cdr3a': hidden_states = output.hidden_states[-1][:,19:55,:].mean(1) elif self.pooling == 'tra': hidden_states = output.hidden_states[-1][:,56:,:].mean(1) elif self.pooling == 'trb': hidden_states = output.hidden_states[-1][:,1:55,:].mean(1) elif self.pooling == 'weighted': hidden_states = (output.hidden_states[-1] * self.pooling_weight.to(input_ids.device)).mean(1) elif callable(self.pooling): hidden_states = self.pooling(output.hidden_states[-1]) else: raise ValueError("Unrecognized pool strategy") prediction_out = self.classifier(hidden_states) return { "output": output, "hidden_states": hidden_states, "prediction_out": prediction_out, }
class TCRpMHCPairUpdateBlock(nn.Module): def __init__(self, bert_config) -> None: super(TCRpMHCPairUpdateBlock, self).__init__() self.opm = OuterProductMean( bert_config.hidden_size, bert_config.hidden_size, bert_config.hidden_size, ) self.tri_att_start = TriangleAttentionStartingNode( bert_config.hidden_size, bert_config.hidden_size, bert_config.num_attention_heads ) self.tri_att_end = TriangleAttentionEndingNode( bert_config.hidden_size, bert_config.hidden_size, bert_config.num_attention_heads ) self.transition = PairTransition( bert_config.hidden_size, 1 ) def forward(self, z: torch.Tensor, z_mask: torch.Tensor, tcr_hidden_states: torch.Tensor, pmhc_hidden_states: torch.Tensor, tcr_attention_mask: torch.Tensor, pmhc_attention_mask: torch.Tensor, inplace_safe=False, output_attentions=False, ): z = z + self.opm( tcr_hidden_states, pmhc_hidden_states, tcr_attention_mask, pmhc_attention_mask ) if output_attentions: o, a1 = self.tri_att_start(z, z_mask, output_attentions=output_attentions) z = add(z, o, inplace_safe) o, a2 = self.tri_att_end(z, z_mask, output_attentions=output_attentions) z = add(z, o, inplace_safe) else: z = add(z, self.tri_att_start(z, z_mask, output_attentions=output_attentions), inplace_safe) z = add(z, self.tri_att_end(z, z_mask, output_attentions=output_attentions), inplace_safe) z = add(z, self.transition(z, z_mask), inplace_safe) if output_attentions: return z, a1, a2 return z class ContactDistogramHead(nn.Module): def __init__(self, c_z, no_bins=64): """ Args: c_z: Input channel dimension no_bins: Number of distogram bins """ super(ContactDistogramHead, self).__init__() self.c_z = c_z self.no_bins = no_bins self.linear = Linear(self.c_z, self.no_bins, init="final") def _forward(self, z): # [*, N, N, C_z] """ Args: z: [*, N_res, N_res, C_z] pair embedding Returns: [*, N, N, no_bins] distogram probability distribution """ # [*, N, N, no_bins] logits = self.linear(z) return logits def forward(self, z): return self._forward(z) class TCRpMHCPairAttention(nn.Module): def __init__(self, bert_config: BertConfig, use_triton: bool = False, ) -> None: super(TCRpMHCPairAttention, self).__init__() if use_triton: self.model = TritonBertForMaskedLM(bert_config) else: self.model = BertForMaskedLM(bert_config) self.embedder = JointEmbedder( bert_config.hidden_size, bert_config.hidden_size, bert_config.hidden_size, ) self.layers = nn.ModuleList([ TCRpMHCPairUpdateBlock(bert_config) for _ in range(bert_config.num_hidden_layers) ]) self.distogram_head = ContactDistogramHead(bert_config.hidden_size) self.contact_head = Linear(bert_config.hidden_size, 1, init="final") def forward(self, *, input_ids: torch.Tensor, attention_mask: torch.Tensor, tcr_hidden_states: torch.Tensor, tcr_attention_mask: torch.Tensor, output_hidden_states = True, output_triangular_attentions = True, ): output = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, output_hidden_states=output_hidden_states, ) pmhc_hidden_states = output.hidden_states[-1] z, z_mask = self.embedder( tcr_hidden_states, pmhc_hidden_states, tcr_attention_mask, attention_mask ) triangular_attentions = [] for layer in self.layers: z = layer( z, z_mask, tcr_hidden_states, pmhc_hidden_states, tcr_attention_mask, attention_mask, output_attentions=output_triangular_attentions ) if output_triangular_attentions: z, a1, a2 = z triangular_attentions.append((a1, a2)) distogram_logits = self.distogram_head(z) contact_logits = self.contact_head(z[:,0,0,:]) return ModelOutput(dict( z=z, z_mask=z_mask, distogram_logits=distogram_logits, contact_logits=contact_logits, triangular_attentions=triangular_attentions )) def softmax_cross_entropy(logits, labels): loss = -1 * torch.sum( labels * torch.nn.functional.log_softmax(logits, dim=-1), dim=-1, ) return loss class TRabModelingBertForPseudoSequenceWithContactModule(TRabModelingBertForPseudoSequence): def __init__( self, bert_config: BertConfig, pooling: Literal["cls", "mean", "max"] = "mean", labels_number: int = 1, use_triton: bool = False, pretrained_checkpoint: Optional[str] = None, device = "cuda", ) -> None: super(TRabModelingBertForPseudoSequenceWithContactModule, self).__init__( bert_config, pooling, labels_number, use_triton ) if pretrained_checkpoint is not None: self.load_state_dict(torch.load(pretrained_checkpoint)) self.contact_module = TCRpMHCPairAttention( bert_config, use_triton ) self.device = device self.to(device) def forward(self, *, tcr_input_ids: torch.Tensor, tcr_attention_mask: torch.Tensor, pmhc_input_ids: torch.Tensor, pmhc_attention_mask: torch.Tensor, tcr_pmhc_distogram: Optional[torch.Tensor] = None, tcr_pmhc_binding: Optional[torch.Tensor] = None, output_triangular_attentions: bool = False, ): tcr_output = self.model.forward( input_ids=tcr_input_ids, attention_mask=tcr_attention_mask, return_dict=True, output_hidden_states=True, ) contact_output = self.contact_module( input_ids=pmhc_input_ids, attention_mask=pmhc_attention_mask, tcr_hidden_states=tcr_output.hidden_states[-1], tcr_attention_mask=tcr_attention_mask, output_triangular_attentions=output_triangular_attentions, ) if tcr_pmhc_distogram is not None: contact_loss = (softmax_cross_entropy( contact_output.distogram_logits, tcr_pmhc_distogram ) * contact_output.z_mask).mean() if tcr_pmhc_binding is not None: binding_loss = torch.nn.functional.binary_cross_entropy_with_logits( contact_output.contact_logits, tcr_pmhc_binding ) return { "tcr_output": tcr_output, "contact_output": contact_output, "contact_loss": contact_loss if tcr_pmhc_distogram is not None else None, "binding_loss": binding_loss if tcr_pmhc_binding is not None else None, }