首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >蛋白质语言模型微调cookbook

蛋白质语言模型微调cookbook

作者头像
Tom2Code
发布2026-04-17 17:29:06
发布2026-04-17 17:29:06
1320
举报

首先介绍一下数据集的来源:来自于2021年的一篇bib的论文

一.背景介绍

数据集介绍

细胞毒性T细胞在适应性免疫系统中扮演着至关重要的角色,通过寻找、结合并杀灭表面呈现外来抗原的细胞。因此,更好地理解T细胞免疫力将极大地帮助开发新的癌症免疫疗法和针对危及生命的病原体(如SARS-CoV-2/COVID-19病毒)的疫苗。设计此类靶向疗法的核心在于计算方法,用以预测哪些非自身肽段最有可能引发T细胞反应。人体白细胞抗原(HLA)是一类存在于所有人体有核细胞表面的多态性蛋白,它们将外来抗原呈递给T细胞受体(TCRs)。预测MHC-I结合肽段的免疫原性对于理解T细胞导向的适应性免疫的分子规则至关重要。然而,尽管现有的HLA-肽段结合预测工具种类繁多,但它们不足以推断免疫原性,因为它们不能模拟哪些肽段将触发T细胞反应。此外,历史上的免疫原性预测方法一直面临挑战,部分原因是训练数据集规模较小且对HLA等位基因的考虑有限。

数据集详情

用于DeepImmuno模型的初步训练和验证的数据集来源于免疫表位数据库(Immune Epitope Database, IEDB),分析了截至2020年8月13日的超过9000个已测试的免疫原性分子测定

研究人员采用严格的数据清洗策略来仅保留信息量高的预测结果:

  1. 数据实例必须明确匹配以下关键词:线性表位T细胞测定MHC I类人类任何疾病
  2. 丢弃了没有明确的4位MHC等位基因的数据实例。
  3. 移除了所有冗余的肽段–MHC等位基因实例(同一肽段与不同HLA等位基因结合被视为不同的实例)。
  4. 移除了所有缺乏明确实验信息(测试受试者人数、有反应受试者人数)或测试受试者少于四人的阴性肽段,因为这些数据在人类群体水平上可能信息不足。
  5. 最终,保留了长度为9聚体和10聚体的肽段进行训练,因为这两种长度覆盖了所有数据实例的97.5%,并且是MHC I类结合肽段的主要长度。

经过处理后,最终数据集中保留了8971个数据实例,其中4059个为阳性反应实例,其余4912个为阴性实例

此外,研究人员选择了三个独立的、先前已验证的免疫原性肽段集合来进行系统基准测试:

  1. 408个登革病毒阳性实例,用于初始验证不同的预测方法。
  2. 608个来自肿瘤新抗原选择联盟 (TESLA) 的经过实验测试的肿瘤特异性新抗原,作为独立的测试数据集。
  3. 100个SARS-CoV-2肽段,这些肽段在康复期和未暴露的受试者中测试了免疫原性,作为另一个独立的测试数据集。

上图就是我们本次要使用的数据集

二.基于ESM2蛋白质语言模型的微调

2.1训练目标:

学习类型:监督学习

学习任务:二元分类

评估指标:ROC曲线下面积(AUC)

利用蛋白质语言模型预测肽的免疫原性

2.2代码

安装依赖:

代码语言:javascript
复制
# the transformers library should be pre-installed in Colab (v4.39.3)
!pip install -q -U peft==0.10.0
!pip install -q -U accelerate==0.29.2
!pip install -q -U datasets==2.18.0
!pip install -q -U evaluate==0.4.1

导入依赖:

代码语言:javascript
复制
import random
import os
import torch

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, 
    precision_score, 
    recall_score, 
    roc_auc_score,
    average_precision_score
)

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    EsmForSequenceClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
    get_scheduler
)

from peft import get_peft_model, LoraConfig, PeftModel
from datasets import Dataset, DatasetDict
from evaluate import load

加载数据:

代码语言:javascript
复制
df=pd.read_csv("https://raw.githubusercontent.com/frankligy/DeepImmuno/main/reproduce/data/remove0123_sample100.csv")
print(len(df))
df.head()

输出:

数据集统计:

代码语言:javascript
复制
df['immunogenicity'].value_counts()

输出:

长度统计:

代码语言:javascript
复制
# 1) 新增序列长度列(去掉可能的空白,再算长度)
df["pep_len"] = df["peptide"].astype(str).str.strip().str.len()

# 2) 看头部 + 基本统计
print(df[["peptide", "pep_len"]].head())
print("\n样本数:", len(df))
print("\n长度统计:")
print(df["pep_len"].describe())

# 3) 各长度的频数分布
print("\n各长度频数:")
print(df["pep_len"].value_counts().sort_index())

# 4) 若有缺失或非字符串,顺手查一下
print("\n缺失/异常条目数量:", df["peptide"].isna().sum())

输出:

定义标签

代码语言:javascript
复制
# create a new column 'labels' to contain binary labels
df['labels'] = df['immunogenicity'].apply(lambda x: 0 if x == 'Negative' else 1)
df['labels'].value_counts(normalize=True)

输出:

数据集划分:

代码语言:javascript
复制
train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df['labels'],
    random_state=42)

print(len(train_df))
print(len(val_df))

输出:

转化数据集:

代码语言:javascript
复制
# 转化数据集成为huggingface数据集
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset
})

加载tokenizer,我们本次使用esm2_t6_8M_UR50D蛋白质语言模型

其他模型:

esm2全家族模型

代码语言:javascript
复制
# 加载tokenizer并对数据进行分词
model_checkpoint = "facebook/esm2_t6_8M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def tokenize(examples, max_length=1023):
    text = examples["peptide"]
    encoding = tokenizer(text, truncation=True, max_length=max_length)
    encoding["labels"] = examples["labels"]
    return encoding

encoded_dataset = dataset_dict.map(
    tokenize,
    batched=True,
    num_proc=os.cpu_count(),
    remove_columns=dataset_dict["train"].column_names
)

encoded_dataset.set_format("torch")

指定任务头:

代码语言:javascript
复制
# load model checkpoint for classification
model = EsmForSequenceClassification.from_pretrained(
 model_checkpoint,
 num_labels=2
)

微调,使用LoRA

代码语言:javascript
复制
# configure model for LoRA fine-tuning
peft_config = LoraConfig(
    task_type="SEQ_CLS",
    inference_mode=False,
    bias="none",
    r=8, # rank number
    lora_alpha=16, # scaling factor)
    lora_dropout=0.2, # dropout prob
    target_modules=[ # which layers to apply LoRA
        "query",
        "key",
        "value"
    ],
    modules_to_save=['classifier'] # ensures that the fine-tuned classifier head is saved when calling trainer.save_model later
)

model = get_peft_model(model, peft_config)

# adjust dropout in the classifier head
model.base_model.model.classifier.modules_to_save.default.dropout.p = 0.25

这里我们制定了bias为none,这样不会对原模型的bias做适配,这样最省参数,task_type设置为'SEQ_CLS',为序列分类任务,PEFT 会按此设置一些默认行为。

LoRA的具体流程大家可以参考下图,如果不理解可以进一步学习和讨论:

lora的示意图

统计训练参数:

代码语言:javascript
复制
# show amount of trainable parameters
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

print_trainable_parameters(model)

输出:

所以这样微调的话,需要训练的参数是非常少的。

设置训练参数,并开始训练:

代码语言:javascript
复制
# configure training args
num_train_epochs = 10
batch_size = 16
learning_rate = 1e-3

args = TrainingArguments(
    seed=42,
    fp16=True,
    output_dir='./results',
    evaluation_strategy = "steps",
    num_train_epochs=num_train_epochs,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=4,
    # gradient_checkpointing=True,
    logging_steps=50,
    eval_steps=50,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',
    metric_for_best_model="auc_roc",
    load_best_model_at_end=True,
    report_to='none'# Disable Weights & Biases logging
)

# define metrics to compute during training
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    softmax = torch.nn.Softmax(dim=1)
    probabilities = softmax(torch.tensor(logits)).numpy()
    predictions = np.argmax(probabilities, axis=1)
    probabilities_pos_class = probabilities[:, 1]

    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, zero_division=0)
    recall = recall_score(labels, predictions, zero_division=0)
    auc = roc_auc_score(labels, probabilities_pos_class)
    auc_pr = average_precision_score(labels, probabilities_pos_class)

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "auc_roc": auc,
        "auc_pr": auc_pr
    }

# define early stopping
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.01
)

# create trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=encoded_dataset['train'],
    eval_dataset=encoded_dataset['validation'],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback]
)

# train model
trainer.train()

输出:

training loss曲线

validate loss曲线

验证:

代码语言:javascript
复制
# evaluate model on validation set
eval_dict = trainer.evaluate()
eval_dict

输出:

进行预测:

在训练的时候采用了早停机制,本次训练在双T4的GPU上运行,如果读者想要训练更好的模型,可以修改参数,进一步提高模型的性能。

三.结尾

我们先把包含肽段序列和正/负标签的数据读入并分成训练集与验证集;再用 ESM 的分词器把序列转成模型能理解的“数字语言”;接着在预训练好的 ESM 上挂一个小小的分类头,用 LoRA 这种“只微调少量新参数、不改动大模型主体”的方式完成训练;训练过程中同步监控常见评估指标并启用早停,自动选出表现最好的版本。最终,我们得到一个基于蛋白质语言模型的免疫原性预测模型。

如果大家对plm,protein language model感兴趣,也可以阅读下图所示最新的一期论文, 使用esmc作为基础模型进行蛋白质基础语言模型,进行蛋白质与脂质的结合表征,以及结合预测。

https://www.biorxiv.org/content/10.1101/2025.09.09.675043v1.full

最后,您的点赞,是Tom持续更新的动力,thanks

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-09-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Tom的小院 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档