Training Gene Expression (GEX) data by VAE

This is a tutorial for training GEX data by VAE. We use scAtlasVAE package to integrate the multi-batch GEX data. We use the human huARdb v2 reference dataset as an example.

Load the reference dataset

1import tcr_deep_insight as tdi
2import torch
3
4gex_reference_adata = tdi.data.human_gex_reference_v2()

Construct and train the GEX model

 1# GEXModelingVAE is an alias of scatlasvae.model.scAtlasVAE
 2
 3model = tdi.model.GEXModelingVAE(
 4  gex_reference_adata,
 5  batch_key=['study_name','sample_name'],
 6  n_latent=10,
 7  batch_hidden_dim=24
 8)
 9
10model.fit()

Extract the GEX embedding and Save the trained model

1gex_reference_adata.obsm['X_gex'] = model.get_latent_representation()
2
3torch.save(
4  model.state_dict(),
5  "/PATH/TO//tcr_deep_insight/data/pretrained_weights/human_scatlasvae_gex_v2.ckpt"
6)

For more detailed information of scAtlasVAE, please refer to the scAtlasVAE documentation.

Note

The trained model is available at Zenodo.

Clustering and UMAP visualization

The downstream analysis can be performed using scanpy package’s standard workflow.

1import scanpy as sc
2
3sc.pp.neighbors(gex_reference_adata, use_rep='X_gex', n_neighbors=15)
4sc.tl.umap(gex_reference_adata)
5sc.tl.leiden(gex_reference_adata)
6
7sc.pl.umap(gex_reference_adata, color='leiden')