首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >加速GPT2 --优化tf.sess.run()的推理时间

加速GPT2 --优化tf.sess.run()的推理时间
EN

Stack Overflow用户
提问于 2021-03-16 08:39:23
回答 1查看 90关注 0票数 0

我正在尝试优化GPT2上的推理时间。在Google Colab上,调用脚本后生成样本的当前时间是55秒。我添加了时间戳,试图找出瓶颈所在。代码如下:

代码语言:javascript
复制
 for _ in range(nsamples // batch_size):
            out = sess.run(output, feed_dict={
                context: [context_tokens for _ in range(batch_size)]
            })[:, len(context_tokens):]
            for i in range(batch_size):
                generated += 1
                text = enc.decode(out[i])
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                print(text)
        print("=" * 80)

这条线

代码语言:javascript
复制
out = sess.run(output, feed_dict={
                context: [context_tokens for _ in range(batch_size)]
            })[:, len(context_tokens):] 

这就是复杂性所在。有谁能帮我改进这段代码吗?非常感谢!

EN

回答 1

Stack Overflow用户

发布于 2021-07-27 09:54:51

在GPT2中,batch_size被设置为1,没有办法在不使进程崩溃的情况下更改它。所以"context_tokens for _ in range(batch_size)“的意思是"context_tokens for _ in range(1)”,意思是"context_tokens“,它不会有多大的速度提升,但实现起来是安全的,并且使查看代码变得更加明智。真正复杂的是,您的ram中有一个6 in的bohemoth,您将在该会话中访问它。

作为一个实际问题,您发送的令牌越少,这些令牌占用的处理越少,这部分的执行速度就会越快。因为每个令牌都需要通过GPT2 AI发送。但结果是,响应将变得越不“智能”。

顺便说一句,//是一个整数除法运算,因此nsamples // batch_size = nsamples/1 = nsamples大小。从我看到的情况来看,当我在print( nsamples )中打印它的值时,nsamples是1。因此for循环是一个项目另一个循环,这意味着可以删除该循环。

GPT2只是tensorflow的一个实现。查找:如何在tensorflow中创建图形;如何调用该图形的会话;如何使saver保存该会话中的变量,以及如何使用saver恢复会话。您将了解检查点、元文件和其他使您的文件更有意义的实现。

tensorflow模块可以在Lib,site-packages,tensorflow_core中找到(至少在AI地下城2的Henk717分支中)。大部分处理都在子目录python/ops和framework中进行。如果您的代码破坏了tf所期望的钩子,您将看到这些弹出窗口。

如果这个问题与AI地下城中的实现有关,那么我所能实现的最好的方法就是对generator.generate的递归调用,该调用是通过一次尝试退出的,除了KeyboardInterrupt:在生成每个令牌时,对每个令牌进行打印( token,end = '',flush = True)。这样,你就可以在AI生成令牌时查看每个令牌,而不是等待55秒才能发出ping声音。

此外,Cuda警告需要单引号,而不是双引号,因此,import os os.environ‘’TF_CPP_MIN_LOG_LEVEL‘= '3’而不是"3“,这将在导入tensorflow时消除cuda警告。

接下来是在tensorflow 1.5以上版本中实现GPT2时弹出的折旧。

要关闭这些,只需使用tfv = tf.compat.v1 tfv.set_verbosity(tfv.logging.Error)即可。您不需要导入警告。

即使这样,在tf初始化、样本初始生成和将模块加载到ram中之间也有很长的加载时间。我在model.shape_list(X)中添加了以下代码行打印(“_”,end ='',flush=True),并且至少对于正在构建的用于将其本地化到机器的模块,您可以查看排序的“进度条”。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66647600

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档