{ "cells": [ { "cell_type": "markdown", "id": "9d80f419", "metadata": {}, "source": [ "# Embedding with a fine-tuned custom model using AMULETY\n", "\n", "In this tutorial we will showcase how to fine-tune the **AntiBERTy** antibody language model to predict binding to the SARS-CoV-2 Spike protein (S), as published in [Wang. et al. 2025](https://doi.org/10.1371/journal.pcbi.1012153). \n", "\n", "\n", "\n", "The tutorial goes through the following steps:\n", "- Loading antibody sequences and S-binding labels\n", "- Formating sequences for AntiBERTy\n", "- Using a grouped, stratified cross-validation split\n", "- Fine-tuning AntiBERTy with Hugging Face `Trainer`\n", "- Evaluating with AUC, MCC, balanced accuracy, etc.\n", "- Using AMULETY to embed new sequences with the fine-tuned model.\n", "\n", "\n", "## Set-up\n", "\n", "We recommend running this tutorial on a GPU with at least 25 GB of RAM. On such a setup, the full notebook typically completes in about 15 minutes.\n", "\n", "First, download a parquet file containing the training data:\n", " - `S_CDR3.parquet` from [figshare](https://figshare.com/articles/dataset/Fine-tuning_Pre-trained_Antibody_Language_Models_for_Antigen_Specificity_Prediction/25342924)\n", "\n", "The dataset contains the following columns:\n", "- `HL` or `H`: sequences (heavy + light vs heavy only).\n", "- `label`: includes antigen binding values like `\"S+\"`, `\"S1+\"`, `\"S2+\"` (positive) and others (negative).\n", "- `subject`: donor / study ID for grouped cross-validation.\n" ] }, { "cell_type": "markdown", "id": "fb486424", "metadata": {}, "source": [ "## Install dependencies (run once per session)" ] }, { "cell_type": "code", "execution_count": 1, "id": "6f4883f2", "metadata": {}, "outputs": [], "source": [ "#!pip install -q antiberty transformers datasets scikit-learn biopython pyarrow" ] }, { "cell_type": "markdown", "id": "3c60cd3c", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "id": "baf4eddc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PyTorch: 2.3.0\n", "CUDA available: True\n" ] }, { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "import random\n", "from collections import Counter\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "\n", "from sklearn.metrics import (\n", " precision_score, recall_score, f1_score,\n", " matthews_corrcoef, roc_auc_score,\n", " average_precision_score, balanced_accuracy_score\n", ")\n", "\n", "from sklearn.model_selection import StratifiedGroupKFold\n", "from datasets import Dataset, DatasetDict, ClassLabel\n", "import transformers\n", "from transformers import (\n", " AutoTokenizer,\n", " AutoModelForSequenceClassification,\n", " TrainingArguments,\n", " Trainer\n", ")\n", "\n", "import antiberty\n", "\n", "print(\"PyTorch:\", torch.__version__)\n", "print(\"CUDA available:\", torch.cuda.is_available())\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "device" ] }, { "cell_type": "markdown", "id": "c91d8518", "metadata": {}, "source": [ "## Configuration" ] }, { "cell_type": "markdown", "id": "ba28897b", "metadata": {}, "source": [ "Update `DATA_DIR` and `OUTPUT_DIR` below if needed to the path where your models and data (the downloaded parquet file) is stored.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "04ca3624", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Run ID: S_antiBERTy_HL_fine_tuning_CDR3\n", "Data dir: ../data/\n", "Output dir: ../models/\n" ] } ], "source": [ "# Which column in the parquet to use as sequences\n", "MODEL_TYPE = \"HL\" \n", "\n", "# Which dataset variant to use\n", "SEQUENCE_SCOPE = \"CDR3\" \n", "\n", "# Path to your data directory\n", "DATA_DIR = \"../data/\" \n", "\n", "# Path where models and logs will be saved\n", "OUTPUT_DIR = \"../models/\"\n", "\n", "# Training hyperparameters\n", "BATCH_SIZE = 64 \n", "LR = 1e-5 \n", "N_EPOCHS = 10 \n", "\n", "RANDOM_STATE_OUTER = 7 if SEQUENCE_SCOPE == \"CDR3\" else 9\n", "RANDOM_STATE_INNER = 1\n", "\n", "RUN_ID = f\"S_antiBERTy_{MODEL_TYPE}_fine_tuning_{SEQUENCE_SCOPE}\"\n", "print(\"Run ID:\", RUN_ID)\n", "print(\"Data dir:\", DATA_DIR)\n", "print(\"Output dir:\", OUTPUT_DIR)" ] }, { "cell_type": "markdown", "id": "f9c78f55", "metadata": {}, "source": [ "## Helper functions: model loading, freezing, formatting, metrics" ] }, { "cell_type": "code", "execution_count": null, "id": "00785585", "metadata": {}, "outputs": [], "source": [ "def get_antiberty_paths():\n", " \"\"\"Locate AntiBERTy model + vocab from the antiberty package.\"\"\"\n", " project_path = os.path.dirname(os.path.realpath(antiberty.__file__))\n", " trained_dir = os.path.join(project_path, \"trained_models\")\n", " model_dir = os.path.join(trained_dir, \"AntiBERTy_md_smooth\")\n", " vocab = os.path.join(trained_dir, \"vocab.txt\")\n", " print(\"AntiBERTy model:\", model_dir)\n", " print(\"AntiBERTy vocab:\", vocab)\n", " return model_dir, vocab\n", "\n", "\n", "def load_antiberty_classifier(num_labels: int = 2):\n", " \"\"\"Load AntiBERTy as a sequence-classification model + tokenizer.\"\"\"\n", " model_dir, vocab = get_antiberty_paths()\n", " tokenizer = transformers.BertTokenizer(\n", " vocab_file=vocab,\n", " do_lower_case=False\n", " )\n", " model = AutoModelForSequenceClassification.from_pretrained(\n", " model_dir,\n", " num_labels=num_labels\n", " )\n", " model.to(device)\n", " size = sum(p.numel() for p in model.parameters())\n", " print(f\"Model size: {size/1e6:.2f}M parameters\")\n", " return model, tokenizer\n", "\n", "\n", "def freeze_antiberty_layers(model, train_last_n_layers: int = 3):\n", " \"\"\"Freeze embeddings and early encoder layers of AntiBERTy.\"\"\"\n", " for p in model.bert.embeddings.parameters():\n", " p.requires_grad = False\n", "\n", " total_layers = len(model.bert.encoder.layer) # AntiBERTy has 8 layers\n", " for layer in model.bert.encoder.layer[: total_layers - train_last_n_layers]:\n", " for p in layer.parameters():\n", " p.requires_grad = False\n", " return model\n", "\n", "\n", "def insert_space_every_other_except_cls(s: str) -> str:\n", " \"\"\"Add spaces between residues, keeping [CLS] intact.\"\"\"\n", " parts = s.split(\"[CLS]\")\n", " spaced = [\" \".join(list(part)) for part in parts]\n", " out = \" [CLS] \".join(spaced)\n", " return \" \".join(out.split())\n", "\n", "\n", "def set_seed(seed: int = 42):\n", " \"\"\"Set random seeds for reproducibility.\"\"\"\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", " torch.backends.cudnn.deterministic = True\n", " torch.backends.cudnn.benchmark = False\n", " np.random.seed(seed)\n", " random.seed(seed)\n", " os.environ[\"PYTHONHASHSEED\"] = str(seed)\n", "\n", "\n", "def compute_metrics(eval_pred):\n", " \"\"\"Metrics callback for Hugging Face Trainer.\"\"\"\n", " logits, labels = eval_pred\n", " probs = torch.softmax(torch.tensor(logits), dim=1).numpy()[:, 1]\n", " preds = np.argmax(logits, axis=1)\n", "\n", " return {\n", " \"precision\": precision_score(labels, preds),\n", " \"recall\": recall_score(labels, preds),\n", " \"f1_weighted\": f1_score(labels, preds, average=\"weighted\"),\n", " \"apr\": average_precision_score(labels, probs),\n", " \"balanced_accuracy\": balanced_accuracy_score(labels, preds),\n", " \"auc\": roc_auc_score(labels, probs),\n", " \"mcc\": matthews_corrcoef(labels, preds),\n", " }" ] }, { "cell_type": "markdown", "id": "45a75af2", "metadata": {}, "source": [ "## Load dataset" ] }, { "cell_type": "markdown", "id": "cef2bb2e", "metadata": {}, "source": [ "We will load the data and transform the sequence format to the required format by AntiBERTy, which requiers the character `[CLS]` for padding the heavy and light chain sequences." ] }, { "cell_type": "code", "execution_count": null, "id": "69f14b31", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading CDR3 sequences...\n", "Total sequences: 15539\n", "Unique donors: 427\n", "Label counts: Counter({1: 8658, 0: 6881})\n" ] } ], "source": [ "MAX_LENGTH = 512 - 2 # AntiBERTy max length minus specials\n", "\n", "def load_data(scope: str = \"CDR3\", model_type: str = \"HL\"):\n", " if scope == \"FULL\":\n", " filename = \"S_FULL.parquet\"\n", " print(\"Loading full-length sequences...\")\n", " else:\n", " filename = \"S_CDR3.parquet\"\n", " print(\"Loading CDR3 sequences...\")\n", "\n", " path = os.path.join(DATA_DIR, filename)\n", " if not os.path.exists(path):\n", " raise FileNotFoundError(f\"Could not find {path}. Please check DATA_DIR and filename.\")\n", "\n", " df = pd.read_parquet(path)\n", "\n", " X = df[model_type].apply(lambda s: s[:MAX_LENGTH])\n", " X = X.str.replace(\"\", \"[CLS][CLS]\", regex=False)\n", " X = X.apply(insert_space_every_other_except_cls)\n", "\n", " y = np.isin(df[\"label\"], [\"S+\", \"S1+\", \"S2+\"]).astype(int)\n", " groups = df[\"subject\"].values\n", "\n", " print(f\"Total sequences: {len(X)}\")\n", " print(f\"Unique donors: {len(np.unique(groups))}\")\n", " print(\"Label counts:\", Counter(y))\n", "\n", " return X, y, groups, df\n", "\n", "X, y, y_groups, raw_df = load_data(SEQUENCE_SCOPE, MODEL_TYPE)" ] }, { "cell_type": "markdown", "id": "df11d2bc", "metadata": {}, "source": [ "## Create data splits\n", "\n", "We generate train, validation and test data splist using the StratifiedGroupKFold function. We use a 25% of the dataset as test dataset and 33% of the training data as a validation set." ] }, { "cell_type": "code", "execution_count": 5, "id": "b8279936", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using outer fold 1\n", "Outer train size: 12969, test size: 2570\n", "% positive train: 0.555, % positive test: 0.566\n", "Using inner fold 1 for train/val split\n", "Final sizes — train: 8423, val: 4546, test: 2570\n" ] } ], "source": [ "outer_cv = StratifiedGroupKFold(\n", " n_splits=4,\n", " shuffle=True,\n", " random_state=RANDOM_STATE_OUTER\n", ")\n", "\n", "inner_cv = StratifiedGroupKFold(\n", " n_splits=3,\n", " shuffle=True,\n", " random_state=RANDOM_STATE_INNER\n", ")\n", "\n", "# Use the first outer fold for this tutorial\n", "for fold_idx, (train_index, test_index) in enumerate(outer_cv.split(X, y, y_groups), start=1):\n", " print(f\"Using outer fold {fold_idx}\")\n", " X_train_all, X_test = X.iloc[train_index], X.iloc[test_index]\n", " y_train_all, y_test = y[train_index], y[test_index]\n", " y_groups_train = y_groups[train_index]\n", " break\n", "\n", "print(f\"Outer train size: {len(X_train_all)}, test size: {len(X_test)}\")\n", "print(f\"% positive train: {np.mean(y_train_all):.3f}, % positive test: {np.mean(y_test):.3f}\")\n", "\n", "# Inner split to create validation set\n", "for inner_idx, (inner_train_index, val_index) in enumerate(\n", " inner_cv.split(X_train_all, y_train_all, y_groups_train),\n", " start=1\n", "):\n", " print(f\"Using inner fold {inner_idx} for train/val split\")\n", " X_train = X_train_all.iloc[inner_train_index]\n", " y_train = y_train_all[inner_train_index]\n", " X_val = X_train_all.iloc[val_index]\n", " y_val = y_train_all[val_index]\n", " break\n", "\n", "print(f\"Final sizes — train: {len(X_train)}, val: {len(X_val)}, test: {len(X_test)}\")" ] }, { "cell_type": "markdown", "id": "e02c972c", "metadata": {}, "source": [ "## Build Hugging Face Datasets & tokenize" ] }, { "cell_type": "markdown", "id": "fc784501", "metadata": {}, "source": [ "We need to convert the different sets to the format required by Hugging Face." ] }, { "cell_type": "code", "execution_count": 6, "id": "eaeaa8a0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AntiBERTy model: /vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty/trained_models/AntiBERTy_md_smooth\n", "AntiBERTy vocab: /vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty/trained_models/vocab.txt\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", " warnings.warn(\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty/trained_models/AntiBERTy_md_smooth and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model size: 25.76M parameters\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Map: 100%|██████████| 8423/8423 [00:02<00:00, 3736.40 examples/s]\n", "Map: 100%|██████████| 4546/4546 [00:01<00:00, 3854.12 examples/s]\n", "Map: 100%|██████████| 2570/2570 [00:00<00:00, 4093.28 examples/s]\n" ] } ], "source": [ "train_df = pd.DataFrame({\"sequence\": X_train.values, \"labels\": y_train})\n", "val_df = pd.DataFrame({\"sequence\": X_val.values, \"labels\": y_val})\n", "test_df = pd.DataFrame({\"sequence\": X_test.values, \"labels\": y_test})\n", "\n", "raw_datasets = DatasetDict({\n", " \"train\": Dataset.from_pandas(train_df.reset_index(drop=True)),\n", " \"validation\": Dataset.from_pandas(val_df.reset_index(drop=True)),\n", " \"test\": Dataset.from_pandas(test_df.reset_index(drop=True)),\n", "})\n", "\n", "model, tokenizer = load_antiberty_classifier(num_labels=2)\n", "model = freeze_antiberty_layers(model, train_last_n_layers=3)\n", "\n", "def preprocess_function(batch):\n", " encodings = tokenizer(\n", " batch[\"sequence\"],\n", " padding=\"max_length\",\n", " truncation=True,\n", " max_length=MAX_LENGTH,\n", " )\n", " encodings[\"labels\"] = batch[\"labels\"]\n", " return encodings\n", "\n", "tokenized_datasets = raw_datasets.map(\n", " preprocess_function,\n", " batched=True,\n", " remove_columns=[\"sequence\"],\n", ")" ] }, { "cell_type": "markdown", "id": "7572c373", "metadata": {}, "source": [ "## Set up the parameters for fine-tuning AntiBERTy" ] }, { "cell_type": "code", "execution_count": 7, "id": "3aea82b1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n", "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saving checkpoints to: ../models/S_antiBERTy_HL_fine_tuning_CDR3_Fold_1\n" ] } ], "source": [ "set_seed(1)\n", "\n", "FOLD_ID = 1\n", "OUT_PATH = os.path.join(OUTPUT_DIR, f\"{RUN_ID}_Fold_{FOLD_ID}\")\n", "os.makedirs(OUT_PATH, exist_ok=True)\n", "print(\"Saving checkpoints to:\", OUT_PATH)\n", "\n", "training_args = TrainingArguments(\n", " output_dir=OUT_PATH,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " logging_strategy=\"epoch\",\n", " learning_rate=LR,\n", " per_device_train_batch_size=BATCH_SIZE,\n", " per_device_eval_batch_size=BATCH_SIZE,\n", " num_train_epochs=N_EPOCHS,\n", " warmup_ratio=0.0,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"auc\",\n", " lr_scheduler_type=\"linear\",\n", " seed=1\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " tokenizer=tokenizer,\n", " train_dataset=tokenized_datasets[\"train\"],\n", " eval_dataset=tokenized_datasets[\"validation\"],\n", " compute_metrics=compute_metrics,\n", ")" ] }, { "cell_type": "markdown", "id": "83fccb6d", "metadata": {}, "source": [ "## Fine tune AntiBERTy on the specificity dataset" ] }, { "cell_type": "code", "execution_count": 8, "id": "c263f296", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [1320/1320 10:27, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossPrecisionRecallF1 WeightedAprBalanced AccuracyAucMcc
10.6821000.6741850.5713910.8704520.5128430.6414190.5359590.5992630.097128
20.6683000.6682390.5867260.8412630.5488880.6621870.5582850.6197110.142084
30.6516000.6592770.6171220.7321070.5941140.6763440.5883030.6349820.184527
40.6417000.6544300.6231690.7313070.6012850.6880670.5952380.6473390.198104
50.6309000.6518990.6285520.7253100.6067890.6934550.6005520.6535570.207789
60.6232000.6571830.6642660.5901640.6113590.6951640.6126860.6542030.224292
70.6165000.6501680.6434080.7005200.6188270.7003750.6128520.6595540.229099
80.6110000.6505330.6509990.6645340.6189530.7013720.6144190.6606470.229431
90.6062000.6504010.6487430.6705320.6181580.7033390.6132610.6621850.227533
100.6059000.6499990.6428840.6953220.6173560.7038810.6114750.6628330.225944

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'train_runtime': 628.6721,\n", " 'train_samples_per_second': 133.981,\n", " 'train_steps_per_second': 2.1,\n", " 'total_flos': 6568285780076400.0,\n", " 'train_loss': 0.6337348244406961,\n", " 'epoch': 10.0}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_result = trainer.train()\n", "train_result.metrics" ] }, { "cell_type": "markdown", "id": "b1407d0e", "metadata": {}, "source": [ "## Evaluate the model on held-out test set" ] }, { "cell_type": "code", "execution_count": 9, "id": "26dda1e2", "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Test metrics:\n", "test_loss: 0.6636\n", "test_precision: 0.6531\n", "test_recall: 0.6527\n", "test_f1_weighted: 0.6074\n", "test_apr: 0.6931\n", "test_balanced_accuracy: 0.6005\n", "test_auc: 0.6438\n", "test_mcc: 0.2010\n", "test_runtime: 8.7197\n", "test_samples_per_second: 294.7350\n", "test_steps_per_second: 4.7020\n" ] } ], "source": [ "model.eval()\n", "test_outputs = trainer.predict(tokenized_datasets[\"test\"])\n", "test_metrics = test_outputs.metrics\n", "\n", "print(\"Test metrics:\")\n", "for k, v in test_metrics.items():\n", " print(f\"{k}: {v:.4f}\")" ] }, { "cell_type": "markdown", "id": "8adf9bda", "metadata": {}, "source": [ "## Using the fine-tuned AntiBERTy model with Amulety on epitope-specific sequences\n", "\n", "In this section we:\n", "1. Export the fine-tuned AntiBERTy classifier as a Hugging Face model.\n", "2. Register it as a **custom** model for Amulety.\n", "3. Use Amulety to embed a few example sequences with this fine-tuned encoder.\n", "\n", "This illustrates how you can fine-tune AntiBERTy for S-protein binding and then\n", "reuse the same model inside Amulety to generate embeddings for other datasets\n", "(e.g. epitope-specificity panels).\n" ] }, { "cell_type": "markdown", "id": "e560643d", "metadata": {}, "source": [ "We will start by saving the fine-tuned AntiBERTy classifier as a Hugging Face model." ] }, { "cell_type": "code", "execution_count": 10, "id": "1aa99997", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved fine-tuned model for Amulety at: ../models/S_antiBERTy_HL_fine_tuning_CDR3_amulety_custom\n" ] } ], "source": [ "from pathlib import Path\n", "\n", "# Directory where we will save the fine-tuned model for Amulety\n", "CUSTOM_MODEL_PATH = Path(OUTPUT_DIR) / f\"{RUN_ID}_amulety_custom\"\n", "CUSTOM_MODEL_PATH.mkdir(parents=True, exist_ok=True)\n", "\n", "# `trainer` already holds the fine-tuned best model because we used `load_best_model_at_end=True`\n", "# but to be explicit, we save the current `model` and `tokenizer`.\n", "model.save_pretrained(CUSTOM_MODEL_PATH)\n", "tokenizer.save_pretrained(CUSTOM_MODEL_PATH)\n", "\n", "print(\"Saved fine-tuned model for Amulety at:\", CUSTOM_MODEL_PATH)\n" ] }, { "cell_type": "markdown", "id": "55b8510a", "metadata": {}, "source": [ "We will then download a new dataset and sample 1000 cells for demonstration purposes." ] }, { "cell_type": "code", "execution_count": 33, "id": "bfba1acc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/tmp.UnNt4ZCghI/ipykernel_1723782/2594638705.py:1: DtypeWarning: Columns (10) have mixed types. Specify dtype option on import or set low_memory=False.\n", " airr_demo = pd.read_csv(\"https://zenodo.org/records/17186858/files/ML_bcr_airr_dataset.tsv\", sep='\\t')\n" ] }, { "data": { "text/html": [ "

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sequence_idsequence_vdj_aalocuscell_idchain_typev_callv_call_familyj_call_familymu_freqjunction_aa_lengthisotypesourcesubjectspecificityduplicate_countproductiverev_compstop_codonvj_in_frame
7980_heavyEVQLVESGGGLVQPGGSLRLSCVASGFTFSSYWMSWVRQAPGKGLE...IGHcell_80HIGHV3-7IGHV3IGHJ40.00337818.0NaNOASOAS_King_Subject-BCP3unlabeled1TrueFalseFalseTrue
189190_heavyEVQLVQSGAEVKKPGESLKISCKGSAYSFTNYWIAWVRQMPGKGLE...IGHcell_190HIGHV5-51IGHV5IGHJ30.01315820.0NaNOASOAS_King_Subject-BCP3unlabeled1TrueFalseFalseTrue
298299_heavyEVQLVESGGGLVKPGGSLRLSCSASRFTFSTYRMNWVRQAPGKGLE...IGHcell_299HIGHV3-21IGHV3IGHJ20.02325623.0NaNOASOAS_King_Subject-BCP3unlabeled1TrueFalseFalseTrue
10841085_heavyQVQLVQSGAEVREPGASVKVSCKASGYTFTIYDINWVRQAPGQGLE...IGHcell_1085HIGHV1-8IGHV1IGHJ40.06270613.0NaNOASOAS_King_Subject-BCP3unlabeled1TrueFalseFalseTrue
27122713_heavyEVQLVESGGGLVQRGGSLRLSCGASGFTFSSYNMNWVRQAPGKGLE...IGHcell_2713HIGHV3-48IGHV3IGHJ30.03289514.0NaNOASOAS_King_Subject-BCP4unlabeled1TrueFalseFalseTrue
\n", "
" ], "text/plain": [ " sequence_id sequence_vdj_aa locus \\\n", "79 80_heavy EVQLVESGGGLVQPGGSLRLSCVASGFTFSSYWMSWVRQAPGKGLE... IGH \n", "189 190_heavy EVQLVQSGAEVKKPGESLKISCKGSAYSFTNYWIAWVRQMPGKGLE... IGH \n", "298 299_heavy EVQLVESGGGLVKPGGSLRLSCSASRFTFSTYRMNWVRQAPGKGLE... IGH \n", "1084 1085_heavy QVQLVQSGAEVREPGASVKVSCKASGYTFTIYDINWVRQAPGQGLE... IGH \n", "2712 2713_heavy EVQLVESGGGLVQRGGSLRLSCGASGFTFSSYNMNWVRQAPGKGLE... IGH \n", "\n", " cell_id chain_type v_call v_call_family j_call_family mu_freq \\\n", "79 cell_80 H IGHV3-7 IGHV3 IGHJ4 0.003378 \n", "189 cell_190 H IGHV5-51 IGHV5 IGHJ3 0.013158 \n", "298 cell_299 H IGHV3-21 IGHV3 IGHJ2 0.023256 \n", "1084 cell_1085 H IGHV1-8 IGHV1 IGHJ4 0.062706 \n", "2712 cell_2713 H IGHV3-48 IGHV3 IGHJ3 0.032895 \n", "\n", " junction_aa_length isotype source subject specificity \\\n", "79 18.0 NaN OAS OAS_King_Subject-BCP3 unlabeled \n", "189 20.0 NaN OAS OAS_King_Subject-BCP3 unlabeled \n", "298 23.0 NaN OAS OAS_King_Subject-BCP3 unlabeled \n", "1084 13.0 NaN OAS OAS_King_Subject-BCP3 unlabeled \n", "2712 14.0 NaN OAS OAS_King_Subject-BCP4 unlabeled \n", "\n", " duplicate_count productive rev_comp stop_codon vj_in_frame \n", "79 1 True False False True \n", "189 1 True False False True \n", "298 1 True False False True \n", "1084 1 True False False True \n", "2712 1 True False False True " ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "airr_demo = pd.read_csv(\"https://zenodo.org/records/17186858/files/ML_bcr_airr_dataset.tsv\", sep='\\t')\n", "# cells that have at least one H and one L\n", "mask_both = (\n", " airr_demo.groupby(\"cell_id\")[\"chain_type\"]\n", " .agg(lambda x: set(x))\n", " .pipe(lambda s: s[s.apply(lambda st: {\"H\", \"L\"}.issubset(st))])\n", ")\n", "cells_with_both = mask_both.index\n", "\n", "# sample up to 1000 such cells (no replacement)\n", "n_cells = min(1000, len(cells_with_both))\n", "sampled_cells = np.random.choice(cells_with_both, size=n_cells, replace=False)\n", "df_sampled = airr_demo[airr_demo[\"cell_id\"].isin(sampled_cells)].copy()\n", "df_sampled.head()" ] }, { "cell_type": "markdown", "id": "0e4ab14f", "metadata": {}, "source": [ "We will then use amulety to generate embedding vectors for these sequences using our fine-tuned AntiBERTy model." ] }, { "cell_type": "code", "execution_count": 31, "id": "5b316293", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/amulety/amulety.py:575: UserWarning: Custom protein language model might not understand paired chain relationships. Chain 'HL' will be processed as concatenated sequences, but results may be inaccurate.\n", " warnings.warn(\n", "Some weights of BertForMaskedLM were not initialized from the model checkpoint at ../models/S_antiBERTy_HL_fine_tuning_CDR3_amulety_custom and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Custom model path: ../models/S_antiBERTy_HL_fine_tuning_CDR3_amulety_custom\n", "Embedding dimension: 512\n", "Max length used by tutorial: 510\n", "Embeddings shape: (1000, 513)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
cell_iddim_1dim_2dim_3dim_4dim_5dim_6dim_7dim_8dim_9...dim_503dim_504dim_505dim_506dim_507dim_508dim_509dim_510dim_511dim_512
0cell_100681-0.077534-0.109347-1.533826-0.5534311.1188740.0110191.411259-1.8774431.221428...-0.4866960.669253-0.1992350.142055-0.4120140.974601-0.765339-0.2301750.045813-0.22751
1cell_101224-0.077534-0.109347-1.533826-0.5534311.1188740.0110191.411259-1.8774431.221428...-0.4866960.669253-0.1992350.142055-0.4120140.974601-0.765339-0.2301750.045813-0.22751
2cell_101233-0.077534-0.109347-1.533826-0.5534311.1188740.0110191.411259-1.8774431.221428...-0.4866960.669253-0.1992350.142055-0.4120140.974601-0.765339-0.2301750.045813-0.22751
3cell_101360-0.077534-0.109347-1.533826-0.5534311.1188740.0110191.411259-1.8774431.221428...-0.4866960.669253-0.1992350.142055-0.4120140.974601-0.765339-0.2301750.045813-0.22751
4cell_102058-0.077534-0.109347-1.533826-0.5534311.1188740.0110191.411259-1.8774431.221428...-0.4866960.669253-0.1992350.142055-0.4120140.974601-0.765339-0.2301750.045813-0.22751
\n", "

5 rows × 513 columns

\n", "
" ], "text/plain": [ " cell_id dim_1 dim_2 dim_3 dim_4 dim_5 dim_6 \\\n", "0 cell_100681 -0.077534 -0.109347 -1.533826 -0.553431 1.118874 0.011019 \n", "1 cell_101224 -0.077534 -0.109347 -1.533826 -0.553431 1.118874 0.011019 \n", "2 cell_101233 -0.077534 -0.109347 -1.533826 -0.553431 1.118874 0.011019 \n", "3 cell_101360 -0.077534 -0.109347 -1.533826 -0.553431 1.118874 0.011019 \n", "4 cell_102058 -0.077534 -0.109347 -1.533826 -0.553431 1.118874 0.011019 \n", "\n", " dim_7 dim_8 dim_9 ... dim_503 dim_504 dim_505 dim_506 \\\n", "0 1.411259 -1.877443 1.221428 ... -0.486696 0.669253 -0.199235 0.142055 \n", "1 1.411259 -1.877443 1.221428 ... -0.486696 0.669253 -0.199235 0.142055 \n", "2 1.411259 -1.877443 1.221428 ... -0.486696 0.669253 -0.199235 0.142055 \n", "3 1.411259 -1.877443 1.221428 ... -0.486696 0.669253 -0.199235 0.142055 \n", "4 1.411259 -1.877443 1.221428 ... -0.486696 0.669253 -0.199235 0.142055 \n", "\n", " dim_507 dim_508 dim_509 dim_510 dim_511 dim_512 \n", "0 -0.412014 0.974601 -0.765339 -0.230175 0.045813 -0.22751 \n", "1 -0.412014 0.974601 -0.765339 -0.230175 0.045813 -0.22751 \n", "2 -0.412014 0.974601 -0.765339 -0.230175 0.045813 -0.22751 \n", "3 -0.412014 0.974601 -0.765339 -0.230175 0.045813 -0.22751 \n", "4 -0.412014 0.974601 -0.765339 -0.230175 0.045813 -0.22751 \n", "\n", "[5 rows x 513 columns]" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from amulety import embed_airr\n", "\n", "CHAIN = MODEL_TYPE\n", "EMBED_DIM = int(model.config.hidden_size)\n", "\n", "print(\"Custom model path:\", CUSTOM_MODEL_PATH)\n", "print(\"Embedding dimension:\", EMBED_DIM)\n", "print(\"Max length used by tutorial:\", MAX_LENGTH)\n", "\n", "embeddings_df, meta_df = embed_airr(\n", " airr=df_sampled,\n", " chain=CHAIN,\n", " model=\"custom\",\n", " sequence_col=\"sequence_vdj_aa\",\n", " cell_id_col=\"cell_id\",\n", " batch_size=8,\n", " model_path=str(CUSTOM_MODEL_PATH),\n", " embedding_dimension=EMBED_DIM,\n", " max_length=MAX_LENGTH,\n", " output_type=\"df\", # return embeddings as a DataFrame\n", " residue_level=False, # sequence-level embeddings\n", ")\n", "\n", "print(\"Embeddings shape:\", embeddings_df.shape)\n", "embeddings_df.head()" ] }, { "cell_type": "markdown", "id": "f59f4c0c", "metadata": {}, "source": [ "At this point:\n", "- `embeddings_df` contains one row per input sequence and 512 embedding features\n", " (plus identifiers), generated by the **fine-tuned** AntiBERTy model via Amulety.\n", "\n", "You can now:\n", "- Project these embeddings with UMAP / t-SNE,\n", "- Train simple classifiers or regressors for predictions,\n", "- Or integrate them into larger downstream pipelines, all using the same\n", " Amulety interface you use for other models." ] } ], "metadata": { "kernelspec": { "display_name": "amulety", "language": "python", "name": "amulety" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }