Source code for amulety.utils

"""Main module."""
import logging
import subprocess
import warnings
from typing import Iterable

import pandas as pd

logger = logging.getLogger(__name__)


[docs] class ConditionalFormatter(logging.Formatter):
[docs] def format(self, record): if hasattr(record, "simple") and record.simple: return record.getMessage() else: return logging.Formatter.format(self, record)
[docs] def batch_loader(data: Iterable, batch_size: int): """ This function generates batches from the provided data. Parameters: data (Iterable): The data to be batched. batch_size (int): The size of each batch. Yields: tuple: A tuple containing the start index, end index, and the batch of data. """ num_samples = len(data) for i in range(0, num_samples, batch_size): end_idx = min(i + batch_size, num_samples) yield i, end_idx, data[i:end_idx]
[docs] def insert_space_every_other_except_cls(input_string: str): """ This function inserts a space after every character in the input string, except for the '[CLS]' token. Parameters: input_string (str): The input string where spaces are to be inserted. Returns: str: The modified string with spaces inserted. """ parts = input_string.split("[CLS]") modified_parts = ["".join([char + " " for char in part]).strip() for part in parts] result = " [CLS] ".join(modified_parts) return result
[docs] def get_cdr3_sequence_column(airr: pd.DataFrame, default_sequence_col: str): """ Get the best CDR3 sequence column for TCR data. Parameters: airr (pd.DataFrame): AIRR DataFrame default_sequence_col (str): Default sequence column name Returns: str: The best CDR3 sequence column name """ # Preferred CDR3 columns in order of preference cdr3_columns = ["junction_aa", "cdr3_aa"] for col in cdr3_columns: if col in airr.columns: # Check if column has non-null values if not airr[col].isna().all(): logger.info(f"Using CDR3 column: {col}") return col # If no CDR3 columns found, return the default logger.warning(f"No CDR3 columns found, using default: {default_sequence_col}") return default_sequence_col
[docs] def process_airr( airr_df: pd.DataFrame, chain_mode: str, sequence_col: str = "sequence_vdj_aa", cell_id_col: str = "cell_id", duplicate_col: str = "duplicate_count", receptor_type: str = "all", mode: str = "concat", ): """ Processes AIRR-seq data and returns a pandas DataFrame containing sequences to embed. Uses AMULETY's unified H/L/HL interface for both BCR and TCR data. See embed_airr() function documentation for detailed chain parameter explanations. Parameters: airr_df (pandas.DataFrame): Input AIRR rearrangement table as a pandas DataFrame. chain_mode (str): The input chain, one of ["H", "L", "HL", "LH", "H+L"]. sequence_col (str): The name of the column containing the amino acid sequences to embed. cell_id_col (str): The name of the column containing the single-cell barcode. receptor_type (str): The receptor type to validate, one of ["BCR", "TCR", "all"]. - "BCR": validates only BCR chains (IGH, IGL, IGK) are present - "TCR": validates only TCR chains (TRA, TRB, TRG, TRD) are present - "all": allows both BCR and TCR chains in the same file duplicate_col (str): The name of the numeric column used to select the best chain when multiple chains of the same type exist per cell. Default: "duplicate_count". mode (str): Mode to use in concatenating sequences. By default it concatenates the sequences when the HL chain is provided (concat), it can also tabulate the sequences alone (tab) or together with the locus and segment (tab_locus_gene). Returns: pandas.DataFrame: Dataframe with formatted sequences. Raises: ValueError: If chain is not one of ["H", "L", "HL", "LH", "H+L"] or receptor_type validation fails. """ allowed_sequence_input = ["H", "L", "HL", "LH", "H+L"] if chain_mode not in allowed_sequence_input: raise ValueError(f"Input x must be one of {allowed_sequence_input}.") allowed_modes = ["concat", "tab", "tab_locus_gene"] if mode not in allowed_modes: raise ValueError(f"Mode must be one of {allowed_modes}.") # Create locus column if not present data = airr_df.copy() present_loci = set(data["locus"].unique()) # Determine available chains based on locus available_chains = set() heavy_loci = {"IGH", "TRB", "TRD"} # Heavy chains: IGH for BCR, TRB/TRD for TCR light_loci = {"IGL", "IGK", "TRA", "TRG"} # Light chains: IGL/IGK for BCR, TRA/TRG for TCR if present_loci & heavy_loci: available_chains.add("H") if present_loci & light_loci: available_chains.add("L") # Validate chain availability if chain_mode == "H" and "H" not in available_chains: raise ValueError( f"Chain parameter 'H' requires heavy chain data, but no heavy chain loci found. " f"Available loci: {', '.join(sorted(present_loci))}. Use --chain L for light chain analysis." ) elif chain_mode == "L" and "L" not in available_chains: raise ValueError( f"Chain parameter 'L' requires light chain data, but no light chain loci found. " f"Available loci: {', '.join(sorted(present_loci))}. Use --chain H for heavy chain analysis." ) elif chain_mode in ["HL", "LH", "H+L"] and not available_chains.issuperset({"H", "L"}): missing = {"H", "L"} - available_chains if "H" in missing: raise ValueError( f"Chain parameter '{chain_mode}' requires heavy chain data, but no heavy chain loci found. " f"Available loci: {', '.join(sorted(present_loci))}. Use --chain L for light chain analysis." ) elif "L" in missing: raise ValueError( f"Chain parameter '{chain_mode}' requires light chain data, but no light chain loci found. " f"Available loci: {', '.join(sorted(present_loci))}. Use --chain H for heavy chain analysis." ) # ===== RECEPTOR TYPE VALIDATION ===== bcr_loci = {"IGH", "IGL", "IGK"} tcr_loci = {"TRA", "TRB", "TRG", "TRD"} present_loci = set(data["locus"].unique()) bcr_present = bool(present_loci & bcr_loci) tcr_present = bool(present_loci & tcr_loci) if receptor_type.upper() == "BCR": if tcr_present and bcr_present: tcr_chains = present_loci & tcr_loci logger.warning( "TCR chains (%s) detected in BCR-only mode. These will be removed and only BCR chains used.", list(tcr_chains), ) data = data[data["locus"].isin(bcr_loci)] elif tcr_present and not bcr_present: raise ValueError("No BCR chains (IGH, IGL, IGK) found in data") elif receptor_type.upper() == "TCR": if bcr_present and tcr_present: bcr_chains = present_loci & bcr_loci logger.warning( "BCR chains (%s) detected in TCR-only mode. These will be removed and only TCR chains used.", list(bcr_chains), ) data = data[data["locus"].isin(tcr_loci)] elif bcr_present and not tcr_present: raise ValueError("No TCR chains (TRA, TRB, TRG, TRD) found in data.") elif receptor_type.upper() == "ALL": logger.info("Processing both BCR and TCR sequences from the file.") else: raise ValueError(f"receptor_type must be one of ['BCR', 'TCR', 'all'], got '{receptor_type}'") # ===== UNIFIED CHAIN MAPPING ===== # Map loci to unified H/L interface data.loc[:, "chain"] = data.loc[:, "locus"].apply(lambda x: "H" if x in ["IGH", "TRB", "TRD"] else "L") # Check for gamma/delta TCR and warn about model compatibility gamma_delta_present = bool(present_loci & {"TRG", "TRD"}) if gamma_delta_present and receptor_type.upper() in ["TCR", "ALL"]: gamma_delta_chains = present_loci & {"TRG", "TRD"} logger.warning( "Gamma/Delta TCR chains (%s) detected. Note: TCR-specific models (TCR-BERT, Trex, TCREMP) " "are primarily trained on Alpha/Beta TCRs. For Gamma/Delta TCRs, consider using general protein " "models (ESM2, ProtT5) which support all TCR types.", list(gamma_delta_chains), ) # Determine data type is_bulk = cell_id_col not in data.columns is_single_cell = not is_bulk and data[cell_id_col].notna().all() is_mixed = not is_bulk and not is_single_cell if is_bulk: logger.info("Bulk AIRR data detected (no cell_id column).") elif is_single_cell: logger.info("Single-cell AIRR data detected (all entries have cell_id).") elif is_mixed: logger.info("Mixed AIRR data detected (some entries have cell_id, others do not).") # Process based on chain_mode if chain_mode in ["HL", "LH"]: # HL/LH modes: error for bulk, process for single-cell/mixed if is_bulk: raise ValueError(f'Chain = "{chain_mode}" is invalid for bulk mode. Please use "H+L", "H" or "L" instead.') # Warning for LH order if chain_mode == "LH": warnings.warn( "LH (Light-Heavy) chain order detected. Most paired models are trained on HL (Heavy-Light) order. " "Using LH order may result in reduced accuracy. Consider using --chain_mode HL for better performance.", UserWarning, ) # If mixed data, filter to only sequences with cell_id if is_mixed: before_filter = len(data) data = data.loc[data[cell_id_col].notna(),] after_filter = len(data) removed_count = before_filter - after_filter if removed_count > 0: logger.info("Removed %d sequences without cell_id for paired chain processing", removed_count) data = concatenate_heavylight(data, sequence_col, cell_id_col, duplicate_col, order=chain_mode, mode=mode) elif chain_mode == "H+L": # H+L mode: same processing for all data types # Add dummy cell_id col if bulk data if is_bulk: data[cell_id_col] = pd.NA # For models that need H+L in tab_locus_gene format if mode == "tab_locus_gene": data = process_h_plus_l(data, sequence_col, cell_id_col, duplicate_col, mode=mode) else: # For other models that need H+L in separate entries data = process_h_plus_l(data, sequence_col, cell_id_col, duplicate_col, mode="tab") else: # Single chain mode (H or L): same processing for all data types before_filter = len(data) data = data.loc[data.chain == chain_mode] after_filter = len(data) removed_count = before_filter - after_filter if removed_count > 0: logger.info("Removed %d sequences not matching %s chain", removed_count, chain_mode) if is_bulk: data[cell_id_col] = pd.NA elif is_single_cell: # For models like TCREMP that need H+L in tab_locus_gene format if mode == "concat": data = process_h_plus_l(data, sequence_col, cell_id_col, duplicate_col, mode="tab") else: data = process_h_plus_l(data, sequence_col, cell_id_col, duplicate_col, mode=mode) return data.loc[:, sequence_col], data
[docs] def concatenate_heavylight( data: pd.DataFrame, sequence_col: str, cell_id_col: str, duplicate_col: str = "duplicate_count", order: str = "HL", mode: str = "concat", ): """ Concatenates heavy and light chain per cell using AMULETY's unified H/L interface. Concatenates sequences as: Heavy<cls><cls>Light (HL order) or Light<cls><cls>Heavy (LH order) for both BCR (IGH + IGL/IGK) and TCR (TRB/TRD + TRA/TRG) data. See embed_airr() documentation for chain mappings. If a cell contains multiple chains of the same type, selects the one with highest value in the selection column. Parameters: order (str): Chain concatenation order, either "HL" (Heavy-Light) or "LH" (Light-Heavy). Default: "HL". Parameters: data (pandas.DataFrame): Input data containing heavy and light chain information. Must include columns: cell_id_col, "chain", selection_col, sequence_col sequence_col (str): The name of the column containing the amino acid sequences to embed. cell_id_col (str): The name of the column containing the single-cell barcode. selection_col (str): The name of the numeric column used to select the best chain when multiple chains of the same type exist per cell. Default: "duplicate_count". mode (str): Mode to use in concatenating sequences. By default it concatenates the sequences (concat), it can also tabulate the sequences alone (tab) or together with the locus and segment (tab_locus_gene). Returns: pandas.DataFrame: Dataframe with concatenated heavy and light chains per cell. Format: HEAVY<cls><cls>LIGHT for each cell. Raises: ValueError: If required columns are missing or duplicate_col is not numeric. """ # TODO add check if multiple heavy chains per cell and warn users colnames = [cell_id_col, "locus", duplicate_col, sequence_col] missing_cols = [col for col in colnames if col not in data.columns] if missing_cols: raise ValueError( f"Column(s) {missing_cols} is/are not present in the input data and are needed to concatenate heavy and light chains." ) # Check that duplicate_col is numeric if not pd.api.types.is_numeric_dtype(data[duplicate_col]): raise ValueError( f"Selection column '{duplicate_col}' must be numeric. Found dtype: {data[duplicate_col].dtype}" ) # Check that duplicate_col does not contain NaN values if data[duplicate_col].isna().any(): raise ValueError(f"Selection column '{duplicate_col}' contains NaN values. Please remove them or fix them.") # if tie in maximum duplicate_col value, return the first occurrence data = data.loc[data.groupby([cell_id_col, "chain"])[duplicate_col].idxmax()] # First pivot dataframe according to chain column values (H and L) data_chain = data.pivot(index=cell_id_col, columns="chain", values=sequence_col) data_chain = data_chain.reset_index(level=cell_id_col) n_cells = data_chain.shape[0] data_chain = data_chain.dropna(axis=0) n_dropped = n_cells - data_chain.shape[0] if n_dropped > 0: logging.info("Dropping %s cells with missing heavy or light chain...", n_dropped) # Throw error if no rows left after dropping if data_chain.shape[0] == 0: raise ValueError("No cells with both heavy and light chains found.") if mode == "concat": # Concatenate based on order parameter if order == "HL": data_chain.loc[:, sequence_col] = data_chain["H"] + "<cls><cls>" + data_chain["L"] elif order == "LH": data_chain.loc[:, sequence_col] = data_chain["L"] + "<cls><cls>" + data_chain["H"] else: raise ValueError(f"Invalid order parameter: {order}. Must be 'HL' or 'LH'.") return data_chain elif mode == "tab": return data_chain elif mode == "tab_locus_gene": # Create locus_vgene and locus_jgene columns for TCREMP format data_full = data.copy() # Work with full data before pivot # Extract V and J gene information # For V genes: extract locus + V (e.g., TRA -> TRAV, TRB -> TRBV) data_full.loc[:, "locus_vgene"] = data_full["locus"] + "V" # For J genes: extract locus + J (e.g., TRA -> TRAJ, TRB -> TRBJ) data_full.loc[:, "locus_jgene"] = data_full["locus"] + "J" # Second pivot for V genes data_vgene = data_full.pivot(index=cell_id_col, columns="locus_vgene", values="v_call") data_vgene = data_vgene.reset_index() # Third pivot for J genes (only if j_call column exists) if "j_call" in data_full.columns: data_jgene = data_full.pivot(index=cell_id_col, columns="locus_jgene", values="j_call") data_jgene = data_jgene.reset_index() # Merge all three pivoted dataframes result = data_chain.merge(data_vgene, on=cell_id_col, how="outer") result = result.merge(data_jgene, on=cell_id_col, how="outer") else: # Only merge chain and V gene data if J gene data is not available result = data_chain.merge(data_vgene, on=cell_id_col, how="outer") # Add placeholder J gene columns for consistency for locus in data_full["locus"].unique(): j_col = f"{locus}J" result[j_col] = "Unknown" # Remove columns ending with 'D' (D gene related) as TCREMP doesn't need them # d_columns = [col for col in result.columns if col.endswith("D")] # if d_columns: # result = result.drop(columns=d_columns) return result else: raise ValueError(f"Invalid mode parameter: {mode}. Must be 'concat', 'tab', or 'tab_locus_gene'.")
[docs] def process_h_plus_l( data: pd.DataFrame, sequence_col: str, cell_id_col: str, duplicate_col: str = "duplicate_count", mode: str = "tab" ): """ Processes both heavy and light chains separately for H+L, H, or L formats. Returns a DataFrame with heavy and/or light chain sequences for each cell, keeping them as separate entries rather than concatenating them. Supports different output modes including tab_locus_gene format. If a cell contains multiple chains of the same type, selects the one with highest value in the selection column. Parameters: data (pandas.DataFrame): Input data containing chain information. sequence_col (str): The name of the column containing the amino acid sequences. cell_id_col (str): The name of the column containing the single-cell barcode. duplicate_col (str): The name of the numeric column used to select the best chain. mode (str): Output mode - "tab" for simple tabular format, "tab_locus_gene" for extended format with V/J gene information. Returns: pandas.DataFrame: Dataframe with processed chain sequences in the specified format. """ # Supports only single-cell or mixed data with cell_id_col # Validate selection column if duplicate_col not in data.columns: raise ValueError(f"Selection column '{duplicate_col}' not found in data.") if cell_id_col not in data.columns: raise ValueError(f"Cell ID column '{cell_id_col}' not found in data.") if not pd.api.types.is_numeric_dtype(data[duplicate_col]): raise ValueError( f"Selection column '{duplicate_col}' must be numeric. Found dtype: {data[duplicate_col].dtype}" ) # Check whether data is mixed single-cell and bulk # For single-cell data select best chain per cell and chain type # TODO this is not needed I think remove # is_mixed = data[cell_id_col].isna().any() # if is_mixed: # data_bulk = data.loc[data[cell_id_col].isna(),] # data_sc = data.loc[data[cell_id_col].notna(),] # data_sc = data_sc.loc[data_sc.groupby([cell_id_col, "chain"])[duplicate_col].idxmax()] # data = pd.concat([data_sc, data_bulk], ignore_index=True) # else: # data = data.loc[data.groupby([cell_id_col, "chain"])[duplicate_col].idxmax()] # Ensure the sequence column is properly included in the output if sequence_col not in data.columns: raise ValueError(f"Sequence column '{sequence_col}' not found in data.") if mode == "tab": # Simple tabular format - keep chains as separate entries # Keep original sequence_id for metadata merging - chain info is preserved in 'chain' column return data elif mode == "tab_locus_gene": # Extended format with V/J gene information for models like TCREMP # This handles H+L, H, or L chains separately with gene information # Create locus_vgene and locus_jgene columns data.loc[:, "locus_vgene"] = data["locus"] + "V" data.loc[:, "locus_jgene"] = data["locus"] + "J" # Start with base columns result = data[[cell_id_col, "sequence_id", sequence_col, "chain", "locus"]].copy() # Use vectorized operations to create dynamic V and J gene columns # Get unique loci to create the appropriate columns unique_loci = data["locus"].unique() # Initialize all possible V and J gene columns with NaN for locus in unique_loci: v_col = f"{locus}V" j_col = f"{locus}J" result[v_col] = pd.NA result[j_col] = pd.NA # Use vectorized assignment to populate the appropriate columns for locus in unique_loci: v_col = f"{locus}V" j_col = f"{locus}J" mask = data["locus"] == locus result.loc[mask, v_col] = data.loc[mask, "v_call"] # Only populate J gene if j_call column exists if "j_call" in data.columns: result.loc[mask, j_col] = data.loc[mask, "j_call"] else: # Set placeholder values for missing J gene information result.loc[mask, j_col] = "Unknown" return result else: raise ValueError(f"Invalid mode parameter: {mode}. Must be 'tab' or 'tab_locus_gene'.")
[docs] def check_dependencies(): """ Check if optional embedding dependencies are installed and provide installation instructions. This function checks all model types (BCR, TCR, and protein language models) for missing dependencies. Returns: list: List of tuples (model_name, installation_command) for missing dependencies """ missing_deps = [] available_models = [] # Check BCR models (included in requirements.txt but good to verify installation) try: from antiberty import AntiBERTyRunner # noqa: F401 available_models.append("AntiBERTy") except ImportError: missing_deps.append(("AntiBERTy", "pip install antiberty")) try: import ablang # noqa: F401 available_models.append("AbLang") except ImportError: missing_deps.append(("AbLang", "pip install ablang")) # Check TCR models try: result = subprocess.run(["tcremp-run", "-h"], capture_output=True, text=True, timeout=10) if result.returncode == 0: available_models.append("TCREMP") else: missing_deps.append( ( "TCREMP", "git clone https://github.com/antigenomics/tcremp.git && cd tcremp && pip install . (requires Python 3.11+)", ) ) except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError): missing_deps.append( ( "TCREMP", "git clone https://github.com/antigenomics/tcremp.git && cd tcremp && pip install . (requires Python 3.11+)", ) ) try: from transformers import BertModel, BertTokenizer # noqa: F401 available_models.append("TCR-BERT") except ImportError: missing_deps.append(("TCR-BERT", "pip install transformers")) try: from transformers import T5ForConditionalGeneration, T5Tokenizer # noqa: F401 available_models.append("TCRT5") except ImportError: missing_deps.append(("TCRT5", "pip install transformers")) # Check protein language models try: from transformers import EsmModel, EsmTokenizer # noqa: F401 available_models.append("ESM2") except ImportError: missing_deps.append(("ESM2", "pip install transformers")) try: import sentencepiece # noqa: F401 from transformers import T5EncoderModel, T5Tokenizer # noqa: F401 available_models.append("ProtT5") except ImportError as e: if "sentencepiece" in str(e): missing_deps.append(("ProtT5", "pip install transformers sentencepiece")) else: missing_deps.append(("ProtT5", "pip install transformers")) try: import gensim # noqa: F401 from embedding import sequence_modeling # noqa: F401 available_models.append("Immune2Vec") except ImportError as e: if "gensim" in str(e): missing_deps.append( ( "Immune2Vec", "pip install gensim>=3.8.3 && git clone https://bitbucket.org/yaarilab/immune2vec_model.git", ) ) else: missing_deps.append( ("Immune2Vec", "git clone https://bitbucket.org/yaarilab/immune2vec_model.git && add to Python path") ) # Report results if available_models: logger.info("Available models: %s", ", ".join(available_models)) if missing_deps: logger.warning("Missing model dependencies: %s", ", ".join([dep[0] for dep in missing_deps])) else: logger.info("All embedding dependencies are available!") return missing_deps