首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何改进策略梯度的tensorflow 2.0代码?

如何改进策略梯度的tensorflow 2.0代码?
EN

Data Science用户
提问于 2019-12-23 21:45:41
回答 1查看 113关注 0票数 1

我在网上重新创建了一些代码,用于使用策略梯度解决土匪问题。示例是在tensorflow 1.0中,因此我使用急切的执行和梯度磁带使用tensorflow 2.0重新创建了它,但是,在训练模型时,我必须将权重张量转换为numpy数组,更新权重,然后从numpy数组重新分配tf.Variable。我觉得这不是表演,我可以找到更好的方法。完整代码在这里,https://github.com/entrpn/reinforcement-learning/blob/master/tf2_rl/bandits.py

我希望改进的主要代码如下:

代码语言:javascript
复制
def train(agent,action,reward, learning_rate=0.001):
    with tf.GradientTape() as t:
        current_loss = loss(agent(action),reward)
    dW = t.gradient(current_loss,[agent.weights])
    weights_as_np = agent.weights.numpy()
    responsible_weight = agent.weights[action]
    responsible_weight_dw = np.array(dW)[0][action]

    weights_as_np[action] = weights_as_np[action] - learning_rate*responsible_weight_dw

    agent.weights.assign(tf.Variable(weights_as_np))
EN

回答 1

Data Science用户

发布于 2020-05-25 20:07:22

如果可能的话,试着只使用tensorflow函数(例如,将.numpy()部件放出来),并在train()函数的顶部添加一个@tf.function装饰器。

@tf.function的作用是将整个函数转换为tensorflow op。整个函数的执行速度将比一般的Python函数快一个数量级。

票数 1
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/65354

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档