首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow如何将hub.Module()更改为本地文件夹

Tensorflow如何将hub.Module()更改为本地文件夹
EN

Stack Overflow用户
提问于 2020-03-22 16:02:07
回答 1查看 594关注 0票数 0

如何更改:

代码语言:javascript
复制
BERT_MODEL = "https://tfhub.dev/google/bert_multi_cased_L-12_H-768_A-12/1"

def create_tokenizer_from_hub_module():
  """Get the vocab file and casing info from the Hub module."""
  with tf.Graph().as_default():
    bert_module = hub.Module(BERT_MODEL)
    tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
    with tf.Session() as sess:
      vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
                                            tokenization_info["do_lower_case"]])

  return bert.tokenization.FullTokenizer(
      vocab_file=vocab_file, do_lower_case=do_lower_case)

tokenizer = create_tokenizer_from_hub_module()

这样我就可以在没有hub.Module()调用的情况下加载本地BERT模型,因为它不能与本地路径一起工作。

我从另一个网站下载了一个不同的TF1预训练模型,将其解压缩并存储在/test/module/中。

如果我更改了上面的BERT_MODEL = "/test/module",我需要如何更改其余的内容?我现在收到字符串错误,因为tokenization_info = bert_module(signature="tokenization_info", as_dict=True)不能工作。

我是TF的新手,注意我需要使用TF1,而不是TF2。

注意:关于下面的建议,我得到了:

代码语言:javascript
复制
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-a98e44536f87> in <module>()
      9   return vocab_file, do_lower_case
     10 
---> 11 print(get_bert_tokenizer_info("/tmp/local_copy"))
     12 # Will print: (b'/tmp/local_copy/assets/vocab.txt', False)

4 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_hub/registry.py in __call__(self, *args, **kwargs)
     43     raise RuntimeError(
     44         "Missing implementation that supports: %s(*%r, **%r)" % (
---> 45             self._name, args, kwargs))
     46 
     47 

RuntimeError: Missing implementation that supports: loader(*('/tmp/local_copy',), **{})
EN

回答 1

Stack Overflow用户

发布于 2020-03-23 18:44:45

hub.Module使用本地未压缩路径,因此您可以将BERT_MODEL更改为其他路径并重用相同的代码。

示例:

创建模块的本地副本:

代码语言:javascript
复制
mkdir /tmp/local_copy
wget "https://tfhub.dev/google/bert_multi_cased_L-12_H-768_A-12/1?tf-hub-format=compressed" -O "module.tar.gz"
tar -C /tmp/local_copy -xzvf module.tar.gz

使用模块的本地副本:

代码语言:javascript
复制
import tensorflow as tf
import tensorflow_hub as hub

def get_bert_tokenizer_info(bert_module):
  """Get the vocab file and casing info from the Hub module."""
  with tf.Graph().as_default():
    bert_module = hub.Module(bert_module)
    tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
    with tf.Session() as sess:
      vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
                                            tokenization_info["do_lower_case"]])
  return vocab_file, do_lower_case

print(get_bert_tokenizer_info("/tmp/local_copy"))
# Will print: (b'/tmp/local_copy/assets/vocab.txt', False)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60797100

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档