
今天浅浅的更新一下使用esm2蛋白质语言模型预测蛋白质-蛋白质的相互作用
首先先补充一些预备知识:
预备知识点1:
bert与Masked-Language-Model (MLM) loss


这两个箭头位置是等着模型(可以是任意模型,可以是esm2蛋白质语言模型,whatever)去猜测这些位置的氨基酸是什么。
模型只要能根据左右邻居成功猜回被遮住的字母,就说明它学到了序列中的“语法与语义”(结构、功能信息)。
了



好了以上就是MLM-loss的含义。
接下来进入正文:
一.背景
在分子生物学里,预测哪两种蛋白会“牵手”非常重要。我们这次用 Meta AI 的 ESM-2 模型,通过计算一对蛋白序列在“掩码语言模型”(MLM)任务中的平均损失,来判断它们是否可能相互作用。
直观上讲:如果两个蛋白真能相互作用,模型在猜被“藏”起来的氨基酸时会更有把握、出错更少,损失也就更低;反之,损失就会更高。于是,我们只要把损失最低的那些蛋白对找出来,就能快速锁定最有可能真正结合的蛋白组合。
二.上干货
2.1导入依赖
from scipy.optimize import linear_sum_assignment
from transformers import AutoTokenizer, EsmForMaskedLM
import torch2.2定义模型和蛋白质分词器
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")运行结果:

2.3迁移到gpu上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
2.4 定义一些蛋白质序列(为了测试)
all_proteins = [
"MEESQSELNIDPPLSQETFSELWNLLPENNVLSSELCPAVDELLLPESVVNWLDEDSDDAPRMPATSAPTAPGPAPSWPLSSSVPSPKTYPGTYGFRLGFLHSGTAKSVTWTYSPLLNKLFCQLAKTCPVQLWVSSPPPPNTCVRAMAIYKKSEFVTEVVRRCPHHERCSDSSDGLAPPQHLIRVEGNLRAKYLDDRNTFRHSVVVPYEPPEVGSDYTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNVLGRNSFEVRVCACPGRDRRTEEENFHKKGEPCPEPPPGSTKRALPPSTSSSPPQKKKPLDGEYFTLQIRGRERYEMFRNLNEALELKDAQSGKEPGGSRAHSSHLKAKKGQSTSRHKKLMFKREGLDSD",
"MCNTNMSVPTDGAVTTSQIPASEQETLVRPKPLLLKLLKSVGAQKDTYTMKEVLFYLGQYIMTKRLYDEKQQHIVYCSNDLLGDLFGVPSFSVKEHRKIYTMIYRNLVVVNQQESSDSGTSVSENRCHLEGGSDQKDLVQELQEEKPSSSHLVSRPSTSSRRRAISETEENSDELSGERQRKRHKSDSISLSFDESLALCVIREICCERSSSSESTGTPSNPDLDAGVSEHSGDWLDQDSVSDQFSVEFEVESLDSEDYSLSEEGQELSDEDDEVYQVTVYQAGESDTDSFEEDPEISLADYWKCTSCNEMNPPLPSHCNRCWALRENWLPEDKGKDKGEISEKAKLENSTQAEEGFDVPDCKKTIVNDSRESCVEENDDKITQASQSQESEDYSQPSTSSSIIYSSQEDVKEFEREETQDKEESVESSLPLNAIEPCVICQGRPKNGCIVHGKTGHLMACFTCAKKLKKRNKPCPVCRQPIQMIVLTYFP",
"MNRGVPFRHLLLVLQLALLPAATQGKKVVLGKKGDTVELTCTASQKKSIQFHWKNSNQIKILGNQGSFLTKGPSKLNDRADSRRSLWDQGNFPLIIKNLKIEDSDTYICEVEDQKEEVQLLVFGLTANSDTHLLQGQSLTLTLESPPGSSPSVQCRSPRGKNIQGGKTLSVSQLELQDSGTWTCTVLQNQKKVEFKIDIVVLAFQKASSIVYKKEGEQVEFSFPLAFTVEKLTGSGELWWQAERASSSKSWITFDLKNKEVSVKRVTQDPKLQMGKKLPLHLTLPQALPQYAGSGNLTLALEAKTGKLHQEVNLVVMRATQLQKNLTCEVWGPTSPKLMLSLKLENKEAKVSKREKAVWVLNPEAGMWQCLLSDSGQVLLESNIKVLPTWSTPVQPMALIVLGGVAGLLLFIGLGIFFCVRCRHRRRQAERMSQIKRLLSEKKTCQCPHRFQKTCSPI",
"MRVKEKYQHLWRWGWKWGTMLLGILMICSATEKLWVTVYYGVPVWKEATTTLFCASDAKAYDTEVHNVWATHACVPTDPNPQEVVLVNVTENFNMWKNDMVEQMHEDIISLWDQSLKPCVKLTPLCVSLKCTDLGNATNTNSSNTNSSSGEMMMEKGEIKNCSFNISTSIRGKVQKEYAFFYKLDIIPIDNDTTSYTLTSCNTSVITQACPKVSFEPIPIHYCAPAGFAILKCNNKTFNGTGPCTNVSTVQCTHGIRPVVSTQLLLNGSLAEEEVVIRSANFTDNAKTIIVQLNQSVEINCTRPNNNTRKSIRIQRGPGRAFVTIGKIGNMRQAHCNISRAKWNATLKQIASKLREQFGNNKTIIFKQSSGGDPEIVTHSFNCGGEFFYCNSTQLFNSTWFNSTWSTEGSNNTEGSDTITLPCRIKQFINMWQEVGKAMYAPPISGQIRCSSNITGLLLTRDGGNNNNGSEIFRPGGGDMRDNWRSELYKYKVVKIEPLGVAPTKAKRRVVQREKRAVGIGALFLGFLGAAGSTMGARSMTLTVQARQLLSGIVQQQNNLLRAIEAQQHLLQLTVWGIKQLQARILAVERYLKDQQLLGIWGCSGKLICTTAVPWNASWSNKSLEQIWNNMTWMEWDREINNYTSLIHSLIEESQNQQEKNEQELLELDKWASLWNWFNITNWLWYIKIFIMIVGGLVGLRIVFAVLSIVNRVRQGYSPLSFQTHLPTPRGPDRPEGIEEEGGERDRDRSIRLVNGSLALIWDDLRSLCLFSYHRLRDLLLIVTRIVELLGRRGWEALKYWWNLLQYWSQELKNSAVSLLNATAIAVAEGTDRVIEVVQGACRAIRHIPRRIRQGLERILL",
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
"MATGGRRGAAAAPLLVAVAALLLGAAGHLYPGEVCPGMDIRNNLTRLHELENCSVIEGHLQILLMFKTRPEDFRDLSFPKLIMITDYLLLFRVYGLESLKDLFPNLTVIRGSRLFFNYALVIFEMVHLKELGLYNLMNITRGSVRIEKNNELCYLATIDWSRILDSVEDNYIVLNKDDNEECGDICPGTAKGKTNCPATVINGQFVERCWTHSHCQKVCPTICKSHGCTAEGLCCHSECLGNCSQPDDPTKCVACRNFYLDGRCVETCPPPYYHFQDWRCVNFSFCQDLHHKCKNSRRQGCHQYVIHNNKCIPECPSGYTMNSSNLLCTPCLGPCPKVCHLLEGEKTIDSVTSAQELRGCTVINGSLIINIRGGNNLAAELEANLGLIEEISGYLKIRRSYALVSLSFFRKLRLIRGETLEIGNYSFYALDNQNLRQLWDWSKHNLTITQGKLFFHYNPKLCLSEIHKMEEVSGTKGRQERNDIALKTNGDQASCENELLKFSYIRTSFDKILLRWEPYWPPDFRDLLGFMLFYKEAPYQNVTEFDGQDACGSNSWTVVDIDPPLRSNDPKSQNHPGWLMRGLKPWTQYAIFVKTLVTFSDERRTYGAKSDIIYVQTDATNPSVPLDPISVSNSSSQIILKWKPPSDPNGNITHYLVFWERQAEDSELFELDYCLKGLKLPSRTWSPPFESEDSQKHNQSEYEDSAGECCSCPKTDSQILKELEESSFRKTFEDYLHNVVFVPRKTSSGTGAEDPRPSRKRRSLGDVGNVTVAVPTVAAFPNTSSTSVPTSPEEHRPFEKVVNKESLVISGLRHFTGYRIELQACNQDTPEERCSVAAYVSARTMPEAKADDIVGPVTHEIFENNVVHLMWQEPKEPNGLIVLYEVSYRRYGDEELHLCVSRKHFALERGCRLRGLSPGNYSVRIRATSLAGNGSWTEPTYFYVTDYLDVPSNIAKIIIGPLIFVFLFSVVIGSIYLFLRKRQPDGPLGPLYASSNPEYLSASDVFPCSVYVPDEWEVSR"
]
2.5定义MLM损失函数
BATCH_SIZE = 2
NUM_MASKS = 10
P_MASK = 0.15
# Function to compute MLM loss for a batch of protein pairs
def compute_mlm_loss_batch(pairs):
avg_losses = []
for _ in range(NUM_MASKS):
# Tokenize the concatenated protein pairs
inputs = tokenizer(pairs, return_tensors="pt", truncation=True, padding=True, max_length=1022)
# Move input tensors to GPU if available
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get the mask token ID
mask_token_id = tokenizer.mask_token_id
# Clone input IDs for labels
labels = inputs["input_ids"].clone()
# Randomly mask 15% of the residues for each sequence in the batch
for idx in range(inputs["input_ids"].shape[0]):
mask_indices = np.random.choice(inputs["input_ids"].shape[1], size=int(P_MASK * inputs["input_ids"].shape[1]), replace=False)
inputs["input_ids"][idx, mask_indices] = mask_token_id
labels[idx[i for i in range(inputs["input_ids"].shape[1]) if i not in mask_indices]] = -100
# Compute the MLM loss
outputs = model(**inputs, labels=labels)
avg_losses.append(outputs.loss.item())
# Return the average loss for the batch
return sum(avg_losses) / NUM_MASKS
为啥batch_size设置成2,是因为每次处理两个拼接后的蛋白质序列对儿。
为啥num_masks设置成10,是因为会进行10次mask,然后计算其平均损失
为啥p_mask是0.15,因为论文中对这个概率进行了论述,如果设置的太大,那么对蛋白质序列进行mask的就太多,反之就是太少,无法让模型学习到序列本身的上下文信息。
这里需要注意,使用更大的模型并扩展上下文窗口可能会提升效果,但也会消耗更多计算资源。如果你想使用更大模型和更长的上下文窗口,可以考虑其它 ESM-2 模型,例如 esm2_t36_3B_UR50D。你也可以尝试调整上述代码中的 max_length。对于我们这里选的长蛋白,上述上下文窗口并不够用,使用更大且上下文窗口更长的模型,或者选择更短的蛋白,几乎肯定能得到更好的结果,但以上内容只是一个入门开胃菜。
2.6构建损失矩阵
# Compute loss matrix
loss_matrix = np.zeros((len(all_proteins), len(all_proteins)))
for i in range(len(all_proteins)):
for j in range(i+1, len(all_proteins), BATCH_SIZE): # to avoid self-pairing and use batches
pairs = [all_proteins[i] + all_proteins[k] for k in range(j, min(j+BATCH_SIZE, len(all_proteins)))]
batch_loss = compute_mlm_loss_batch(pairs)
for k in range(len(pairs)):
loss_matrix[i, j+k] = batch_loss
loss_matrix[j+k, i] = batch_loss # the matrix is symmetric
# Set the diagonal of the loss matrix to a large value to prevent self-pairings
np.fill_diagonal(loss_matrix, np.inf)
将所有蛋白两两配对后,按批次调用 compute_mlm_loss_batch 计算每对的平均 MLM 损失,并把结果填入一个对称的损失矩阵中,最后把主对角线(蛋白与自身配对)的值设为无穷大,以确保后续在寻找最低损失配对时不会选到自配对。
2.7寻找最优配对
# Use the linear assignment problem to find the optimal pairing based on MLM loss
rows, cols = linear_sum_assignment(loss_matrix)
optimal_pairs = list(zip(rows, cols))
print(optimal_pairs)
最后,我们使用linear_sum_assignment函数来找到蛋白质的最佳配对,
输出结果:

2.8 综合案例
现在,你可以先选择一个目标蛋白,并按照本教程计算它的结合位点。一旦确定了结合位点,就可以转到 RFDiffusion 的 notebook,为你的蛋白设计多个结合伙伴。设计好几种结合伙伴后,就可以使用上述代码,将你的目标蛋白与这些潜在结合伙伴拼接起来,计算它们的 MLM 损失,从而测试哪一个具有最高的结合亲和力。请记住,蛋白越长,需要的上下文窗口就越大。下面是一个示例,说明如何使用这种方法对一组固定蛋白的结合伙伴进行排序:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
# Load the base model and tokenizer
base_model_path = "facebook/esm2_t12_35M_UR50D"
model = AutoModelForMaskedLM.from_pretrained(base_model_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# Ensure the model is in evaluation mode
model.eval()
# Define the protein of interest and its potential binders
protein_of_interest = "MLTEVMEVWHGLVIAVVSLFLQACFLTAINYLLSRHMAHKSEQILKAASLQVPRPSPGHHHPPAVKEMKETQTERDIPMSDSLYRHDSDTPSDSLDSSCSSPPACQATEDVDYTQVVFSDPGELKNDSPLDYENIKEITDYVNVNPERHKPSFWYFVNPALSEPAEYDQVAM"
potential_binders = [
"MASPGSGFWSFGSEDGSGDSENPGTARAWCQVAQKFTGGIGNKLCALLYGDAEKPAESGGSQPPRAAARKAACACDQKPCSCSKVDVNYAFLHATDLLPACDGERPTLAFLQDVMNILLQYVVKSFDRSTKVIDFHYPNELLQEYNWELADQPQNLEEILMHCQTTLKYAIKTGHPRYFNQLSTGLDMVGLAADWLTSTANTNMFTYEIAPVFVLLEYVTLKKMREIIGWPGGSGDGIFSPGGAISNMYAMMIARFKMFPEVKEKGMAALPRLIAFTSEHSHFSLKKGAAALGIGTDSVILIKCDERGKMIPSDLERRILEAKQKGFVPFLVSATAGTTVYGAFDPLLAVADICKKYKIWMHVDAAWGGGLLMSRKHKWKLSGVERANSVTWNPHKMMGVPLQCSALLVREEGLMQNCNQMHASYLFQQDKHYDLSYDTGDKALQCGRHVDVFKLWLMWRAKGTTGFEAHVDKCLELAEYLYNIIKNREGYEMVFDGKPQHTNVCFWYIPPSLRTLEDNEERMSRLSKVAPVIKARMMEYGTTMVSYQPLGDKVNFFRMVISNPAATHQDIDFLIEEIERLGQDL",
"MAAGVAGWGVEAEEFEDAPDVEPLEPTLSNIIEQRSLKWIFVGGKGGVGKTTCSCSLAVQLSKGRESVLIISTDPAHNISDAFDQKFSKVPTKVKGYDNLFAMEIDPSLGVAELPDEFFEEDNMLSMGKKMMQEAMSAFPGIDEAMSYAEVMRLVKGMNFSVVVFDTAPTGHTLRLLNFPTIVERGLGRLMQIKNQISPFISQMCNMLGLGDMNADQLASKLEETLPVIRSVSEQFKDPEQTTFICVCIAEFLSLYETERLIQELAKCKIDTHNIIVNQLVFPDPEKPCKMCEARHKIQAKYLDQMEDLYEDFHIVKLPLLPHEVRGADKVNTFSALLLEPYKPPSAQ",
"EKTGLSIRGAQEEDPPDPQLMRLDNMLLAEGVSGPEKGGGSAAAAAAAAASGGSSDNSIEHSDYRAKLTQIRQIYHTELEKYEQACNEFTTHVMNLLREQSRTRPISPKEIERMVGIIHRKFSSIQMQLKQSTCEAVMILRSRFLDARRKRRNFSKQATEILNEYFYSHLSNPYPSEEAKEELAKKCSITVSQSLVKDPKERGSKGSDIQPTSVVSNWFGNKRIRYKKNIGKFQEEANLYAAKTAVTAAHAVAAAVQNNQTNSPTTPNSGSSGSFNLPNSGDMFMNMQSLNGDSYQGSQVGANVQSQVDTLRHVINQTGGYSDGLGGNSLYSPHNLNANGGWQDATTPSSVTSPTEGPGSVHSDTSN"
] # Add potential binding sequences here
def compute_mlm_loss(protein, binder, iterations=3):
total_loss = 0.0
for _ in range(iterations):
# Concatenate protein sequences with a separator
concatenated_sequence = protein + ":" + binder
# Mask a subset of amino acids in the concatenated sequence (excluding the separator)
tokens = list(concatenated_sequence)
mask_rate = 0.15 # For instance, masking 15% of the sequence
num_mask = int(len(tokens) * mask_rate)
# Exclude the separator from potential mask indices
available_indices = [i for i, token in enumerate(tokens) if token != ":"]
probs = torch.ones(len(available_indices))
mask_indices = torch.multinomial(probs, num_mask, replacement=False)
for idx in mask_indices:
tokens[available_indices[idx]] = tokenizer.mask_token
masked_sequence = "".join(tokens)
inputs = tokenizer(masked_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
# Compute the MLM loss
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
total_loss += loss.item()
# Return the average loss
return total_loss / iterations
# Compute MLM loss for each potential binder
mlm_losses = {}
for binder in potential_binders:
loss = compute_mlm_loss(protein_of_interest, binder)
mlm_losses[binder] = loss
# Rank binders based on MLM loss
ranked_binders = sorted(mlm_losses, key=mlm_losses.get)
print("Ranking of Potential Binders:")
for idx, binder in enumerate(ranked_binders, 1):
print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}")
输出结果:

-换行符-

-换行符-

可以从图中看到每个序列对儿的MLM损失值,那么越低的则越有可能有更高的相互作用力?!
我们还可以借鉴文献中对复合物使用的一种技巧:在两条蛋白序列之间插入大约 20 个甘氨酸(G)的长链来代替冒号分隔符。这种做法可能会产生略有不同的预测结果,值得一试。我们也可以选择优先对由上述结合位点预测模型预测出的结合位点进行掩码,因为已知蛋白语言模型会对结合残基及其他特殊残基给予更多关注。另一种变体是,仅对每一对拼接序列中代表潜在蛋白复合物的其中一条序列应用掩码。所有这些调整都有可能提升该方法的性能。请尝试这些不同的变体,看看哪种方法效果最好!请读者大佬们自行测试~
2.9finally地可视化
我们还可以基于这种方法构建蛋白质-蛋白质相互作用(PPI)网络。只需针对蛋白对的 MLM 损失计算设置一个阈值,然后根据该阈值创建图:如果损失低于某个阈值,就在这两个蛋白之间添加一条边。这样就能快速地近似蛋白互作网络并发现候选相互作用。我们可以按如下方式实现。
import networkx as nx
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import plotly.graph_objects as go
from ipywidgets import interact
from ipywidgets import widgets
# Check if CUDA is available and set the default device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pretrained (or fine-tuned) ESM-2 model and tokenizer
model_name = "facebook/esm2_t6_8M_UR50D" # You can change this to your fine-tuned model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
# Send the model to the device (GPU or CPU)
model.to(device)
# Ensure the model is in evaluation mode
model.eval()
# Define Protein Sequences (Replace with your list)
all_proteins = [
"MFLSILVALCLWLHLALGVRGAPCEAVRIPMCRHMPWNITRMPNHLHHSTQENAILAIEQYEELVDVNCSAVLRFFLCAMYAPICTLEFLHDPIKPCKSVCQRARDDCEPLMKMYNHSWPESLACDELPVYDRGVCISPEAIVTDLPEDVKWIDITPDMMVQERPLDVDCKRLSPDRCKCKKVKPTLATYLSKNYSYVIHAKIKAVQRSGCNEVTTVVDVKEIFKSSSPIPRTQVPLITNSSCQCPHILPHQDVLIMCYEWRSRMMLLENCLVEKWRDQLSKRSIQWEERLQEQRRTVQDKKKTAGRTSRSNPPKPKGKPPAPKPASPKKNIKTRSAQKRTNPKRV",
"MDAVEPGGRGWASMLACRLWKAISRALFAEFLATGLYVFFGVGSVMRWPTALPSVLQIAITFNLVTAMAVQVTWKASGAHANPAVTLAFLVGSHISLPRAVAYVAAQLVGATVGAALLYGVMPGDIRETLGINVVRNSVSTGQAVAVELLLTLQLVLCVFASTDSRQTSGSPATMIGISVALGHLIGIHFTGCSMNPARSFGPAIIIGKFTVHWVFWVGPLMGALLASLIYNFVLFPDTKTLAQRLAILTGTVEVGTGAGAGAEPLKKESQPGSGAVEMESV",
"MKFLLDILLLLPLLIVCSLESFVKLFIPKRRKSVTGEIVLITGAGHGIGRLTAYEFAKLKSKLVLWDINKHGLEETAAKCKGLGAKVHTFVVDCSNREDIYSSAKKVKAEIGDVSILVNNAGVVYTSDLFATQDPQIEKTFEVNVLAHFWTTKAFLPAMTKNNHGHIVTVASAAGHVSVPFLLAYCSSKFAAVGFHKTLTDELAALQITGVKTTCLCPNFVNTGFIKNPSTSLGPTLEPEEVVNRLMHGILTEQKMIFIPSSIAFLTTLERILPERFLAVLKRKISVKFDAVIGYKMKAQ",
"MAAAVPRRPTQQGTVTFEDVAVNFSQEEWCLLSEAQRCLYRDVMLENLALISSLGCWCGSKDEEAPCKQRISVQRESQSRTPRAGVSPKKAHPCEMCGLILEDVFHFADHQETHHKQKLNRSGACGKNLDDTAYLHQHQKQHIGEKFYRKSVREASFVKKRKLRVSQEPFVFREFGKDVLPSSGLCQEEAAVEKTDSETMHGPPFQEGKTNYSCGKRTKAFSTKHSVIPHQKLFTRDGCYVCSDCGKSFSRYVSFSNHQRDHTAKGPYDCGECGKSYSRKSSLIQHQRVHTGQTAYPCEECGKSFSQKGSLISHQLVHTGEGPYECRECGKSFGQKGNLIQHQQGHTGERAYHCGECGKSFRQKFCFINHQRVHTGERPYKCGECGKSFGQKGNLVHHQRGHTGERPYECKECGKSFRYRSHLTEHQRLHTGERPYNCRECGKLFNRKYHLLVHERVHTGERPYACEVCGKLFGNKHSVTIHQRIHTGERPYECSECGKSFLSSSALHVHKRVHSGQKPYKCSECGKSFSECSSLIKHRRIHTGERPYECTKCGKTFQRSSTLLHHQSSHRRKAL",
"MGQPWAAGSTDGAPAQLPLVLTALWAAAVGLELAYVLVLGPGPPPLGPLARALQLALAAFQLLNLLGNVGLFLRSDPSIRGVMLAGRGLGQGWAYCYQCQSQVPPRSGHCSACRVCILRRDHHCRLLGRCVGFGNYRPFLCLLLHAAGVLLHVSVLLGPALSALLRAHTPLHMAALLLLPWLMLLTGRVSLAQFALAFVTDTCVAGALLCGAGLLFHGMLLLRGQTTWEWARGQHSYDLGPCHNLQAALGPRWALVWLWPFLASPLPGDGITFQTTADVGHTAS",
"MGLRIHFVVDPHGWCCMGLIVFVWLYNIVLIPKIVLFPHYEEGHIPGILIIIFYGISIFCLVALVRASITDPGRLPENPKIPHGEREFWELCNKCNLMRPKRSHHCSRCGHCVRRMDHHCPWINNCVGEDNHWLFLQLCFYTELLTCYALMFSFCHYYYFLPLKKRNLDLFVFRHELAIMRLAAFMGITMLVGITGLFYTQLIGIITDTTSIEKMSNCCEDISRPRKPWQQTFSEVFGTRWKILWFIPFRQRQPLRVPYHFANHV",
"MLLLGAVLLLLALPGHDQETTTQGPGVLLPLPKGACTGWMAGIPGHPGHNGAPGRDGRDGTPGEKGEKGDPGLIGPKGDIGETGVPGAEGPRGFPGIQGRKGEPGEGAYVYRSAFSVGLETYVTIPNMPIRFTKIFYNQQNHYDGSTGKFHCNIPGLYYFAYHITVYMKDVKVSLFKKDKAMLFTYDQYQENNVDQASGSVLLHLEVGDQVWLQVYGEGERNGLYADNDNDSTFTGFLLYHDTN",
"MGLLAFLKTQFVLHLLVGFVFVVSGLVINFVQLCTLALWPVSKQLYRRLNCRLAYSLWSQLVMLLEWWSCTECTLFTDQATVERFGKEHAVIILNHNFEIDFLCGWTMCERFGVLGSSKVLAKKELLYVPLIGWTWYFLEIVFCKRKWEEDRDTVVEGLRRLSDYPEYMWFLLYCEGTRFTETKHRVSMEVAAAKGLPVLKYHLLPRTKGFTTAVKCLRGTVAAVYDVTLNFRGNKNPSLLGILYGKKYEADMCVRRFPLEDIPLDEKEAAQWLHKLYQEKDALQEIYNQKGMFPGEQFKPARRPWTLLNFLSWATILLSPLFSFVLGVFASGSPLLILTFLGFVGAASFGVRRLIGVTEIEKGSSYGNQEFKKKE",
"MDLAGLLKSQFLCHLVFCYVFIASGLIINTIQLFTLLLWPINKQLFRKINCRLSYCISSQLVMLLEWWSGTECTIFTDPRAYLKYGKENAIVVLNHKFEIDFLCGWSLSERFGLLGGSKVLAKKELAYVPIIGWMWYFTEMVFCSRKWEQDRKTVATSLQHLRDYPEKYFFLIHCEGTRFTEKKHEISMQVARAKGLPRLKHHLLPRTKGFAITVRSLRNVVSAVYDCTLNFRNNENPTLLGVLNGKKYHADLYVRRIPLEDIPEDDDECSAWLHKLYQEKDAFQEEYYRTGTFPETPMVPPRRPWTLVNWLFWASLVLYPFFQFLVSMIRSGSSLTLASFILVFFVASVGVRWMIGVTEIDKGSAYGNSDSKQKLND",
"MALLLCFVLLCGVVDFARSLSITTPEEMIEKAKGETAYLPCKFTLSPEDQGPLDIEWLISPADNQKVDQVIILYSGDKIYDDYYPDLKGRVHFTSNDLKSGDASINVTNLQLSDIGTYQCKVKKAPGVANKKIHLVVLVKPSGARCYVDGSEEIGSDFKIKCEPKEGSLPLQYEWQKLSDSQKMPTSWLAEMTSSVISVKNASSEYSGTYSCTVRNRVGSDQCLLRLNVVPPSNKAGLIAGAIIGTLLALALIGLIIFCCRKKRREEKYEKEVHHDIREDVPPPKSRTSTARSYIGSNHSSLGSMSPSNMEGYSKTQYNQVPSEDFERTPQSPTLPPAKVAAPNLSRMGAIPVMIPAQSKDGSIV",
"MSYVFVNDSSQTNVPLLQACIDGDFNYSKRLLESGFDPNIRDSRGRTGLHLAAARGNVDICQLLHKFGADLLATDYQGNTALHLCGHVDTIQFLVSNGLKIDICNHQGATPLVLAKRRGVNKDVIRLLESLEEQEVKGFNRGTHSKLETMQTAESESAMESHSLLNPNLQQGEGVLSSFRTTWQEFVEDLGFWRVLLLIFVIALLSLGIAYYVSGVLPFVENQPELVH",
"MRVAGAAKLVVAVAVFLLTFYVISQVFEIKMDASLGNLFARSALDTAARSTKPPRYKCGISKACPEKHFAFKMASGAANVVGPKICLEDNVLMSGVKNNVGRGINVALANGKTGEVLDTKYFDMWGGDVAPFIEFLKAIQDGTIVLMGTYDDGATKLNDEARRLIADLGSTSITNLGFRDNWVFCGGKGIKTKSPFEQHIKNNKDTNKYEGWPEVVEMEGCIPQKQD",
"MAPAAATGGSTLPSGFSVFTTLPDLLFIFEFIFGGLVWILVASSLVPWPLVQGWVMFVSVFCFVATTTLIILYIIGAHGGETSWVTLDAAYHCTAALFYLSASVLEALATITMQDGFTYRHYHENIAAVVFSYIATLLYVVHAVFSLIRWKSS",
"MRLQGAIFVLLPHLGPILVWLFTRDHMSGWCEGPRMLSWCPFYKVLLLVQTAIYSVVGYASYLVWKDLGGGLGWPLALPLGLYAVQLTISWTVLVLFFTVHNPGLALLHLLLLYGLVVSTALIWHPINKLAALLLLPYLAWLTVTSALTYHLWRDSLCPVHQPQPTEKSD",
"MEESVVRPSVFVVDGQTDIPFTRLGRSHRRQSCSVARVGLGLLLLLMGAGLAVQGWFLLQLHWRLGEMVTRLPDGPAGSWEQLIQERRSHEVNPAAHLTGANSSLTGSGGPLLWETQLGLAFLRGLSYHDGALVVTKAGYYYIYSKVQLGGVGCPLGLASTITHGLYKRTPRYPEELELLVSQQSPCGRATSSSRVWWDSSFLGGVVHLEAGEKVVVRVLDERLVRLRDGTRSYFGAFMV"
]
def compute_average_mlm_loss(protein1, protein2, iterations=10):
total_loss = 0.0
connector = "G" * 25 # Connector sequence of G's
for _ in range(iterations):
concatenated_sequence = protein1 + connector + protein2
inputs = tokenizer(concatenated_sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024)
mask_prob = 0.55
mask_indices = torch.rand(inputs["input_ids"].shape, device=device) < mask_prob
# Locate the positions of the connector 'G's and set their mask indices to False
connector_indices = tokenizer.encode(connector, add_special_tokens=False)
connector_length = len(connector_indices)
start_connector = len(tokenizer.encode(protein1, add_special_tokens=False))
end_connector = start_connector + connector_length
# Avoid masking the connector 'G's
mask_indices[0, start_connector:end_connector] = False
# Apply the mask to the input IDs
inputs["input_ids"][mask_indices] = tokenizer.mask_token_id
inputs = {k: v.to(device) for k, v in inputs.items()} # Send inputs to the device
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
total_loss += loss.item()
return total_loss / iterations
# Compute all average losses to determine the maximum threshold for the slider
all_losses = []
for i, protein1 in enumerate(all_proteins):
for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
avg_loss = compute_average_mlm_loss(protein1, protein2)
all_losses.append(avg_loss)
# Set the maximum threshold to the maximum loss computed
max_threshold = max(all_losses)
print(f"Maximum loss (maximum threshold for slider): {max_threshold}")
def plot_graph(threshold):
G = nx.Graph()
# Add all protein nodes to the graph
for i, protein in enumerate(all_proteins):
G.add_node(f"protein {i+1}")
# Loop through all pairs of proteins and calculate average MLM loss
loss_idx = 0 # Index to keep track of the position in the all_losses list
for i, protein1 in enumerate(all_proteins):
for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
avg_loss = all_losses[loss_idx]
loss_idx += 1
# Add an edge if the loss is below the threshold
if avg_loss < threshold:
G.add_edge(f"protein {i+1}", f"protein {j+1}", weight=round(avg_loss, 3))
# 3D Network Plot
# Adjust the k parameter to bring nodes closer. This might require some experimentation to find the right value.
k_value = 2 # Lower value will bring nodes closer together
pos = nx.spring_layout(G, dim=3, seed=42, k=k_value)
edge_x = []
edge_y = []
edge_z = []
for edge in G.edges():
x0, y0, z0 = pos[edge[0]]
x1, y1, z1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_z.extend([z0, z1, None])
edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(width=0.5, color='grey'))
node_x = []
node_y = []
node_z = []
node_text = []
for node in G.nodes():
x, y, z = pos[node]
node_x.append(x)
node_y.append(y)
node_z.append(z)
node_text.append(node)
node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z, mode='markers', marker=dict(size=5), hoverinfo='text', hovertext=node_text)
layout = go.Layout(title='Protein Interaction Graph', title_x=0.5, scene=dict(xaxis=dict(showbackground=False), yaxis=dict(showbackground=False), zaxis=dict(showbackground=False)))
fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
fig.show()
# Create an interactive slider for the threshold value with a default of 8.50
interact(plot_graph, threshold=widgets.FloatSlider(min=0.0, max=max_threshold, step=0.05, value=8.25))
输出:其实是可以根据阈值进行可视化的动态操作:

三.完结
所以这个可能给各位大佬带去 一些新的思路,这样在做ppi研究的时候可以有更广的方向和方法,最后谢谢各位大佬。您的点赞和分享是我更新的动力!