我正在尝试使用microsoft/pubmedbert获取临床数据的word嵌入。我有360万行文本。将文本转换为10k行的向量大约需要30分钟。因此,对于360万行,大约需要180小时(约8天)。
,有什么方法可以加快这个过程吗?
我的密码-
from transformers import AutoTokenizer
from transformers import pipeline
model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
tokenizer = AutoTokenizer.from_pretrained(model_name)
classifier = pipeline('feature-extraction',model=model_name, tokenizer=tokenizer)
def lambda_func(row):
tokens = tokenizer(row['notetext'])
if len(tokens['input_ids'])>512:
tokens = re.split(r'\b', row['notetext'])
tokens= [t for t in tokens if len(t) > 0 ]
row['notetext'] = ''.join(tokens[:512])
row['vectors'] = classifier(row['notetext'])[0][0]
return row
def process(progress_notes):
progress_notes = progress_notes.apply(lambda_func, axis=1)
return progress_notes
progress_notes = process(progress_notes)
vectors_breadth = 768
vectors_length = len(progress_notes)
vectors_2d = np.reshape(progress_notes['vectors'].to_list(), (vectors_length, vectors_breadth))
vectors_df = pd.DataFrame(vectors_2d)我的progress_notes数据就像-
progress_notes = pd.DataFrame({'id':[1,2,3],'progressnotetype':['Nursing Note', 'Nursing Note', 'Administration Note'], 'notetext': ['Patient\'s skin is grossly intact with exception of skin tear to r inner elbow and r lateral lower leg','Patient with history of Afib with RVR. Patient is incontinent of bowel and bladder.','Give 2 tablet by mouth every 4 hours as needed for Mild to moderate Pain Not to exceed 3 grams in 24 hours']})注- 1)我在aws ec2实例r5.8x大型(32个cpu )上运行代码--我尝试使用多处理,但代码陷入死锁,因为ec2占用了我的所有cpu核心。
发布于 2022-12-02 22:02:10
不只是?应用不是最快的方法。
from sentence_transformers import SentenceTransformer
sbert_model = SentenceTransformer('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')
document_embeddings = sbert_model.encode(pd.Series(['hello', 'cell type', 'protein']))
document_embeddings您将得到如下的输出
array([[ 0.06255245, 0.14945783, -0.06224129, ..., -0.11892398,
-0.0507343 , 0.0153866 ],
[-0.17571464, 0.03554079, -0.04899959, ..., -0.24369009,
-0.00672011, 0.04914075],
[-0.22093703, -0.03271236, -0.08943298, ..., -0.21335356,
0.11418738, -0.09207606]], dtype=float32)https://stackoverflow.com/questions/65494850
复制相似问题