Source code for amulety.bcr_embeddings

"""
BCR embedding functions using various models.
"""
# Please order alphabetically by function name.
# ruff: noqa: N806

import logging
import math
import os
import subprocess
import time
from typing import Optional

import pandas as pd
import torch
from torch.nn.functional import pad

from amulety.protein_embeddings import custommodel
from amulety.utils import batch_loader, insert_space_every_other_except_cls

logger = logging.getLogger(__name__)


[docs] def antiberty( sequences, cache_dir: Optional[str] = None, batch_size: int = 50, residue_level: bool = False, ): """ Embeds sequences using the AntiBERTy model. The maximum length of the sequences to be embedded is 510. Parameters: sequences: pd.Series for single chain or pd.DataFrame for H+L mode """ from antiberty import AntiBERTyRunner max_seq_length = 510 # Handle both Series (single chain) and DataFrame (H+L) inputs if isinstance(sequences, pd.DataFrame): # H+L mode: DataFrame contains rows with 'chain' column indicating H or L if "chain" in sequences.columns: # Use the sequence column that the user provides sequence_col_candidates = ["sequence_vdj_aa", "sequence_aa", "sequence"] sequence_col = None for col in sequence_col_candidates: if col in sequences.columns: sequence_col = col break if sequence_col is None: raise ValueError( f"No recognized sequence column found in DataFrame. Expected one of: {sequence_col_candidates}" ) X = sequences[sequence_col].apply(lambda a: str(a)[:max_seq_length]) else: raise ValueError("DataFrame input must contain 'chain' column for H+L mode") else: # Single chain mode - sequences is a Series X = sequences.apply(lambda a: str(a)[:max_seq_length]) X = X.str.replace("<cls><cls>", "[CLS][CLS]") X = X.apply(insert_space_every_other_except_cls) sequences_processed = X.str.replace(" ", " ") antiberty_runner = AntiBERTyRunner() model_size = sum(p.numel() for p in antiberty_runner.model.parameters()) logger.info("AntiBERTy loaded. Size: %s M", round(model_size / 1e6, 2)) start_time = time.time() n_seqs = len(sequences_processed) dim = max_seq_length + 2 n_batches = math.ceil(n_seqs / batch_size) if residue_level: embeddings = torch.empty((n_seqs, max_seq_length, dim)) else: embeddings = torch.empty((n_seqs, dim)) i = 1 for start, end, batch in batch_loader(sequences_processed, batch_size): logger.info("Batch %s/%s", i, n_batches) x = antiberty_runner.embed(batch) if not residue_level: x_keep = [a.mean(axis=0) for a in x] if residue_level: x_keep = [] for a in x: if a.shape[0] < max_seq_length: a_pad = pad(a.clone().detach(), (0, 0, 0, max_seq_length - a.shape[0])) x_keep.append(a_pad) embeddings[start:end] = torch.stack(x_keep) i += 1 end_time = time.time() logger.info("Took %s seconds", round(end_time - start_time, 2)) return embeddings
[docs] def ablang( sequences, batch_size: int = 50, residue_level: bool = False, ): """ Embeds antibody sequences using the AbLang model. Note:\n AbLang consists of two models: one for heavy chains and one for light chains. Each AbLang model has two parts: AbRep (creates representations) and AbHead (predicts amino acids). Trained on antibody sequences in the OAS database, demonstrating power in restoring missing residues. This is a key capability for B-cell receptor repertoire sequencing data. Maximum sequence length: 160 amino acids. Reference: https://github.com/oxpig/AbLang Parameters: sequences: pd.Series for single chain or pd.DataFrame for H+L mode batch_size: int: Number of sequences to process in each batch. residue_level: bool: If True, returns residue-level embeddings. """ import sys # Guard for Python version incompatibility (numba dependency unsupported on >=3.14) if sys.version_info >= (3, 14): raise ImportError( "AbLang is unavailable on Python 3.14+ due to numba version constraints. " "Please reinstall amulety under Python <3.14 to include the AbLang model." ) try: import ablang # noqa: F401 except ImportError as e: raise ImportError("AbLang is not installed. Install with: pip install ablang (requires Python <3.14).") from e max_seq_length = 160 # Handle both Series (single chain) and DataFrame (H+L) inputs if isinstance(sequences, pd.DataFrame): # H+L mode: DataFrame contains rows with 'chain' column indicating H or L if "chain" in sequences.columns: # Use the sequence column that the user provides sequence_col_candidates = ["sequence_vdj_aa", "sequence_aa", "sequence"] sequence_col = None for col in sequence_col_candidates: if col in sequences.columns: sequence_col = col break if sequence_col is None: raise ValueError( f"No recognized sequence column found in DataFrame. Expected one of: {sequence_col_candidates}" ) # Process sequences and track chain types sequences_data = [] for _, row in sequences.iterrows(): seq = str(row[sequence_col])[:max_seq_length].upper() # Convert to uppercase for AbLang seq = seq.replace("<CLS><CLS>", "") # Remove CLS tokens that AbLang doesn't understand chain_type = row["chain"] sequences_data.append((seq, chain_type)) else: raise ValueError("DataFrame input must contain 'chain' column for H+L mode") else: # Single chain mode - assume heavy chain as default sequences_data = [ (str(seq)[:max_seq_length].upper().replace("<CLS><CLS>", ""), "H") for seq in sequences ] # Convert to uppercase and remove CLS tokens for AbLang # Initialize AbLang models heavy_ablang = None light_ablang = None # Load models based on chain types present chain_types = set(chain_type for _, chain_type in sequences_data) if "H" in chain_types: heavy_ablang = ablang.pretrained("heavy") heavy_ablang.freeze() logger.info("AbLang heavy chain model loaded") if "L" in chain_types: light_ablang = ablang.pretrained("light") light_ablang.freeze() logger.info("AbLang light chain model loaded") # Generate embeddings using appropriate models embeddings_list = [] start_time = time.time() n_seqs = len(sequences_data) n_batches = math.ceil(n_seqs / batch_size) # Process in batches for efficiency i = 1 for start_idx in range(0, n_seqs, batch_size): end_idx = min(start_idx + batch_size, n_seqs) batch_data = sequences_data[start_idx:end_idx] logger.info("Batch %s/%s", i, n_batches) # Group sequences by chain type for batch processing heavy_seqs = [] light_seqs = [] heavy_indices = [] light_indices = [] for idx, (seq, chain_type) in enumerate(batch_data): if chain_type == "H": heavy_seqs.append(seq) heavy_indices.append(idx) else: # "L" or other chain types light_seqs.append(seq) light_indices.append(idx) # Initialize batch embeddings list with None placeholders batch_embeddings = [None] * len(batch_data) # Helper function to process embeddings with padding def process_embeddings(embeddings, indices, apply_padding): for idx, seq_embedding in zip(indices, embeddings): seq_tensor = torch.tensor(seq_embedding) if apply_padding and seq_tensor.shape[0] < max_seq_length + 2: pad_length = max_seq_length + 2 - seq_tensor.shape[0] seq_tensor = pad(seq_tensor, (0, 0, 0, pad_length)) batch_embeddings[idx] = seq_tensor # Process heavy chain sequences in batch if heavy_seqs and heavy_ablang is not None: mode = "rescoding" if residue_level else "seqcoding" heavy_embeddings = heavy_ablang(heavy_seqs, mode=mode) process_embeddings(heavy_embeddings, heavy_indices, residue_level) # Process light chain sequences in batch if light_seqs and light_ablang is not None: mode = "rescoding" if residue_level else "seqcoding" light_embeddings = light_ablang(light_seqs, mode=mode) process_embeddings(light_embeddings, light_indices, residue_level) # Handle fallback case: if light sequences exist but no light model, use heavy model elif light_seqs and heavy_ablang is not None: mode = "rescoding" if residue_level else "seqcoding" fallback_embeddings = heavy_ablang(light_seqs, mode=mode) process_embeddings(fallback_embeddings, light_indices, residue_level) embeddings_list.extend(batch_embeddings) i += 1 embeddings = torch.stack(embeddings_list) end_time = time.time() logger.info("AbLang embedding completed. Took %s seconds", round(end_time - start_time, 2)) return embeddings
[docs] def antiberta2( sequences, cache_dir: Optional[str] = None, residue_level: bool = False, batch_size: int = 50, ): """ Embeds sequences using the antiBERTa2 RoFormer model. The maximum length of the sequences to be embedded is 256. Parameters: sequences: pd.Series for single chain or pd.DataFrame for H+L mode cache_dir: Optional[str]: Directory to cache the model files. residue_level: bool: If True, returns residue-level embeddings. batch_size: int: Number of sequences to process in each batch. """ from transformers import RoFormerForMaskedLM, RoFormerTokenizer max_seq_length = 256 device = "cuda" if torch.cuda.is_available() else "cpu" # Handle both Series (single chain) and DataFrame (H+L) inputs if isinstance(sequences, pd.DataFrame): # H+L mode: DataFrame contains rows with 'chain' column indicating H or L if "chain" in sequences.columns: # Use the sequence column that the user provides sequence_col_candidates = ["sequence_vdj_aa", "sequence_aa", "sequence"] sequence_col = None for col in sequence_col_candidates: if col in sequences.columns: sequence_col = col break if sequence_col is None: raise ValueError( f"No recognized sequence column found in DataFrame. Expected one of: {sequence_col_candidates}" ) X = sequences[sequence_col].apply(lambda a: str(a)[:max_seq_length]) else: raise ValueError("DataFrame input must contain 'chain' column for H+L mode") else: # Single chain mode - sequences is a Series X = sequences.apply(lambda a: str(a)[:max_seq_length]) X = X.str.replace("<cls><cls>", "[CLS][CLS]") X = X.apply(insert_space_every_other_except_cls) X = X.str.replace(" ", " ") sequences_array = X.values tokenizer = RoFormerTokenizer.from_pretrained("alchemab/antiberta2", cache_dir=cache_dir) model = RoFormerForMaskedLM.from_pretrained("alchemab/antiberta2", cache_dir=cache_dir) model = model.to(device) model_size = sum(p.numel() for p in model.parameters()) logger.info("AntiBERTa2 loaded. Size: %s M", model_size / 1e6) start_time = time.time() n_seqs = len(sequences_array) dim = 1024 n_batches = math.ceil(n_seqs / batch_size) if residue_level: embeddings = torch.empty((n_seqs, max_seq_length, dim)) else: embeddings = torch.empty((n_seqs, dim)) i = 1 for start, end, batch in batch_loader(sequences_array, batch_size): logger.info("Batch %s/%s.", i, n_batches) x = torch.tensor( [ tokenizer.encode( seq, padding="max_length", truncation=True, max_length=max_seq_length, return_special_tokens_mask=True, ) for seq in batch ] ).to(device) attention_mask = (x != tokenizer.pad_token_id).float().to(device) with torch.no_grad(): outputs = model(x, attention_mask=attention_mask, output_hidden_states=True) outputs = outputs.hidden_states[-1] outputs = list(outputs.detach()) # aggregate across the residuals, ignore the padded bases if not residue_level: for j, a in enumerate(attention_mask): outputs[j] = outputs[j][a == 1, :].mean(0) embeddings[start:end] = torch.stack(outputs) del x del attention_mask del outputs i += 1 end_time = time.time() logger.info("Took %s seconds", round(end_time - start_time, 2)) return embeddings
[docs] def balm_paired( sequences, cache_dir: str = "/tmp/amulety", residue_level: bool = False, batch_size: int = 50, ): """ Embeds sequences using the BALM-paired model. The maximum length of the sequences to be embedded is 1024. The embedding dimension is 1024. Parameters: sequences: pd.Series for single chain or pd.DataFrame for H+L mode cache_dir: Optional[str]: Directory to cache the model files. residue_level: bool: If True, returns residue-level embeddings. batch_size: int: Number of sequences to process in each batch. """ os.makedirs(cache_dir, exist_ok=True) model_name = "BALM-paired_LC-coherence_90-5-5-split_122222" model_path = os.path.join(cache_dir, model_name) embedding_dimension = 1024 max_seq_length = 510 if not os.path.exists(model_path): try: command = f""" curl -L -o {os.path.join(cache_dir, "BALM-paired.tar.gz")} https://zenodo.org/records/8237396/files/BALM-paired.tar.gz tar -xzf {os.path.join(cache_dir, "BALM-paired.tar.gz")} -C {cache_dir} rm {os.path.join(cache_dir, "BALM-paired.tar.gz")} """ subprocess.run(command, shell=True, check=True, text=True, capture_output=True) except subprocess.CalledProcessError as e: raise RuntimeError(f"Error downloading or extracting BALM-paired model: {e}") embeddings = custommodel( sequences=sequences, model_path=model_path, embedding_dimension=embedding_dimension, batch_size=batch_size, max_seq_length=max_seq_length, cache_dir=cache_dir, residue_level=residue_level, ) return embeddings