Source code for tcr_deep_insight.model.tokenizers._tokenizer

import pandas as pd
import torch
from torch.nn import functional as F
from typing import Union, Iterable, List, Tuple, Optional
from einops import rearrange, repeat
from sklearn.model_selection import train_test_split
import datasets
import tqdm
import numpy as np
from pathlib import Path
        
from tokenizers import Tokenizer, Encoding

from transformers import (
    PreTrainedTokenizerBase,
)

from scatlasvae.utils._tensor_utils import get_k_elements

from ...utils._tcr_definitions import (
    HumanTCRAnnotations,
    MouseTCRAnnotations
)

from ...utils._compat import Literal
from ...utils._decorators import deprecated
from ...utils._amino_acids import (
    _AMINO_ACIDS_ADDITIONALS,
    _AMINO_ACIDS_INDEX_REVERSE,
    _AMINO_ACIDS_INDEX
)

MODULE_PATH = Path(__file__).parent

class AminoAcidTokenizer(PreTrainedTokenizerBase):
    """Tokenizer for amino acids. The amino acid to token index follows the same layout as tcr-bert"""
    def __init__(self, 
        *,
        model_max_length: int,
        append_cls_token: bool = True,
        append_eos_token: bool = True,
        **kwargs
    ) -> None:
        # A special token representing an out-of-vocabulary token.
        if "pad_token" not in kwargs.keys() or not kwargs.get("pad_token", None):
            kwargs["pad_token"] = _AMINO_ACIDS_ADDITIONALS["PAD"] 
        # A special token representing the class of the input
        if "unk_token" not in kwargs.keys()  or not kwargs.get("unk_token", None):
            kwargs["unk_token"] = _AMINO_ACIDS_ADDITIONALS["UNK"]
        # A special token representing a masked token (used by masked-language modeling pretraining objectives, like BERT). 
        if "mask_token" not in kwargs.keys()  or not kwargs.get("mask_token", None):
            kwargs["mask_token"] = _AMINO_ACIDS_ADDITIONALS["PAD"] 
        if "sep_token" not in kwargs.keys()  or not kwargs.get("sep_token", None):
            kwargs["sep_token"] = _AMINO_ACIDS_ADDITIONALS["SEP"] 
        if "cls_token" not in kwargs.keys()  or not kwargs.get("cls_token", None):
            kwargs["cls_token"] = _AMINO_ACIDS_ADDITIONALS["CLS"] 

        kwargs["model_max_length"] = model_max_length
        super(AminoAcidTokenizer, self).__init__(**kwargs)
        self._vocab_size = len(_AMINO_ACIDS_INDEX)
        self.append_cls_token = append_cls_token
        self.append_eos_token = append_eos_token
        self._pad_token_id = _AMINO_ACIDS_INDEX[kwargs["pad_token"]]
        self._unk_token_id = _AMINO_ACIDS_INDEX[kwargs["unk_token"]]
        self._mask_token_id = _AMINO_ACIDS_INDEX[kwargs["mask_token"]]
        self._sep_token_id = _AMINO_ACIDS_INDEX[kwargs["sep_token"]]
        self._cls_token_id = _AMINO_ACIDS_INDEX[kwargs["cls_token"]]

    @property
    def is_fast(self) -> bool:
        return False

    @property
    def vocab_size(self) -> int:
        return self._vocab_size

    def _encode(self, aa: str, max_length:int = None) -> torch.Tensor:
        if self.append_cls_token and not aa.startswith(
            _AMINO_ACIDS_ADDITIONALS["CLS"]
        ):
            aa = _AMINO_ACIDS_ADDITIONALS["CLS"] + aa

        if self.append_eos_token and not aa.startswith(
            _AMINO_ACIDS_ADDITIONALS["SEP"]
        ):
            aa += _AMINO_ACIDS_ADDITIONALS["SEP"]

        max_length = max_length or self.model_max_length
        if len(aa) < max_length:
            aa = aa + self._pad_token * (max_length - len(aa))
        return torch.Tensor(list(map(lambda a: _AMINO_ACIDS_INDEX[a], aa)))
    
    def _unpad(self, s) -> str:
        return ''.join(list(filter(lambda x: x != self.pad_token, s)))

    def _decode(self, ids) -> str:
        return self._unpad(list(map(lambda t: _AMINO_ACIDS_INDEX_REVERSE[t], ids)))

    def convert_tokens_to_ids(self, sequence: Union[Iterable[str], str]) -> torch.Tensor:
        if isinstance(sequence, str):
            ids = rearrange(self._encode(sequence), '(n h) -> n h', n = 1).type(torch.LongTensor)
        elif isinstance(sequence, Iterable):
            ids = torch.vstack(list(map(lambda x: self._encode(x), sequence))).type(torch.LongTensor)
        mask = self.convert_ids_to_mask(ids)
        return {"indices": ids, "mask": mask}

    def convert_ids_to_tokens(
        self, 
        ids: Union[torch.Tensor, np.ndarray, Iterable[int]]
    ) -> Iterable[str]:
        ids = ids.detach().cpu().numpy().astype(np.int64)
        if len(ids.shape) == 1:
            return self._decode(ids)
        else:
            return list(map(self._decode, ids))
    
    def convert_ids_to_mask(self, ids: torch.Tensor) -> torch.Tensor:
        return (ids != self._pad_token_id) & (ids != self._cls_token_id) & (ids != self._sep_token_id)

    def to_dataset(
        self, 
        ids: Iterable[str],
        chains: Iterable[str], 
        
        split: bool = False
    ) -> datasets.DatasetDict:
        # sourcery skip: remove-unnecessary-else, swap-if-else-branches
        assert(len(ids) == len(chains))
        tokenized = self.convert_tokens_to_ids(chains)
        if split:
            train_idx, test_idx = train_test_split(list(range(len(ids))))
            return datasets.DatasetDict({
                "train": datasets.Dataset.from_dict({
                    "id": np.array(ids)[train_idx],
                    "chains": np.array(chains)[train_idx],
                    "input_ids": tokenized["indices"][train_idx],
                    "attention_mask": tokenized["mask"][train_idx],
                }),
                "test": datasets.Dataset.from_dict({
                    "id": np.array(ids)[test_idx],
                    "chains": np.array(chains)[test_idx],
                    "input_ids": tokenized["indices"][test_idx],
                    "attention_mask": tokenized["mask"][test_idx],
                })
            })
        else: 
            return datasets.DatasetDict({
                "train": datasets.Dataset.from_dict({
                    "id": np.array(ids),
                    "chains": np.array(chains),
                    "input_ids": tokenized["indices"],
                    "attention_mask": tokenized["mask"],
                }),
            })

[docs] class TRabTokenizerForVJCDR3(AminoAcidTokenizer): """Tokenizer for TRA and TRB sequence. Encode V,J genes into tokens, and CDR3 into amino acids""" def __init__(self, *, tra_max_length:int, trb_max_length:int, pad_token: Optional[str] = None, unk_token: Optional[str] = None, mask_token: Optional[str] = None, cls_token: Optional[str] = None, sep_token: Optional[str] = None, species: Literal['human', 'mouse'] = 'human', **kwargs ) -> None: # A special token representing an out-of-vocabulary token. kwargs["pad_token"] = pad_token or _AMINO_ACIDS_ADDITIONALS["PAD"] # A special token representing the class of the input kwargs["unk_token"] = unk_token or _AMINO_ACIDS_ADDITIONALS["UNK"] # A special token representing a masked token (used by masked-language modeling pretraining objectives, like BERT). kwargs["mask_token"] = mask_token or _AMINO_ACIDS_ADDITIONALS["MASK"] kwargs["cls_token"] = cls_token or _AMINO_ACIDS_ADDITIONALS["CLS"] kwargs["sep_token"] = sep_token or _AMINO_ACIDS_ADDITIONALS["SEP"] super(TRabTokenizerForVJCDR3, self).__init__(model_max_length = tra_max_length + trb_max_length, **kwargs) self.tra_max_length = tra_max_length self.trb_max_length = trb_max_length if species == 'human': self.VJ_GENES2INDEX = HumanTCRAnnotations.VJ_GENES2INDEX self.VJ_GENES2INDEX_REVERSE = HumanTCRAnnotations.VJ_GENES2INDEX_REVERSE elif species == 'mouse': self.VJ_GENES2INDEX = MouseTCRAnnotations.VJ_GENES2INDEX self.VJ_GENES2INDEX_REVERSE = MouseTCRAnnotations.VJ_GENES2INDEX_REVERSE def _encode(self, aa: str, v_gene: str = None, j_gene: str = None, max_length: int = None) -> torch.Tensor: aa = list(aa) if self.append_cls_token: if aa[0] != _AMINO_ACIDS_ADDITIONALS["CLS"]: aa = [_AMINO_ACIDS_ADDITIONALS["CLS"]] + aa if aa[-1] != _AMINO_ACIDS_ADDITIONALS["SEP"]: aa += [_AMINO_ACIDS_ADDITIONALS["SEP"]] max_length = max_length or self.model_max_length if v_gene and j_gene: aa = [v_gene] + aa + [j_gene] if len(aa) < max_length: aa += list(self._pad_token * (max_length - len(aa))) return torch.Tensor(list(map(lambda a: _AMINO_ACIDS_INDEX.get(a) if a in _AMINO_ACIDS_INDEX.keys() else self.VJ_GENES2INDEX.get(a, 0) + len(_AMINO_ACIDS_INDEX), aa)))
[docs] def convert_tokens_to_ids(self, sequence: Union[List[Tuple[str]], Tuple[str]], alpha_vj: Optional[Union[List[Tuple[str]], Tuple[str]]] = None, beta_vj: Optional[Union[List[Tuple[str]], Tuple[str]]] = None): # sourcery skip: none-compare, swap-if-else-branches if isinstance(sequence, list): if alpha_vj != None and beta_vj != None: ids = torch.hstack([ torch.vstack( list(map(lambda x: self._encode(x[0], x[1][0], x[1][1], max_length=self.tra_max_length), zip(get_k_elements(sequence, 0), alpha_vj))) ), torch.vstack( list(map(lambda x: self._encode(x[0], x[1][0], x[1][1], max_length=self.trb_max_length), zip(get_k_elements(sequence, 1), beta_vj))) )]).type(torch.LongTensor) else: ids = torch.hstack([ torch.vstack(list(map(lambda x: self._encode(x[0], max_length=self.tra_max_length), sequence))), torch.vstack(list(map(lambda x: self._encode(x[1], max_length=self.trb_max_length), sequence))) ]).type(torch.LongTensor) elif alpha_vj != None and beta_vj != None: ids = rearrange(torch.hstack([self._encode(sequence[0], alpha_vj[0], alpha_vj[1], max_length=self.tra_max_length), self._encode(sequence[1], beta_vj[0], beta_vj[1], max_length=self.tra_max_length)]), '(n h) -> n h', n=1).type(torch.LongTensor) else: ids = rearrange(torch.hstack([self._encode(sequence[0], max_length=self.tra_max_length), self._encode(sequence[1], max_length=self.tra_max_length)]), '(n h) -> n h', n=1).type(torch.LongTensor) mask = self.convert_ids_to_mask(ids) token_type_ids = torch.hstack([torch.zeros(ids.shape[0], self.tra_max_length), torch.ones(ids.shape[0], self.trb_max_length)]).type(torch.LongTensor) return {"indices": ids, "mask": mask, "token_type_ids": token_type_ids}
def _decode(self, ids): return list(map(lambda t: _AMINO_ACIDS_INDEX_REVERSE[t] if t in _AMINO_ACIDS_INDEX_REVERSE.keys() else self.VJ_GENES2INDEX_REVERSE[t - len(_AMINO_ACIDS_INDEX)], ids)) def _trab_decode(self, ids): dec = self._decode(ids) return (self._unpad(dec[:self.tra_max_length]), self._unpad(dec[self.tra_max_length:]))
[docs] def convert_ids_to_tokens(self, ids: torch.Tensor): ids = ids.detach().cpu().numpy().astype(np.int64) if len(ids.shape) == 1: return self._trab_decode(ids) else: return list(map(lambda x: self._trab_decode(x), ids))
[docs] def to_dataset( self, df: pd.DataFrame = None, ids: Iterable[str] = None, alpha_chains: Iterable[str] = None, beta_chains: Iterable[str] = None, alpha_v_genes: Iterable[str] = None, alpha_j_genes: Iterable[str] = None, beta_v_genes: Iterable[str] = None, beta_j_genes: Iterable[str] = None, pairing: Iterable[int] = None, split: bool = False ): if df is not None: ids = df['id'] alpha_chains = df['CDR3a'] beta_chains = df['CDR3b'] if 'TRAV' in df.columns and 'TRAJ' in df.columns and 'TRBV' in df.columns and 'TRBJ' in df.columns: alpha_v_genes = df['TRAV'] alpha_j_genes = df['TRAJ'] beta_v_genes = df['TRBV'] beta_j_genes = df['TRBJ'] if 'pairing' in df.columns: pairing = df['pairing'] return self._to_dataset( ids, alpha_chains, beta_chains, alpha_v_genes, alpha_j_genes, beta_v_genes, beta_j_genes, pairing, split ) else: return self._to_dataset( ids, alpha_chains, beta_chains, alpha_v_genes, alpha_j_genes, beta_v_genes, beta_j_genes, pairing, split )
def _to_dataset( self, ids: Iterable[str], alpha_chains: Iterable[str], beta_chains: Iterable[str], alpha_v_genes: Iterable[str] = None, alpha_j_genes: Iterable[str] = None, beta_v_genes: Iterable[str] = None, beta_j_genes: Iterable[str] = None, pairing: Iterable[int] = None, split: bool = False ) -> datasets.DatasetDict: if not len(ids) == len(alpha_chains) == len(beta_chains): raise ValueError("Length of ids(%d), alpha_chains(%d) and beta_chains(%d) do not match" % (len(ids), len(alpha_chains), len(beta_chains))) alpha_chains = list(alpha_chains) beta_chains = list(beta_chains) if all(map(lambda x: x is not None, [alpha_v_genes, alpha_j_genes, beta_v_genes, beta_j_genes])): tokenized = self.convert_tokens_to_ids( list(zip(alpha_chains, beta_chains)), list(zip(alpha_v_genes, alpha_j_genes)), list(zip(beta_v_genes, beta_j_genes)), ) else: tokenized = self.convert_tokens_to_ids(list(zip(alpha_chains, beta_chains))) if split: train_idx, test_idx = train_test_split(list(range(len(ids)))) return datasets.DatasetDict({ "train": datasets.Dataset.from_dict({ "id": np.array(ids)[train_idx], "alpha_chains": np.array(alpha_chains)[train_idx], "beta_chains": np.array(beta_chains)[train_idx], "input_ids": tokenized["indices"][train_idx], "token_type_ids": tokenized["token_type_ids"][train_idx], "attention_mask": tokenized["mask"][train_idx], "pairing": np.array(pairing)[train_idx] if pairing is not None else np.ones(len(train_idx), dtype=np.uint8) }), "test": datasets.Dataset.from_dict({ "id": np.array(ids)[test_idx], "alpha_chains": np.array(alpha_chains)[test_idx], "beta_chains": np.array(beta_chains)[test_idx], "input_ids": tokenized["indices"][test_idx], "token_type_ids": tokenized["token_type_ids"][test_idx], "attention_mask": tokenized["mask"][test_idx], "pairing": np.array(pairing)[test_idx] if pairing is not None else np.ones(len(test_idx), dtype=np.uint8) }) }) else: return datasets.DatasetDict({ "train": datasets.Dataset.from_dict({ "id": ids, "alpha_chains": alpha_chains, "beta_chains": beta_chains, "input_ids": tokenized["indices"], "token_type_ids": tokenized["token_type_ids"], "attention_mask": tokenized["mask"], "pairing": np.array(pairing) if pairing is not None else np.ones(len(alpha_chains), dtype=np.uint8) }) }) def __call__(self): raise NotImplementedError
class TCRabTokenizerForCDR123(AminoAcidTokenizer): @deprecated(ymd=(2024, 6, 1)) def __init__( self, tra_max_length:int, trb_max_length:int, pad_token: Optional[str] = None, unk_token: Optional[str] = None, mask_token: Optional[str] = None, cls_token: Optional[str] = None, sep_token: Optional[str] = None, species: Literal['human', 'mouse'] = 'human', **kwargs ) -> None: # A special token representing an out-of-vocabulary token. kwargs["pad_token"] = pad_token or _AMINO_ACIDS_ADDITIONALS["PAD"] # A special token representing the class of the input kwargs["unk_token"] = unk_token or _AMINO_ACIDS_ADDITIONALS["UNK"] # A special token representing a masked token (used by masked-language modeling pretraining objectives, like BERT). kwargs["mask_token"] = mask_token or _AMINO_ACIDS_ADDITIONALS["MASK"] kwargs["cls_token"] = cls_token or _AMINO_ACIDS_ADDITIONALS["CLS"] kwargs["sep_token"] = sep_token or _AMINO_ACIDS_ADDITIONALS["SEP"] super(TCRabTokenizerForCDR123, self).__init__(model_max_length = tra_max_length + trb_max_length, **kwargs) self.tra_max_length = tra_max_length self.trb_max_length = trb_max_length self.species = species def _encode(self, aa: str, v_gene: str = None, j_gene: str = None, max_length: int = None) -> torch.Tensor: if self.species == 'human': cdr1 = HumanTCRAnnotations.TRAV2CDR1a[v_gene] if v_gene in HumanTCRAnnotations.TRAV2CDR1a.keys() else HumanTCRAnnotations.TRBV2CDR1b[v_gene] cdr2 = HumanTCRAnnotations.TRAV2CDR2a[v_gene] if v_gene in HumanTCRAnnotations.TRAV2CDR2a.keys() else HumanTCRAnnotations.TRBV2CDR2b[v_gene] elif self.species == 'mouse': cdr1 = MouseTCRAnnotations.TRAV2CDR1a[v_gene] if v_gene in MouseTCRAnnotations.TRAV2CDR1a.keys() else MouseTCRAnnotations.TRBV2CDR1b[v_gene] cdr2 = MouseTCRAnnotations.TRAV2CDR2a[v_gene] if v_gene in MouseTCRAnnotations.TRAV2CDR2a.keys() else MouseTCRAnnotations.TRBV2CDR2b[v_gene] aa = cdr1 + '.' + cdr2 + '.' + aa aa = list(aa) if self.append_cls_token: if aa[0] != _AMINO_ACIDS_ADDITIONALS["CLS"]: aa = [_AMINO_ACIDS_ADDITIONALS["CLS"]] + aa if aa[-1] != _AMINO_ACIDS_ADDITIONALS["SEP"]: aa += [_AMINO_ACIDS_ADDITIONALS["SEP"]] max_length = max_length or self.model_max_length if len(aa) < max_length: aa += list(self._pad_token * (max_length - len(aa))) return torch.Tensor(list(map(lambda a: _AMINO_ACIDS_INDEX.get(a), aa))) def convert_tokens_to_ids(self, sequence: Union[List[Tuple[str]], Tuple[str]], alpha_vj: Optional[Union[List[Tuple[str]], Tuple[str]]] = None, beta_vj: Optional[Union[List[Tuple[str]], Tuple[str]]] = None): # sourcery skip: none-compare, swap-if-else-branches if isinstance(sequence, list): if alpha_vj != None and beta_vj != None: ids = torch.hstack([ torch.vstack( list(map(lambda x: self._encode(x[0], x[1][0], x[1][1], max_length=self.tra_max_length), zip(get_k_elements(sequence, 0), alpha_vj))) ), torch.vstack( list(map(lambda x: self._encode(x[0], x[1][0], x[1][1], max_length=self.trb_max_length), zip(get_k_elements(sequence, 1), beta_vj))) )]).type(torch.LongTensor) else: ids = torch.hstack([ torch.vstack(list(map(lambda x: self._encode(x[0], max_length=self.tra_max_length), sequence))), torch.vstack(list(map(lambda x: self._encode(x[1], max_length=self.trb_max_length), sequence))) ]).type(torch.LongTensor) elif alpha_vj != None and beta_vj != None: ids = rearrange(torch.hstack([self._encode(sequence[0], alpha_vj[0], alpha_vj[1], max_length=self.tra_max_length), self._encode(sequence[1], beta_vj[0], beta_vj[1], max_length=self.tra_max_length)]), '(n h) -> n h', n=1).type(torch.LongTensor) else: ids = rearrange(torch.hstack([self._encode(sequence[0], max_length=self.tra_max_length), self._encode(sequence[1], max_length=self.tra_max_length)]), '(n h) -> n h', n=1).type(torch.LongTensor) mask = self.convert_ids_to_mask(ids) token_type_ids = torch.hstack([torch.zeros(ids.shape[0], self.tra_max_length), torch.ones(ids.shape[0], self.trb_max_length)]).type(torch.LongTensor) return {"indices": ids, "mask": mask, "token_type_ids": token_type_ids} def _decode(self, ids): return list(map(lambda t: _AMINO_ACIDS_INDEX_REVERSE[t] if t in _AMINO_ACIDS_INDEX_REVERSE.keys() else self.VJ_GENES2INDEX_REVERSE[t - len(_AMINO_ACIDS_INDEX)], ids)) def convert_ids_to_tokens(self, ids) -> Iterable[str]: return super().convert_ids_to_tokens(ids) def convert_ids_to_mask(self, ids: torch.Tensor) -> torch.Tensor: return super().convert_ids_to_mask(ids) def to_dataset( self, df: pd.DataFrame = None, ids: Iterable[str] = None, alpha_chains: Iterable[str] = None, beta_chains: Iterable[str] = None, alpha_v_genes: Iterable[str] = None, alpha_j_genes: Iterable[str] = None, beta_v_genes: Iterable[str] = None, beta_j_genes: Iterable[str] = None, pairing: Iterable[int] = None, split: bool = False ): if df is not None: ids = df['id'] alpha_chains = df['CDR3a'] beta_chains = df['CDR3b'] if 'TRAV' in df.columns and 'TRAJ' in df.columns and 'TRBV' in df.columns and 'TRBJ' in df.columns: alpha_v_genes = df['TRAV'] alpha_j_genes = df['TRAJ'] beta_v_genes = df['TRBV'] beta_j_genes = df['TRBJ'] if 'pairing' in df.columns: pairing = df['pairing'] return self._to_dataset( ids, alpha_chains, beta_chains, alpha_v_genes, alpha_j_genes, beta_v_genes, beta_j_genes, pairing, split ) else: return self._to_dataset( ids, alpha_chains, beta_chains, alpha_v_genes, alpha_j_genes, beta_v_genes, beta_j_genes, pairing, split ) def _to_dataset( self, ids: Iterable[str], alpha_chains: Iterable[str], beta_chains: Iterable[str], alpha_v_genes: Iterable[str] = None, alpha_j_genes: Iterable[str] = None, beta_v_genes: Iterable[str] = None, beta_j_genes: Iterable[str] = None, pairing: Iterable[int] = None, split: bool = False ) -> datasets.DatasetDict: if not len(ids) == len(alpha_chains) == len(beta_chains): raise ValueError("Length of ids(%d), alpha_chains(%d) and beta_chains(%d) do not match" % (len(ids), len(alpha_chains), len(beta_chains))) alpha_chains = list(alpha_chains) beta_chains = list(beta_chains) if all(map(lambda x: x is not None, [alpha_v_genes, alpha_j_genes, beta_v_genes, beta_j_genes])): tokenized = self.convert_tokens_to_ids( list(zip(alpha_chains, beta_chains)), list(zip(alpha_v_genes, alpha_j_genes)), list(zip(beta_v_genes, beta_j_genes)), ) else: tokenized = self.convert_tokens_to_ids(list(zip(alpha_chains, beta_chains))) if split: train_idx, test_idx = train_test_split(list(range(len(ids)))) return datasets.DatasetDict({ "train": datasets.Dataset.from_dict({ "id": np.array(ids)[train_idx], "alpha_chains": np.array(alpha_chains)[train_idx], "beta_chains": np.array(beta_chains)[train_idx], "input_ids": tokenized["indices"][train_idx], "token_type_ids": tokenized["token_type_ids"][train_idx], "attention_mask": tokenized["mask"][train_idx], "pairing": np.array(pairing)[train_idx] if pairing is not None else np.ones(len(train_idx), dtype=np.uint8) }), "test": datasets.Dataset.from_dict({ "id": np.array(ids)[test_idx], "alpha_chains": np.array(alpha_chains)[test_idx], "beta_chains": np.array(beta_chains)[test_idx], "input_ids": tokenized["indices"][test_idx], "token_type_ids": tokenized["token_type_ids"][test_idx], "attention_mask": tokenized["mask"][test_idx], "pairing": np.array(pairing)[test_idx] if pairing is not None else np.ones(len(test_idx), dtype=np.uint8) }) }) else: return datasets.DatasetDict({ "train": datasets.Dataset.from_dict({ "id": ids, "alpha_chains": alpha_chains, "beta_chains": beta_chains, "input_ids": tokenized["indices"], "token_type_ids": tokenized["token_type_ids"], "attention_mask": tokenized["mask"], "pairing": np.array(pairing) if pairing is not None else np.ones(len(alpha_chains), dtype=np.uint8) }) }) TRabTokenizerForPseudoSequence = Tokenizer trab_tokenizer_for_pseudosequence = TRabTokenizerForPseudoSequence.from_str( open( MODULE_PATH / "tokenizer.json" ).read() ) trab_tokenizer_for_pseudosequence.pad_token_id = 0 trab_tokenizer_for_pseudosequence.pad_token = "<|bos|>" trab_tokenizer_for_pseudosequence.eos_token_id = 1 trab_tokenizer_for_pseudosequence.eos_token = "<|eos|>" trab_tokenizer_for_pseudosequence.bos_token_id = 2 trab_tokenizer_for_pseudosequence.bos_token = "<|bos|>" trab_tokenizer_for_pseudosequence.bos_token_id = 3 default_max_length = { 'CDR3a': 36, 'CDR3b': 36, 'CDR1b': 8, 'CDR2b': 8, 'CDR1a': 8, 'CDR2a': 8, } def tcr_pseudo_sequence_weight( cdr1a: float = 1., cdr2a: float = 1., cdr3a: float = 3., cdr1b: float = 1., cdr2b: float = 1., cdr3b: float = 5., ): weight = torch.zeros(110) weight[1:9] = cdr1a weight[10:18] = cdr2a weight[19:55] = cdr3a weight[56:64] = cdr1b weight[65:73] = cdr2b weight[74:110] = cdr3b return weight def tokenize_tcr_pseudo_sequence_to_fixed_length( pseudo_sequence, max_length = default_max_length, pad_token = '.', show_progress=False ): _f = False if isinstance(pseudo_sequence, str): pseudo_sequences = [pseudo_sequence] _f = True elif isinstance(pseudo_sequence, Iterable): pseudo_sequences = pseudo_sequence tokens, ids, attention_mask = [], [], [] if show_progress: pbar = tqdm.tqdm(total=len(pseudo_sequences)) for pseudo_sequence in pseudo_sequences: cdr1b, cdr2b, cdr3b, cdr1a, cdr2a, cdr3a = pseudo_sequence.split(':') cdr1a = cdr1a + pad_token * (max_length['CDR1a'] - len(cdr1a)) cdr2a = cdr2a + pad_token * (max_length['CDR2a'] - len(cdr2a)) cdr1b = cdr1b + pad_token * (max_length['CDR1b'] - len(cdr1b)) cdr2b = cdr2b + pad_token * (max_length['CDR2b'] - len(cdr2b)) cdr3a = cdr3a + pad_token * (max_length['CDR3a'] - len(cdr3a)) cdr3b = cdr3b + pad_token * (max_length['CDR3b'] - len(cdr3b)) tokens.append('^' + ':'.join([cdr1a, cdr2a, cdr3a, cdr1b, cdr2b, cdr3b])) encoding = trab_tokenizer_for_pseudosequence.encode(tokens[-1]) ids.append(np.array(encoding.ids)) attention_mask.append( ids[-1] >= 6 ) if show_progress: pbar.update(1) if show_progress: pbar.close() if _f: return tokens[0], ids[0], attention_mask[0] else: return tokens, ids, attention_mask def tokenize_to_fixed_length( sequence, max_length, pad_token = '.', ): _f = False if isinstance(sequence, str): sequences = [sequence] _f = True elif isinstance(sequence, Iterable): sequences = sequence tokens, ids, attention_mask = [], [], [] for sequence in sequences: sequence = sequence + pad_token * (max_length - len(sequence)) tokens.append('^' + sequence) encoding = trab_tokenizer_for_pseudosequence.encode(tokens[-1]) ids.append(np.array(encoding.ids)) attention_mask.append( ids[-1] >= 6 ) if _f: return tokens[0], ids[0], attention_mask[0] else: return tokens, ids, attention_mask