"""
TCR embedding functions using various models.
"""
# ruff: noqa: N806
# Please order alphabetically by function name.
import logging
import math
import time
from typing import Optional
import pandas as pd
import torch
from torch.nn.functional import pad
from amulety.utils import batch_loader
# Optional imports for TCR models are handled within individual functions
logger = logging.getLogger(__name__)
[docs]
def check_tcr_dependencies():
"""Check if optional TCR embedding dependencies are installed and provide installation instructions."""
missing_deps = []
available_models = []
# Check TCR-BERT (requires transformers)
try:
from transformers import BertModel, BertTokenizer # noqa: F401
available_models.append("TCR-BERT")
except ImportError:
missing_deps.append(("TCR-BERT", "pip install transformers"))
# Check TCRT5 (requires transformers)
try:
from transformers import T5ForConditionalGeneration, T5Tokenizer # noqa: F401
available_models.append("TCRT5")
except ImportError:
missing_deps.append(("TCRT5", "pip install transformers"))
# Check Immune2Vec (in protein_embeddings but used for TCR too)
try:
import gensim # noqa: F401
from embedding import sequence_modeling # noqa: F401
available_models.append("Immune2Vec")
except ImportError as e:
if "gensim" in str(e):
missing_deps.append(
(
"Immune2Vec",
"pip install gensim>=3.8.3 && git clone https://bitbucket.org/yaarilab/immune2vec_model.git",
)
)
else:
missing_deps.append(
("Immune2Vec", "git clone https://bitbucket.org/yaarilab/immune2vec_model.git && add to Python path")
)
# Report results
if available_models:
logger.info("Available TCR models: %s", ", ".join(available_models))
if missing_deps:
logger.warning("Missing TCR model dependencies: %s", ", ".join([dep[0] for dep in missing_deps]))
else:
logger.info("All TCR embedding dependencies are available!")
return missing_deps
[docs]
def tcr_bert(
sequences,
cache_dir: Optional[str] = None,
batch_size: int = 32,
residue_level: bool = False,
):
"""
Embeds T-Cell Receptor (TCR) sequences using the TCR-BERT model.
Args:
sequences: Input TCR sequences (pd.Series for single chain or pd.DataFrame for H+L mode)
cache_dir: Directory to cache model files
batch_size: Number of sequences to process in each batch
Note:\n
Pretrained on 88,403 human TRA/TRB sequences from VDJdb and PIRD.
Non-fine-tuned version focused on human TCR data only. The maximum length of the sequences to be embedded is 64.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
max_seq_length = 64
dim = 768
# Handle both Series (single chain) and DataFrame (H+L) inputs
if isinstance(sequences, pd.DataFrame):
# DataFrame mode: Extract sequence data from the 'chain' column
if "chain" in sequences.columns:
X = sequences["chain"]
else:
# Fallback: look for common sequence column names
sequence_col_candidates = ["sequence_vdj_aa", "sequence_aa", "sequence"]
sequence_col = None
for col in sequence_col_candidates:
if col in sequences.columns:
sequence_col = col
break
if sequence_col is None:
raise ValueError(
f"No sequence column found in DataFrame. Expected 'chain' or one of {sequence_col_candidates}"
)
X = sequences[sequence_col]
else:
# Series mode: direct sequence input
X = sequences
X = X.apply(lambda a: a[:max_seq_length])
# TCR-BERT expects space-separated amino acid sequences
# Convert "CASSLAPGATNEKLFF" to "C A S S L A P G A T N E K L F F"
X = X.apply(lambda seq: seq.replace("<cls><cls>", " ")) # Remove any existing special tokens
X = X.apply(lambda seq: " ".join(list(seq))) # Space-separate amino acids
sequences = X.values
logger.info("Loading TCR-BERT model for TCR embedding...")
try:
from transformers import BertModel, BertTokenizer
# non-fine-tuned TCR-BERT model pre-trained on human data only
model_name = "wukevin/tcr-bert"
tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir, do_lower_case=False)
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
logger.info("Successfully loaded TCR-BERT model")
except Exception as e:
logger.warning("TCR-BERT model not available, using BERT-base as fallback: %s", str(e))
try:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", cache_dir=cache_dir)
model = BertModel.from_pretrained("bert-base-uncased", cache_dir=cache_dir)
logger.info("Using BERT-base as fallback for TCR-BERT")
except Exception as e2:
logger.error("Failed to load any BERT model: %s", str(e2))
raise RuntimeError("Could not load TCR-BERT or fallback model") from e2
model = model.to(device)
model_size = sum(p.numel() for p in model.parameters())
logger.info("TCR-BERT model loaded. Size: %s M", round(model_size / 1e6, 2))
start_time = time.time()
n_seqs = len(sequences)
n_batches = math.ceil(n_seqs / batch_size)
if residue_level:
embeddings = torch.empty((n_seqs, max_seq_length, dim))
else:
embeddings = torch.empty((n_seqs, dim))
i = 1
for start, end, batch in batch_loader(sequences, batch_size):
logger.info("TCR-BERT Batch %s/%s.", i, n_batches)
x = torch.tensor(
[
tokenizer.encode(
seq,
padding="max_length",
truncation=True,
max_length=max_seq_length,
return_special_tokens_mask=True,
)
for seq in batch
]
).to(device)
attention_mask = (x != tokenizer.pad_token_id).float().to(device)
with torch.no_grad():
outputs = model(input_ids=x, attention_mask=attention_mask)
outputs = outputs.last_hidden_state
outputs = list(outputs.detach())
if not residue_level:
for j, a in enumerate(attention_mask):
outputs[j] = outputs[j][a == 1, :].mean(0)
embeddings[start:end] = torch.stack(outputs)
del x
del attention_mask
del outputs
i += 1
end_time = time.time()
logger.info("TCR-BERT embedding took %s seconds", round(end_time - start_time, 2))
return embeddings
[docs]
def tcrt5(
sequences,
cache_dir: Optional[str] = None,
batch_size: int = 32,
residue_level: bool = False,
):
"""
Embeds T-Cell Receptor (TCR) sequences using the TCRT5 model.
Args:
sequences: Input TCR sequences (pd.Series for single chain or pd.DataFrame for H+L mode)
cache_dir: Directory to cache model files
batch_size: Number of sequences to process in each batch
Note:\n
TCRT5 was pre-trained on masked span reconstruction using ~14M CDR3 β sequences from TCRdb
and ~780k peptide-pseudosequence pairs from IEDB. This model only supports beta chains (H chains for TCR).
Maximum sequence length: 20 amino acids.
Embedding dimension: 256.
Reference: https://huggingface.co/dkarthikeyan1/tcrt5_pre_tcrdb
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
max_seq_length = 20
dim = 256
# Handle both Series (single chain) and DataFrame (H+L) inputs
if isinstance(sequences, pd.DataFrame):
# DataFrame mode: Extract sequence data from the 'chain' column
if "chain" in sequences.columns:
X = sequences["chain"]
else:
# Fallback: look for common sequence column names
sequence_col_candidates = ["sequence_vdj_aa", "sequence_aa", "sequence"]
sequence_col = None
for col in sequence_col_candidates:
if col in sequences.columns:
sequence_col = col
break
if sequence_col is None:
raise ValueError(
f"No sequence column found in DataFrame. Expected 'chain' or one of {sequence_col_candidates}"
)
X = sequences[sequence_col]
else:
# Series mode: direct sequence input
X = sequences
X = X.apply(lambda a: a[:max_seq_length])
sequences = X.values
logger.info("Loading TCRT5 model for TCR embedding...")
start_time = time.time()
try:
from transformers import T5ForConditionalGeneration, T5Tokenizer
model_name = "dkarthikeyan1/tcrt5_pre_tcrdb"
tokenizer = T5Tokenizer.from_pretrained(model_name, cache_dir=cache_dir)
model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
model.to(device)
model.eval()
except ImportError as e:
logger.error("transformers library not found: %s", str(e))
raise ImportError("transformers library is required for TCRT5. Install with: pip install transformers") from e
except Exception as e:
logger.error("Failed to load TCRT5 model: %s", str(e))
raise RuntimeError("Could not load TCRT5 model") from e
n_seqs = len(sequences)
n_batches = math.ceil(n_seqs / batch_size)
if residue_level:
embeddings = torch.zeros((n_seqs, max_seq_length, dim))
else:
embeddings = torch.zeros((n_seqs, dim))
i = 1
for start, end, batch in batch_loader(sequences, batch_size):
logger.info("TCRT5 Batch %s/%s.", i, n_batches)
# Format sequences for TCRT5 (just the sequence without PMHC format for CDR3 embedding)
encoded_batch = tokenizer(
batch.tolist(), return_tensors="pt", padding=True, truncation=True, max_length=max_seq_length
).to(device)
with torch.no_grad():
# Use encoder outputs for embedding representation
enc_outputs = model.encoder(**encoded_batch)
padded_embeddings = enc_outputs.last_hidden_state # [batch_size, max_seq_len_in_batch, dim]
# Ensure that padding tokens are represented with zeros
padding_mask = (
encoded_batch["attention_mask"].unsqueeze(-1).float()
) # [batch_size, max_seq_len_in_batch, 1]
masked_embedding = padded_embeddings * padding_mask # [batch_size, max_seq_len_in_batch, dim]
if not residue_level:
# Use mean pooling over sequence length to get fixed-size embedding
# Ignore padding tokens while taking the mean
seq_lengths = padding_mask.sum(dim=1) # [batch_size, 1]
sequence_embedding = masked_embedding.sum(dim=1) / seq_lengths # [batch_size, dim]
else:
# Additional padding to match max_seq_length even if all sequences are shorter
padding_length = max_seq_length - masked_embedding.size(1)
# pad specifies where padding is added: (embedding_dim_left, embedding_dim_right, seq_dim_left, seq_dim_right)
sequence_embedding = pad(
masked_embedding, pad=(0, 0, 0, padding_length), mode="constant", value=0
) # [batch_size, max_seq_length, dim]
embeddings[start:end] = sequence_embedding.cpu()
i += 1
end_time = time.time()
logger.info("TCRT5 embedding took %s seconds", round(end_time - start_time, 2))
return embeddings