我正在尝试在GPU GeForce RTX2060 8 8GB内存上训练OpenNMT-tf变压器模型。您可以看到步骤Here。
我已经创建了Anaconda虚拟环境,并使用以下推荐安装了tensorflow-gpu。
conda install tensorflow-gpu==2.2.0在运行上面的命令之后,conda env将处理所有的事情,并在env中安装cuda 10.1和cudnn 7.6.5。然后我使用下面的命令安装了兼容TF2.2 gpu的openNMT-TF2.10。
~/anaconda3/envs/nmt/bin/pip install openNMT-tf==2.10以上命令将在conda环境中安装openNMT。
当我尝试运行OpenNMT-tf文档中“Quicstart”页面上的命令时,它在创建vocab时识别出了GPU。但当我开始训练transformer模型时,它给出了以下cudnn错误。
tensorflow.python.framework.errors_impl.InternalError: 2 root error(s) found.
(0) Internal: cuDNN launch failure : input shape ([1,504,512,1])
[[node transformer_base/self_attention_decoder/self_attention_decoder_layer/transformer_layer_wrapper_12/layer_norm_14/FusedBatchNormV3 (defined at /site-packages/opennmt/layers/common.py:128) ]]
[[Func/gradients/global_norm/write_summary/summary_cond/then/_302/input/_893/_52]]
(1) Internal: cuDNN launch failure : input shape ([1,504,512,1])
[[node transformer_base/self_attention_decoder/self_attention_decoder_layer/transformer_layer_wrapper_12/layer_norm_14/FusedBatchNormV3 (defined at /site-packages/opennmt/layers/common.py:128) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference__accumulate_next_33440]
Function call stack:
_accumulate_next -> _accumulate_next
2021-03-01 13:01:01.138811: I tensorflow/stream_executor/stream.cc:1990] [stream=0x560490f17b10,impl=0x560490f172c0] did not wait for [stream=0x5604906de830,impl=0x560490f17250]
2021-03-01 13:01:01.138856: I tensorflow/stream_executor/stream.cc:4938] [stream=0x560490f17b10,impl=0x560490f172c0] did not memcpy host-to-device; source: 0x7ff4467f8780
2021-03-01 13:01:01.138957: F tensorflow/core/common_runtime/gpu/gpu_util.cc:340] CPU->GPU Memcpy failed
Aborted (core dumped)如果有人能在这里指导那就太好了。
Ps。我不认为这是一个版本问题,因为我验证了openNMT-tf 2.10需要tensorflow 2.2,并且在安装tensorflow-gpu 2.2时,anaconda本身安装了cuda 10.1和cudnn 7.6.5 (默认情况下处理GPU依赖)。
发布于 2021-03-03 15:07:35
这是一个记忆问题。有些人在StackOverflow上提出了一些关于cudnn问题的建议。在运行此命令之前,请将环境变量'TF_FORCE_GPU_ALLOW_GROWTH‘设置为true。
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = "true"
os.system('onmt-main --model_type Transformer --config data.yml train --with_eval')我终于开始使用上面的脚本进行训练,它解决了我的问题。
https://stackoverflow.com/questions/66418806
复制相似问题