如何更改:
BERT_MODEL = "https://tfhub.dev/google/bert_multi_cased_L-12_H-768_A-12/1"
def create_tokenizer_from_hub_module():
"""Get the vocab file and casing info from the Hub module."""
with tf.Graph().as_default():
bert_module = hub.Module(BERT_MODEL)
tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
with tf.Session() as sess:
vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
tokenization_info["do_lower_case"]])
return bert.tokenization.FullTokenizer(
vocab_file=vocab_file, do_lower_case=do_lower_case)
tokenizer = create_tokenizer_from_hub_module()这样我就可以在没有hub.Module()调用的情况下加载本地BERT模型,因为它不能与本地路径一起工作。
我从另一个网站下载了一个不同的TF1预训练模型,将其解压缩并存储在/test/module/中。
如果我更改了上面的BERT_MODEL = "/test/module",我需要如何更改其余的内容?我现在收到字符串错误,因为tokenization_info = bert_module(signature="tokenization_info", as_dict=True)不能工作。
我是TF的新手,注意我需要使用TF1,而不是TF2。
注意:关于下面的建议,我得到了:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-11-a98e44536f87> in <module>()
9 return vocab_file, do_lower_case
10
---> 11 print(get_bert_tokenizer_info("/tmp/local_copy"))
12 # Will print: (b'/tmp/local_copy/assets/vocab.txt', False)
4 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_hub/registry.py in __call__(self, *args, **kwargs)
43 raise RuntimeError(
44 "Missing implementation that supports: %s(*%r, **%r)" % (
---> 45 self._name, args, kwargs))
46
47
RuntimeError: Missing implementation that supports: loader(*('/tmp/local_copy',), **{})发布于 2020-03-23 18:44:45
hub.Module使用本地未压缩路径,因此您可以将BERT_MODEL更改为其他路径并重用相同的代码。
示例:
创建模块的本地副本:
mkdir /tmp/local_copy
wget "https://tfhub.dev/google/bert_multi_cased_L-12_H-768_A-12/1?tf-hub-format=compressed" -O "module.tar.gz"
tar -C /tmp/local_copy -xzvf module.tar.gz使用模块的本地副本:
import tensorflow as tf
import tensorflow_hub as hub
def get_bert_tokenizer_info(bert_module):
"""Get the vocab file and casing info from the Hub module."""
with tf.Graph().as_default():
bert_module = hub.Module(bert_module)
tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
with tf.Session() as sess:
vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
tokenization_info["do_lower_case"]])
return vocab_file, do_lower_case
print(get_bert_tokenizer_info("/tmp/local_copy"))
# Will print: (b'/tmp/local_copy/assets/vocab.txt', False)https://stackoverflow.com/questions/60797100
复制相似问题