首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将sess.run转换为pytorch

将sess.run转换为pytorch
EN

Stack Overflow用户
提问于 2020-04-27 17:44:47
回答 1查看 616关注 0票数 0

我正在尝试将一个代码从tf转换为pytorch。代码中我被卡住的部分是这个sess.run。据我所知,pytorch不需要它,但我找不到复制它的方法。我给你附上代码。

TF:

代码语言:javascript
复制
ebnos_db = np.linspace(1,6, 6)
bers_no_training = np.zeros(shape=[ebnos_db.shape[0]])
for j in range(epochs):
    for i in range(ebnos_db.shape[0]):
        ebno_db = ebnos_db[i]
        bers_no_training[i] += sess.run(ber, feed_dict={
            batch_size: samples,
            noise_var: ebnodb2noisevar(ebno_db, coderate)
        })
bers_no_training /= epochs

samples是一个int32,而ebnodb2noisevar()返回一个float32。

TF中的BER计算如下:

代码语言:javascript
复制
ber = tf.reduce_mean(tf.cast(tf.not_equal(x, x_hat), dtype=tf.float32))

在PT中:

代码语言:javascript
复制
wrong_bits = ( torch.eq(x, x_hat).type(torch.float32) * -1 ) + 1
ber = torch.mean(wrong_bits)

我认为误码率计算得很好,但主要的问题是我不知道如何将sess.run转换为PyTorch,也不完全了解它的功能。

有人能帮我吗?

谢谢

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-04-27 18:11:58

在PyTorch中也可以做到这一点,但在ber中更容易做到

代码语言:javascript
复制
ber = torch.mean((x != x_hat).float())

就足够了。

是的,PyTorch不需要它,因为它是基于动态图形构造的(与Tensorflow的静态方法不同)。

tensorflow中,sess.run用于向创建的图提供值;在这里,名为batch_sizetf.Placeholder (图中的变量,表示用户可以“注入”其数据的节点)将被提供给samples,而noise_var将被提供给ebnodb2noisevar(ebno_db, coderate)

将其转换为PyTorch通常很简单,因为您不需要使用session使用任何类似于图形的方法。只需使用您的神经网络(或类似)和正确的输入(如samplesnoise_var),您就很好了。您必须检查您的图(因此ber是如何从batch_sizenoise_var构造的),并在PyTorch中重新实现它。

此外,在深入研究框架之前,请查看PyTorch introductory tutorials以了解该框架。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61455950

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档