Embedding with a fine-tuned custom model using AMULETY

In this tutorial we will showcase how to fine-tune the AntiBERTy antibody language model to predict binding to the SARS-CoV-2 Spike protein (S), as published in Wang. et al. 2025.

The tutorial goes through the following steps:

  • Loading antibody sequences and S-binding labels

  • Formating sequences for AntiBERTy

  • Using a grouped, stratified cross-validation split

  • Fine-tuning AntiBERTy with Hugging Face Trainer

  • Evaluating with AUC, MCC, balanced accuracy, etc.

  • Using AMULETY to embed new sequences with the fine-tuned model.

Set-up

We recommend running this tutorial on a GPU with at least 25 GB of RAM. On such a setup, the full notebook typically completes in about 15 minutes.

First, download a parquet file containing the training data:

The dataset contains the following columns:

  • HL or H: sequences (heavy + light vs heavy only).

  • label: includes antigen binding values like "S+", "S1+", "S2+" (positive) and others (negative).

  • subject: donor / study ID for grouped cross-validation.

Install dependencies (run once per session)

[1]:
#!pip install -q antiberty transformers datasets scikit-learn biopython pyarrow

Imports

[1]:
import os
import random
from collections import Counter

import numpy as np
import pandas as pd
import torch

from sklearn.metrics import (
    precision_score, recall_score, f1_score,
    matthews_corrcoef, roc_auc_score,
    average_precision_score, balanced_accuracy_score
)

from sklearn.model_selection import StratifiedGroupKFold
from datasets import Dataset, DatasetDict, ClassLabel
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)

import antiberty

print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
/home/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
PyTorch: 2.3.0
CUDA available: True
[1]:
device(type='cuda')

Configuration

Update DATA_DIR and OUTPUT_DIR below if needed to the path where your models and data (the downloaded parquet file) is stored.

[2]:
# Which column in the parquet to use as sequences
MODEL_TYPE = "HL"

# Which dataset variant to use
SEQUENCE_SCOPE = "CDR3"

# Path to your data directory
DATA_DIR = "../data/"

# Path where models and logs will be saved
OUTPUT_DIR = "../models/"

# Training hyperparameters
BATCH_SIZE = 64
LR = 1e-5
N_EPOCHS = 10

RANDOM_STATE_OUTER = 7 if SEQUENCE_SCOPE == "CDR3" else 9
RANDOM_STATE_INNER = 1

RUN_ID = f"S_antiBERTy_{MODEL_TYPE}_fine_tuning_{SEQUENCE_SCOPE}"
print("Run ID:", RUN_ID)
print("Data dir:", DATA_DIR)
print("Output dir:", OUTPUT_DIR)
Run ID: S_antiBERTy_HL_fine_tuning_CDR3
Data dir: ../data/
Output dir: ../models/

Helper functions: model loading, freezing, formatting, metrics

[ ]:
def get_antiberty_paths():
    """Locate AntiBERTy model + vocab from the antiberty package."""
    project_path = os.path.dirname(os.path.realpath(antiberty.__file__))
    trained_dir = os.path.join(project_path, "trained_models")
    model_dir = os.path.join(trained_dir, "AntiBERTy_md_smooth")
    vocab = os.path.join(trained_dir, "vocab.txt")
    print("AntiBERTy model:", model_dir)
    print("AntiBERTy vocab:", vocab)
    return model_dir, vocab


def load_antiberty_classifier(num_labels: int = 2):
    """Load AntiBERTy as a sequence-classification model + tokenizer."""
    model_dir, vocab = get_antiberty_paths()
    tokenizer = transformers.BertTokenizer(
        vocab_file=vocab,
        do_lower_case=False
    )
    model = AutoModelForSequenceClassification.from_pretrained(
        model_dir,
        num_labels=num_labels
    )
    model.to(device)
    size = sum(p.numel() for p in model.parameters())
    print(f"Model size: {size/1e6:.2f}M parameters")
    return model, tokenizer


def freeze_antiberty_layers(model, train_last_n_layers: int = 3):
    """Freeze embeddings and early encoder layers of AntiBERTy."""
    for p in model.bert.embeddings.parameters():
        p.requires_grad = False

    total_layers = len(model.bert.encoder.layer)  # AntiBERTy has 8 layers
    for layer in model.bert.encoder.layer[: total_layers - train_last_n_layers]:
        for p in layer.parameters():
            p.requires_grad = False
    return model


def insert_space_every_other_except_cls(s: str) -> str:
    """Add spaces between residues, keeping [CLS] intact."""
    parts = s.split("[CLS]")
    spaced = [" ".join(list(part)) for part in parts]
    out = " [CLS] ".join(spaced)
    return " ".join(out.split())


def set_seed(seed: int = 42):
    """Set random seeds for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)


def compute_metrics(eval_pred):
    """Metrics callback for Hugging Face Trainer."""
    logits, labels = eval_pred
    probs = torch.softmax(torch.tensor(logits), dim=1).numpy()[:, 1]
    preds = np.argmax(logits, axis=1)

    return {
        "precision": precision_score(labels, preds),
        "recall": recall_score(labels, preds),
        "f1_weighted": f1_score(labels, preds, average="weighted"),
        "apr": average_precision_score(labels, probs),
        "balanced_accuracy": balanced_accuracy_score(labels, preds),
        "auc": roc_auc_score(labels, probs),
        "mcc": matthews_corrcoef(labels, preds),
    }

Load dataset

We will load the data and transform the sequence format to the required format by AntiBERTy, which requiers the character [CLS] for padding the heavy and light chain sequences.

[ ]:
MAX_LENGTH = 512 - 2  # AntiBERTy max length minus specials

def load_data(scope: str = "CDR3", model_type: str = "HL"):
    if scope == "FULL":
        filename = "S_FULL.parquet"
        print("Loading full-length sequences...")
    else:
        filename = "S_CDR3.parquet"
        print("Loading CDR3 sequences...")

    path = os.path.join(DATA_DIR, filename)
    if not os.path.exists(path):
        raise FileNotFoundError(f"Could not find {path}. Please check DATA_DIR and filename.")

    df = pd.read_parquet(path)

    X = df[model_type].apply(lambda s: s[:MAX_LENGTH])
    X = X.str.replace("<cls><cls>", "[CLS][CLS]", regex=False)
    X = X.apply(insert_space_every_other_except_cls)

    y = np.isin(df["label"], ["S+", "S1+", "S2+"]).astype(int)
    groups = df["subject"].values

    print(f"Total sequences: {len(X)}")
    print(f"Unique donors: {len(np.unique(groups))}")
    print("Label counts:", Counter(y))

    return X, y, groups, df

X, y, y_groups, raw_df = load_data(SEQUENCE_SCOPE, MODEL_TYPE)
Loading CDR3 sequences...
Total sequences: 15539
Unique donors: 427
Label counts: Counter({1: 8658, 0: 6881})

Create data splits

We generate train, validation and test data splist using the StratifiedGroupKFold function. We use a 25% of the dataset as test dataset and 33% of the training data as a validation set.

[5]:
outer_cv = StratifiedGroupKFold(
    n_splits=4,
    shuffle=True,
    random_state=RANDOM_STATE_OUTER
)

inner_cv = StratifiedGroupKFold(
    n_splits=3,
    shuffle=True,
    random_state=RANDOM_STATE_INNER
)

# Use the first outer fold for this tutorial
for fold_idx, (train_index, test_index) in enumerate(outer_cv.split(X, y, y_groups), start=1):
    print(f"Using outer fold {fold_idx}")
    X_train_all, X_test = X.iloc[train_index], X.iloc[test_index]
    y_train_all, y_test = y[train_index], y[test_index]
    y_groups_train = y_groups[train_index]
    break

print(f"Outer train size: {len(X_train_all)}, test size: {len(X_test)}")
print(f"% positive train: {np.mean(y_train_all):.3f}, % positive test: {np.mean(y_test):.3f}")

# Inner split to create validation set
for inner_idx, (inner_train_index, val_index) in enumerate(
    inner_cv.split(X_train_all, y_train_all, y_groups_train),
    start=1
):
    print(f"Using inner fold {inner_idx} for train/val split")
    X_train = X_train_all.iloc[inner_train_index]
    y_train = y_train_all[inner_train_index]
    X_val = X_train_all.iloc[val_index]
    y_val = y_train_all[val_index]
    break

print(f"Final sizes — train: {len(X_train)}, val: {len(X_val)}, test: {len(X_test)}")
Using outer fold 1
Outer train size: 12969, test size: 2570
% positive train: 0.555, % positive test: 0.566
Using inner fold 1 for train/val split
Final sizes — train: 8423, val: 4546, test: 2570

Build Hugging Face Datasets & tokenize

We need to convert the different sets to the format required by Hugging Face.

[6]:
train_df = pd.DataFrame({"sequence": X_train.values, "labels": y_train})
val_df   = pd.DataFrame({"sequence": X_val.values,   "labels": y_val})
test_df  = pd.DataFrame({"sequence": X_test.values,  "labels": y_test})

raw_datasets = DatasetDict({
    "train": Dataset.from_pandas(train_df.reset_index(drop=True)),
    "validation": Dataset.from_pandas(val_df.reset_index(drop=True)),
    "test": Dataset.from_pandas(test_df.reset_index(drop=True)),
})

model, tokenizer = load_antiberty_classifier(num_labels=2)
model = freeze_antiberty_layers(model, train_last_n_layers=3)

def preprocess_function(batch):
    encodings = tokenizer(
        batch["sequence"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
    )
    encodings["labels"] = batch["labels"]
    return encodings

tokenized_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=["sequence"],
)
AntiBERTy model: /vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty/trained_models/AntiBERTy_md_smooth
AntiBERTy vocab: /vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty/trained_models/vocab.txt
/home/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty/trained_models/AntiBERTy_md_smooth and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Model size: 25.76M parameters
Map: 100%|██████████| 8423/8423 [00:02<00:00, 3736.40 examples/s]
Map: 100%|██████████| 4546/4546 [00:01<00:00, 3854.12 examples/s]
Map: 100%|██████████| 2570/2570 [00:00<00:00, 4093.28 examples/s]

Set up the parameters for fine-tuning AntiBERTy

[7]:
set_seed(1)

FOLD_ID = 1
OUT_PATH = os.path.join(OUTPUT_DIR, f"{RUN_ID}_Fold_{FOLD_ID}")
os.makedirs(OUT_PATH, exist_ok=True)
print("Saving checkpoints to:", OUT_PATH)

training_args = TrainingArguments(
    output_dir=OUT_PATH,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    learning_rate=LR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=N_EPOCHS,
    warmup_ratio=0.0,
    load_best_model_at_end=True,
    metric_for_best_model="auc",
    lr_scheduler_type="linear",
    seed=1
)

trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    compute_metrics=compute_metrics,
)
/home/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Saving checkpoints to: ../models/S_antiBERTy_HL_fine_tuning_CDR3_Fold_1

Fine tune AntiBERTy on the specificity dataset

[8]:
train_result = trainer.train()
train_result.metrics
[1320/1320 10:27, Epoch 10/10]
Epoch Training Loss Validation Loss Precision Recall F1 Weighted Apr Balanced Accuracy Auc Mcc
1 0.682100 0.674185 0.571391 0.870452 0.512843 0.641419 0.535959 0.599263 0.097128
2 0.668300 0.668239 0.586726 0.841263 0.548888 0.662187 0.558285 0.619711 0.142084
3 0.651600 0.659277 0.617122 0.732107 0.594114 0.676344 0.588303 0.634982 0.184527
4 0.641700 0.654430 0.623169 0.731307 0.601285 0.688067 0.595238 0.647339 0.198104
5 0.630900 0.651899 0.628552 0.725310 0.606789 0.693455 0.600552 0.653557 0.207789
6 0.623200 0.657183 0.664266 0.590164 0.611359 0.695164 0.612686 0.654203 0.224292
7 0.616500 0.650168 0.643408 0.700520 0.618827 0.700375 0.612852 0.659554 0.229099
8 0.611000 0.650533 0.650999 0.664534 0.618953 0.701372 0.614419 0.660647 0.229431
9 0.606200 0.650401 0.648743 0.670532 0.618158 0.703339 0.613261 0.662185 0.227533
10 0.605900 0.649999 0.642884 0.695322 0.617356 0.703881 0.611475 0.662833 0.225944

[8]:
{'train_runtime': 628.6721,
 'train_samples_per_second': 133.981,
 'train_steps_per_second': 2.1,
 'total_flos': 6568285780076400.0,
 'train_loss': 0.6337348244406961,
 'epoch': 10.0}

Evaluate the model on held-out test set

[9]:
model.eval()
test_outputs = trainer.predict(tokenized_datasets["test"])
test_metrics = test_outputs.metrics

print("Test metrics:")
for k, v in test_metrics.items():
    print(f"{k}: {v:.4f}")
Test metrics:
test_loss: 0.6636
test_precision: 0.6531
test_recall: 0.6527
test_f1_weighted: 0.6074
test_apr: 0.6931
test_balanced_accuracy: 0.6005
test_auc: 0.6438
test_mcc: 0.2010
test_runtime: 8.7197
test_samples_per_second: 294.7350
test_steps_per_second: 4.7020

Using the fine-tuned AntiBERTy model with Amulety on epitope-specific sequences

In this section we:

  1. Export the fine-tuned AntiBERTy classifier as a Hugging Face model.

  2. Register it as a custom model for Amulety.

  3. Use Amulety to embed a few example sequences with this fine-tuned encoder.

This illustrates how you can fine-tune AntiBERTy for S-protein binding and then reuse the same model inside Amulety to generate embeddings for other datasets (e.g. epitope-specificity panels).

We will start by saving the fine-tuned AntiBERTy classifier as a Hugging Face model.

[10]:
from pathlib import Path

# Directory where we will save the fine-tuned model for Amulety
CUSTOM_MODEL_PATH = Path(OUTPUT_DIR) / f"{RUN_ID}_amulety_custom"
CUSTOM_MODEL_PATH.mkdir(parents=True, exist_ok=True)

# `trainer` already holds the fine-tuned best model because we used `load_best_model_at_end=True`
# but to be explicit, we save the current `model` and `tokenizer`.
model.save_pretrained(CUSTOM_MODEL_PATH)
tokenizer.save_pretrained(CUSTOM_MODEL_PATH)

print("Saved fine-tuned model for Amulety at:", CUSTOM_MODEL_PATH)

Saved fine-tuned model for Amulety at: ../models/S_antiBERTy_HL_fine_tuning_CDR3_amulety_custom

We will then download a new dataset and sample 1000 cells for demonstration purposes.

[33]:
airr_demo = pd.read_csv("https://zenodo.org/records/17186858/files/ML_bcr_airr_dataset.tsv", sep='\t')
# cells that have at least one H and one L
mask_both = (
    airr_demo.groupby("cell_id")["chain_type"]
      .agg(lambda x: set(x))
      .pipe(lambda s: s[s.apply(lambda st: {"H", "L"}.issubset(st))])
)
cells_with_both = mask_both.index

# sample up to 1000 such cells (no replacement)
n_cells = min(1000, len(cells_with_both))
sampled_cells = np.random.choice(cells_with_both, size=n_cells, replace=False)
df_sampled = airr_demo[airr_demo["cell_id"].isin(sampled_cells)].copy()
df_sampled.head()
/tmp/tmp.UnNt4ZCghI/ipykernel_1723782/2594638705.py:1: DtypeWarning: Columns (10) have mixed types. Specify dtype option on import or set low_memory=False.
  airr_demo = pd.read_csv("https://zenodo.org/records/17186858/files/ML_bcr_airr_dataset.tsv", sep='\t')
[33]:
sequence_id sequence_vdj_aa locus cell_id chain_type v_call v_call_family j_call_family mu_freq junction_aa_length isotype source subject specificity duplicate_count productive rev_comp stop_codon vj_in_frame
79 80_heavy EVQLVESGGGLVQPGGSLRLSCVASGFTFSSYWMSWVRQAPGKGLE... IGH cell_80 H IGHV3-7 IGHV3 IGHJ4 0.003378 18.0 NaN OAS OAS_King_Subject-BCP3 unlabeled 1 True False False True
189 190_heavy EVQLVQSGAEVKKPGESLKISCKGSAYSFTNYWIAWVRQMPGKGLE... IGH cell_190 H IGHV5-51 IGHV5 IGHJ3 0.013158 20.0 NaN OAS OAS_King_Subject-BCP3 unlabeled 1 True False False True
298 299_heavy EVQLVESGGGLVKPGGSLRLSCSASRFTFSTYRMNWVRQAPGKGLE... IGH cell_299 H IGHV3-21 IGHV3 IGHJ2 0.023256 23.0 NaN OAS OAS_King_Subject-BCP3 unlabeled 1 True False False True
1084 1085_heavy QVQLVQSGAEVREPGASVKVSCKASGYTFTIYDINWVRQAPGQGLE... IGH cell_1085 H IGHV1-8 IGHV1 IGHJ4 0.062706 13.0 NaN OAS OAS_King_Subject-BCP3 unlabeled 1 True False False True
2712 2713_heavy EVQLVESGGGLVQRGGSLRLSCGASGFTFSSYNMNWVRQAPGKGLE... IGH cell_2713 H IGHV3-48 IGHV3 IGHJ3 0.032895 14.0 NaN OAS OAS_King_Subject-BCP4 unlabeled 1 True False False True

We will then use amulety to generate embedding vectors for these sequences using our fine-tuned AntiBERTy model.

[31]:
from amulety import embed_airr

CHAIN = MODEL_TYPE
EMBED_DIM = int(model.config.hidden_size)

print("Custom model path:", CUSTOM_MODEL_PATH)
print("Embedding dimension:", EMBED_DIM)
print("Max length used by tutorial:", MAX_LENGTH)

embeddings_df, meta_df = embed_airr(
    airr=df_sampled,
    chain=CHAIN,
    model="custom",
    sequence_col="sequence_vdj_aa",
    cell_id_col="cell_id",
    batch_size=8,
    model_path=str(CUSTOM_MODEL_PATH),
    embedding_dimension=EMBED_DIM,
    max_length=MAX_LENGTH,
    output_type="df",       # return embeddings as a DataFrame
    residue_level=False,    # sequence-level embeddings
)

print("Embeddings shape:", embeddings_df.shape)
embeddings_df.head()
/home/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/amulety/amulety.py:575: UserWarning: Custom protein language model might not understand paired chain relationships. Chain 'HL' will be processed as concatenated sequences, but results may be inaccurate.
  warnings.warn(
Some weights of BertForMaskedLM were not initialized from the model checkpoint at ../models/S_antiBERTy_HL_fine_tuning_CDR3_amulety_custom and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Custom model path: ../models/S_antiBERTy_HL_fine_tuning_CDR3_amulety_custom
Embedding dimension: 512
Max length used by tutorial: 510
Embeddings shape: (1000, 513)
[31]:
cell_id dim_1 dim_2 dim_3 dim_4 dim_5 dim_6 dim_7 dim_8 dim_9 ... dim_503 dim_504 dim_505 dim_506 dim_507 dim_508 dim_509 dim_510 dim_511 dim_512
0 cell_100681 -0.077534 -0.109347 -1.533826 -0.553431 1.118874 0.011019 1.411259 -1.877443 1.221428 ... -0.486696 0.669253 -0.199235 0.142055 -0.412014 0.974601 -0.765339 -0.230175 0.045813 -0.22751
1 cell_101224 -0.077534 -0.109347 -1.533826 -0.553431 1.118874 0.011019 1.411259 -1.877443 1.221428 ... -0.486696 0.669253 -0.199235 0.142055 -0.412014 0.974601 -0.765339 -0.230175 0.045813 -0.22751
2 cell_101233 -0.077534 -0.109347 -1.533826 -0.553431 1.118874 0.011019 1.411259 -1.877443 1.221428 ... -0.486696 0.669253 -0.199235 0.142055 -0.412014 0.974601 -0.765339 -0.230175 0.045813 -0.22751
3 cell_101360 -0.077534 -0.109347 -1.533826 -0.553431 1.118874 0.011019 1.411259 -1.877443 1.221428 ... -0.486696 0.669253 -0.199235 0.142055 -0.412014 0.974601 -0.765339 -0.230175 0.045813 -0.22751
4 cell_102058 -0.077534 -0.109347 -1.533826 -0.553431 1.118874 0.011019 1.411259 -1.877443 1.221428 ... -0.486696 0.669253 -0.199235 0.142055 -0.412014 0.974601 -0.765339 -0.230175 0.045813 -0.22751

5 rows × 513 columns

At this point:

  • embeddings_df contains one row per input sequence and 512 embedding features (plus identifiers), generated by the fine-tuned AntiBERTy model via Amulety.

You can now:

  • Project these embeddings with UMAP / t-SNE,

  • Train simple classifiers or regressors for predictions,

  • Or integrate them into larger downstream pipelines, all using the same Amulety interface you use for other models.