import os
from typing import Dict, List, Optional, Union, Mapping
import pandas as pd
import scanpy as sc
import numpy as np
import faiss
from faiss import IndexFlatL2
from ._constants import (
FAISS_INDEX_BACKEND,
TDI_RESULT_FIELD
)
from ..utils._compat import Literal
from ..utils._logger import Colors, get_tqdm
from ..utils._utilities import FLATTEN
from ..utils._tcr_definitions import TRAB_DEFINITION
from ..utils._tcr import TCR
def check_is_parquet_serializable(obj: pd.DataFrame):
for c in obj.columns:
if obj[c].dtype == 'O':
obj[c] = obj[c].astype(str)
elif isinstance(obj[c].dtype, pd.CategoricalDtype):
obj[c] = obj[c].astype(str)
return obj
class TDICluster(list):
@property
def number_of_unique_tcrs(self):
return len(list(map(lambda x: x.to_tcr_string(), self)))
@property
def number_of_individuals(self):
return len(np.unique(list(map(lambda x: x.individual, self))))
def __repr__(self):
return f'<TDICluster at {hex(id(self))} with {len(self)} TCRs>'
[docs]
class TDIResult:
"""
A class to store the TDI clustering result
:param _data: the clustering result
:param _tcr_df: the tcr dataframe
:param _tcr_adata: the tcr anndata
:param _gex_adata: the gex anndata
:param _cluster_label: the cluster label
:param faiss_index: the faiss index
:param low_memory: whether to use low memory mode
"""
def __init__(
self,
_data: sc.AnnData,
_tcr_df: pd.DataFrame = None,
_tcr_adata: sc.AnnData = None,
_gex_adata: Optional[sc.AnnData] = None,
_cluster_label: Optional[str] = None,
faiss_index: Optional[IndexFlatL2] = None,
low_memory: bool = False
):
self._data = _data
self._data.obs['TCRab'] = list(map(TDICluster, self._data.obs['TCRab']))
self._tcr_df = _tcr_df
self._tcr_adata = _tcr_adata
self._gex_adata = _gex_adata
self._cluster_label = _cluster_label
self._data.uns["_cluster_label"] = _cluster_label
self.faiss_index = faiss_index
self.low_memory = low_memory
@property
def data(self):
return self._data
@property
def cluster_label(self):
return self._cluster_label
@property
def I(self):
return self._data.uns['I']
@property
def D(self):
return self._data.uns['D']
@property
def tcr_df(self):
return self._tcr_df
def __getattribute__(self, name):
if name in object.__getattribute__(self, '_data').obs.columns:
return object.__getattribute__(self, '_data').obs[name].to_numpy()
else:
return super().__getattribute__(name)
[docs]
def select(self, indices):
return TDIResult(
self._data[indices],
self._tcr_df,
self._tcr_adata,
self._gex_adata,
self._cluster_label,
self.faiss_index,
self.low_memory
)
@cluster_label.setter
def cluster_label(self, cluster_label: str):
assert(cluster_label in self._data.obs.columns)
self._cluster_label = cluster_label
@property
def tcr_adata(self):
return self._tcr_adata
@tcr_adata.setter
def tcr_adata(self, tcr_adata: sc.AnnData):
self._tcr_adata = tcr_adata
@property
def gex_adata(self):
return self._gex_adata
@gex_adata.setter
def gex_adata(self, gex_adata: sc.AnnData):
self._gex_adata = gex_adata
def __repr__(self) -> str:
base_string = f'{Colors.GREEN}TDIResult{Colors.NC} object containing {Colors.CYAN}{self.data.shape[0]}{Colors.NC} clusters\n'
if self._cluster_label is not None:
base_string += f' Cluster label: {self._cluster_label}\n'
if self._gex_adata is not None:
base_string += f' GEX data: \n' +\
' ' + self._gex_adata.__repr__().replace('\n', '\n ')
if self._tcr_adata is not None:
base_string += f' TCR data: \n' +\
' ' + self._tcr_adata.__repr__().replace('\n', '\n ')
return base_string
[docs]
def save_to_disk(self,
save_path,
save_cluster_result_as_csv=True,
save_tcr_data=True,
save_gex_data=True
):
"""
Save the cluster result to disk
:param save_path: the path to save the cluster result
save_cluster_result_as_csv: whether to save the cluster result as csv files
:param save_tcr_data: whether to save the tcr data
:param save_gex_data: whether to save the gex data
"""
if not os.path.exists(save_path):
os.mkdir(save_path)
column0 = None
if isinstance(self._data.obs['TCRab'].iloc[0], list):
column0 = list(self._data.obs["TCRab"])
self._data.obs["TCRab"] = list(
map(
lambda x: ",".join(list(map(lambda x: x.to_string(), x))),
column0,
)
)
self._data.write_h5ad(os.path.join(save_path, 'cluster_data.h5ad'))
if column0 is not None:
self._data.obs["TCRab"] = column0
if self._tcr_df is not None:
self._tcr_df = check_is_parquet_serializable(self._tcr_df)
self._tcr_df.to_parquet(os.path.join(save_path, 'tcr_data.parquet'))
if save_tcr_data and self._tcr_adata is not None:
self._tcr_adata.write_h5ad(os.path.join(save_path, 'tcr_data.h5ad'))
if save_gex_data and self._gex_adata is not None:
self._gex_adata.write_h5ad(os.path.join(save_path, 'gex_data.h5ad'))
if self.faiss_index is not None:
try:
faiss.write_index(self.faiss_index, os.path.join(save_path, 'faiss_index.faiss'))
except:
pass
# print('Failed to save faiss index')
if save_cluster_result_as_csv:
cluster_tcr = self.to_pandas_dataframe_tcr()
cluster_index = self.to_pandas_dataframe_cluster_index()
cluster_tcr.to_csv(os.path.join(save_path, "cluster_tcr.csv"), index=False)
cluster_index.to_csv(
os.path.join(save_path, "cluster_index.csv"), index=False
)
[docs]
@classmethod
def load_from_disk(
cls,
save_path: str,
tcr_data_path: Optional[str] = None,
gex_adata_path: Optional[str] = None
):
"""
Load the cluster result from disk
:param save_path: the path to load the cluster result
:param tcr_data_path: the path to load the tcr data
:param gex_adata_path: the path to load the gex data
"""
data = sc.read_h5ad(os.path.join(save_path, 'cluster_data.h5ad'))
if type(data.obs['TCRab'].iloc[0]) == str:
data.obs["TCRab"] = list(
map(
lambda x: list(
map(lambda z: TCR.from_string(z), filter(lambda tcr: tcr != '-', x.split(",")))
),
data.obs["TCRab"],
)
)
tcr_df = None
if os.path.exists(os.path.join(save_path, 'tcr_data.parquet')):
tcr_df = pd.read_parquet(os.path.join(save_path, 'tcr_data.parquet'))
if tcr_data_path is not None:
print('Loading tcr data from {}'.format(tcr_data_path))
tcr_adata = sc.read_h5ad(tcr_data_path)
else:
if os.path.exists(os.path.join(save_path, 'tcr_data.h5ad')):
tcr_adata = sc.read_h5ad(os.path.join(save_path, 'tcr_data.h5ad'))
else:
tcr_adata = None
if gex_adata_path is not None:
print('Loading gex data from {}'.format(gex_adata_path))
gex_adata = sc.read_h5ad(gex_adata_path)
else:
if os.path.exists(os.path.join(save_path, 'gex_data.h5ad')):
gex_adata = sc.read_h5ad(os.path.join(save_path, 'gex_data.h5ad'))
else:
gex_adata = None
cluster_label = None
if '_cluster_label' in data.uns:
cluster_label = data.uns['_cluster_label']
faiss_index = None
if os.path.exists(os.path.join(save_path, 'faiss_index.faiss')):
try:
faiss_index = faiss.read_index(os.path.join(save_path, 'faiss_index.faiss'))
except:
pass
return cls(data, tcr_df, tcr_adata, gex_adata, cluster_label, faiss_index=faiss_index)
[docs]
def get_tcrs_for_cluster(
self,
label: Optional[Union[Mapping[str, str], Mapping[str, List[str]]]] = None,
rank: int = 0,
rank_by: Literal["convergence", "disease_association"] = 'convergence',
min_unique_tcr_number: int = 4,
min_individual_number: int = 2,
min_cell_number: int = 10,
min_tcr_convergence_score: Optional[float] = None,
min_disease_association_score: Optional[float] = None,
return_background_tcrs: bool = False,
additional_label_key_values: Optional[Dict[str, List[str]]] = None
):
"""
Get the tcrs for a specific cluster
:param label: the cluster label
:param rank: the rank of the tcrs to return
:param rank_by: the metric to rank the tcrs
:param min_unique_tcr_number: the minimum number of unique tcrs in the cluster
:param min_individual_number: the minimum number of individuals in the cluster
:param min_cell_number: the minimum number of cells in the cluster
:param min_tcr_convergence_score: the minimum convergence score
:param min_disease_association_score: the minimum disease specificity score
:param return_background_tcrs: whether to return other tcrs in the cluster
:param additional_label_key_values: additional label key values to filter the cluster
:return: a dictionary containing the tcrs and their metadata
"""
return self._get_tcrs_for_cluster(
label,
rank_by,
rank,
min_unique_tcr_number=min_unique_tcr_number,
min_individual_number=min_individual_number,
min_cell_number=min_cell_number,
min_tcr_convergence_score = min_tcr_convergence_score,
min_disease_association_score = min_disease_association_score,
return_background_tcrs=return_background_tcrs,
additional_label_key_values=additional_label_key_values
)
[docs]
def to_pandas_dataframe_tcr(
self,
rank_by: Literal["convergence", "disease_association", "individual", "unique_tcr"] = 'convergence',
return_background_tcrs: bool = False
):
"""
Convert the cluster result to a pandas dataframe.
:param rank_by: the metric to rank the tcrs
:param return_background_tcrs: whether to return background tcrs for each cluster
:return: a pandas dataframe containing the cluster result
"""
ret = []
cluster_indices = []
cluster_labels = []
if rank_by == 'convergence':
self.data.obs.sort_values(TDI_RESULT_FIELD.CONVERGENCE.value, ascending=False, inplace=True)
elif rank_by == 'disease_association':
self.data.obs.sort_values(TDI_RESULT_FIELD.DISEASE_ASSOCIATION.value, ascending=True, inplace=True)
elif rank_by == 'individual':
self.data.obs.sort_values(TDI_RESULT_FIELD.NUMBER_OF_INDIVIDUAL.value, ascending=False, inplace=True)
elif rank_by == 'unique_tcr':
self.data.obs.sort_values(TDI_RESULT_FIELD.NUMBER_OF_UNIQUE_TCR.value, ascending=False, inplace=True)
if self.cluster_label in self.data.obs.columns:
for i, l, j in zip(
self.data.obs.iloc[:, 0],
self.data.obs[self.cluster_label],
self.data.obs["cluster_index"],
):
if isinstance(i, str):
ret.append(
list(
map(
lambda x: x.split("=")[:7] + ['foreground'],
filter(lambda x: x != "-", i.split(",")),
)
)
)
elif isinstance(i, list):
ret.append(
list(
map(
lambda x: (
x.cdr3a,
x.cdr3b,
x.trav,
x.traj,
x.trbv,
x.trbj,
x.individual,
'foreground'
),
i,
)
)
)
cluster_indices.append([j] * len(ret[-1]))
cluster_labels.append([l] * len(ret[-1]))
size = len(ret[-1])
if return_background_tcrs:
cluster_labels.append([])
ret.append([])
tcr_clusters = self.get_tcrs_by_cluster_index(j)
c = 0
for i, background_l in zip(tcr_clusters['tcrs'], tcr_clusters['labels']):
if background_l != l:
ret[-1].append(i.split("=")[:7] + ['background'])
c += 1
cluster_labels[-1].append(background_l)
if c == size:
break
cluster_indices.append([j] * c)
df = pd.DataFrame(FLATTEN(ret), columns=TRAB_DEFINITION + ["individual", "type"])
df["cluster_index"] = FLATTEN(cluster_indices)
df["cluster_label"] = FLATTEN(cluster_labels)
else:
for i, j in zip(self.data.obs.iloc[:, 0], self.data.obs["cluster_index"]):
if isinstance(i, str):
ret.append(
list(
map(
lambda x: x.split("=")[:7],
filter(lambda x: x != "-", i.split(",")),
)
)
)
elif isinstance(i, list):
ret.append(
list(
map(
lambda x: (
x.cdr3a,
x.cdr3b,
x.trav,
x.traj,
x.trbv,
x.trbj,
x.individual,
),
i,
)
)
)
cluster_indices.append([j] * len(ret[-1]))
df = pd.DataFrame(FLATTEN(ret), columns=TRAB_DEFINITION + ["individual"])
df["cluster_index"] = FLATTEN(cluster_indices)
return df
[docs]
def to_pandas_dataframe_cluster_index(self):
return self.data.obs.loc[
:,
["cluster_index"]
+ list(
filter(
lambda x: x not in ["cluster_index", "TCRab"], self.data.obs.columns
)
),
]
def _get_tcrs_for_cluster(
self,
label: Optional[str] = None,
rank_by: Literal["convergence", "disease_association", "individual", "unique_tcr"] = 'convergence',
rank: int = 0,
min_unique_tcr_number: int = 4,
min_individual_number: int = 2,
min_cell_number: int = 10,
min_tcr_convergence_score: Optional[float] = None,
min_disease_association_score: Optional[float] = None,
return_background_tcrs: bool = False,
additional_label_key_values: Optional[Dict[str, List[str]]] = None
):
if rank_by not in ['convergence', 'disease_association', "individual", "unique_tcr"]:
raise ValueError('rank_by must be one of ["convergence", "disease_association"], got {}'.format(rank_by))
min_tcr_convergence_score = min_tcr_convergence_score if min_tcr_convergence_score is not None else -1
min_disease_association_score = min_disease_association_score if min_disease_association_score is not None else -1
if additional_label_key_values is not None:
additional_indices = np.bitwise_and.reduce(
[
np.array(self.data.obs[key] == value) for key, value in additional_label_key_values.items()
]
)
else:
additional_indices = np.ones(len(self.data.obs), dtype=bool)
if label is not None:
label_criteria = np.bitwise_and.reduce([
np.array(self.data.obs[k] == v)
if type(v) == str
else np.array(self.data.obs[k].isin(v)) for k, v in label.items()])
result_tcr = self.data.obs[
label_criteria &
np.array(self.data.obs[TDI_RESULT_FIELD.NUMBER_OF_UNIQUE_TCR] >= min_unique_tcr_number) &
np.array(self.data.obs[TDI_RESULT_FIELD.NUMBER_OF_INDIVIDUAL] >= min_individual_number) &
np.array(self.data.obs[TDI_RESULT_FIELD.CONVERGENCE] >= min_tcr_convergence_score) &
np.array(self.data.obs[TDI_RESULT_FIELD.DISEASE_ASSOCIATION] >= min_disease_association_score) &
additional_indices
]
else:
result_tcr = self.data.obs[
np.array(self.data.obs[TDI_RESULT_FIELD.NUMBER_OF_UNIQUE_TCR] >= min_unique_tcr_number) &
np.array(self.data.obs[TDI_RESULT_FIELD.NUMBER_OF_INDIVIDUAL] >= min_individual_number) &
np.array(self.data.obs[TDI_RESULT_FIELD.CONVERGENCE] >= min_tcr_convergence_score) &
np.array(self.data.obs[TDI_RESULT_FIELD.DISEASE_ASSOCIATION] >= min_disease_association_score) &
additional_indices
]
if TDI_RESULT_FIELD.NUMBER_OF_CELL in result_tcr.columns:
result_tcr = result_tcr[
np.array(result_tcr[TDI_RESULT_FIELD.NUMBER_OF_CELL] >= min_cell_number)
]
if rank_by == 'convergence':
result = result_tcr.sort_values(
TDI_RESULT_FIELD.CONVERGENCE,
ascending=False
)
elif rank_by == 'disease_association':
result = result_tcr.sort_values(
'disease_association_score',
ascending=False
)
elif rank_by == 'individual':
result = result_tcr.sort_values(
'number_of_individuals',
ascending=False
)
elif rank_by == 'unique_tcr':
result = result_tcr.sort_values(
'number_of_unique_tcrs',
ascending=False
)
tcrs = result.iloc[rank,0]
cdr3a, cdr3b, trav, traj, trbv, trbj, individual = list(map(
list,
np.array(list(map(lambda x: (x.cdr3a, x.cdr3b, x.trav, x.traj, x.trbv, x.trbj, x.individual), tcrs))).T
))
if return_background_tcrs:
return {
'cluster_index': result['cluster_index'][rank],
'tcrs': tcrs,
'cdr3a': cdr3a,
'cdr3b': cdr3b,
'trav': trav,
'traj': traj,
'trbv': trbv,
'trbj': trbj,
'individual': individual
}, self.get_tcrs_by_cluster_index(result['cluster_index'][rank], len(tcrs))
else:
return {
'cluster_index': result['cluster_index'][rank],
'tcrs': tcrs,
'cdr3a': cdr3a,
'cdr3b': cdr3b,
'trav': trav,
'traj': traj,
'trbv': trbv,
'trbj': trbj,
'individual': individual
}
def _get_tcrs_gex_embedding_coordinates(self, use_rep: str = 'X_umap') -> Dict[str, np.ndarray]:
"""
Get the gex embedding coordinates for each tcr
:param use_rep: the representation to use. Default: 'X_umap'
:return: a dictionary containing the gex embedding coordinates for each tcr
"""
result_tcr = self.data.obs
coordinates = {}
tcrs2int = dict(zip(self.gex_adata.obs['tcr'], range(len(self.gex_adata))))
pbar = get_tqdm()(total=len(result_tcr), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
for i in range(len(result_tcr)):
tcrs = list(result_tcr.iloc[i,0].split(","))
coordinates[result_tcr.index[i]] = self.gex_adata.obsm[use_rep][
list(map(tcrs2int.get, list(filter(lambda x: x != '-', tcrs))))
]
pbar.update(1)
pbar.close()
return coordinates
[docs]
def get_tcrs_by_cluster_index(
self,
cluster_index: int,
_n_after: int = 0
) -> List[str]:
"""
Get the tcrs for a specific cluster
:param cluster_index: the cluster index
:param _n_after: the number of tcrs to skip
"""
_result = self.tcr_df.iloc[
self.I[int(cluster_index)]
]
cdr3a = list(_result['CDR3a'])[_n_after:]
cdr3b = list(_result['CDR3b'])[_n_after:]
trav = list(_result['TRAV'])[_n_after:]
traj = list(_result['TRAJ'])[_n_after:]
trbv = list(_result['TRBV'])[_n_after:]
trbj = list(_result['TRBJ'])[_n_after:]
tcrs = list(_result['tcr'])[_n_after:]
labels = None
if self.cluster_label:
labels = list(_result[self.cluster_label])[_n_after:]
individual = list(_result['individual'])[_n_after:]
return {
'cluster_index': cluster_index,
'tcrs': tcrs,
'cdr3a': cdr3a,
'cdr3b': cdr3b,
'trav': trav,
'traj': traj,
'trbv': trbv,
'trbj': trbj,
'individual': individual,
'labels': labels
}