首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将BERT模型作为泡菜文件保存在磁盘上

将BERT模型作为泡菜文件保存在磁盘上
EN

Stack Overflow用户
提问于 2020-01-23 15:21:14
回答 1查看 1.7K关注 0票数 1

我已经设法让伯特模型工作在约翰斯诺实验室-火花-nlp库。我能够将“经过训练的模型”保存在磁盘上,如下所示。

拟合模型

代码语言:javascript
复制
df_bert_trained = bert_pipeline.fit(textRDD)

df_bert=df_bert_trained.transform(textRDD)

保存模型

代码语言:javascript
复制
df_bert_trained.write().overwrite().save("/home/XX/XX/trained_model")

然而,

首先,根据https://nlp.johnsnowlabs.com/docs/en/concepts这里的文档,可以将模型加载为

代码语言:javascript
复制
EmbeddingsHelper.load(path, spark, format, reference, dims, caseSensitive) 

但在我看来,变量“引用”在这一点上所代表的是什么还不清楚。

第二,是否有人设法将BERT嵌入保存为python中的泡菜文件?

EN

回答 1

Stack Overflow用户

发布于 2020-02-14 11:06:22

在Spark,伯特作为一个预先训练的模型。这意味着它已经是一个模型,经过了培训,适合等,并保存在正确的格式。

这是说,没有理由再适合或保存它。但是,一旦您将DataFrame转换为一个新的DataFrame,并为每个令牌提供了BERT嵌入,您就可以保存它的结果。

示例:

使用Spark包在火花壳中启动星火会话

代码语言:javascript
复制
spark-shell --packages JohnSnowLabs:spark-nlp:2.4.0
代码语言:javascript
复制
import com.johnsnowlabs.nlp.annotators._
import com.johnsnowlabs.nlp.base._

val documentAssembler = new DocumentAssembler()
      .setInputCol("text")
      .setOutputCol("document")

    val sentence = new SentenceDetector()
      .setInputCols("document")
      .setOutputCol("sentence")

    val tokenizer = new Tokenizer()
      .setInputCols(Array("sentence"))
      .setOutputCol("token")

    // Download and load the pretrained BERT model
    val embeddings = BertEmbeddings.pretrained(name = "bert_base_cased", lang = "en")
      .setInputCols("sentence", "token")
      .setOutputCol("embeddings")
      .setCaseSensitive(true)
      .setPoolingLayer(0)

    val pipeline = new Pipeline()
      .setStages(Array(
        documentAssembler,
        sentence,
        tokenizer,
        embeddings
      ))

// Test and transform

   val testData = Seq(
      "I like pancakes in the summer. I hate ice cream in winter.",
      "If I had asked people what they wanted, they would have said faster horses"
    ).toDF("text")

    val predictionDF = pipeline.fit(testData).transform(testData)

predictionDF是一个DataFrame,它包含BERT嵌入数据集中的每个令牌。BertEmbeddings的预训练模型来自TF集线器,这意味着它们与谷歌发布的训练前的权重完全相同。所有5种型号都有:

  • bert_base_cased (en)
  • bert_base_uncased (en)
  • bert_large_cased (en)
  • bert_large_uncased (en)
  • bert_multi_cased (xx)

如果你有任何问题或问题,请告诉我,我会更新我的答案。

参考资料

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59881819

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档