我正在尝试将一个代码从tf转换为pytorch。代码中我被卡住的部分是这个sess.run。据我所知,pytorch不需要它,但我找不到复制它的方法。我给你附上代码。
TF:
ebnos_db = np.linspace(1,6, 6)
bers_no_training = np.zeros(shape=[ebnos_db.shape[0]])
for j in range(epochs):
for i in range(ebnos_db.shape[0]):
ebno_db = ebnos_db[i]
bers_no_training[i] += sess.run(ber, feed_dict={
batch_size: samples,
noise_var: ebnodb2noisevar(ebno_db, coderate)
})
bers_no_training /= epochssamples是一个int32,而ebnodb2noisevar()返回一个float32。
TF中的BER计算如下:
ber = tf.reduce_mean(tf.cast(tf.not_equal(x, x_hat), dtype=tf.float32))在PT中:
wrong_bits = ( torch.eq(x, x_hat).type(torch.float32) * -1 ) + 1
ber = torch.mean(wrong_bits)我认为误码率计算得很好,但主要的问题是我不知道如何将sess.run转换为PyTorch,也不完全了解它的功能。
有人能帮我吗?
谢谢
发布于 2020-04-27 18:11:58
在PyTorch中也可以做到这一点,但在ber中更容易做到
ber = torch.mean((x != x_hat).float())就足够了。
是的,PyTorch不需要它,因为它是基于动态图形构造的(与Tensorflow的静态方法不同)。
在tensorflow中,sess.run用于向创建的图提供值;在这里,名为batch_size的tf.Placeholder (图中的变量,表示用户可以“注入”其数据的节点)将被提供给samples,而noise_var将被提供给ebnodb2noisevar(ebno_db, coderate)。
将其转换为PyTorch通常很简单,因为您不需要使用session使用任何类似于图形的方法。只需使用您的神经网络(或类似)和正确的输入(如samples和noise_var),您就很好了。您必须检查您的图(因此ber是如何从batch_size和noise_var构造的),并在PyTorch中重新实现它。
此外,在深入研究框架之前,请查看PyTorch introductory tutorials以了解该框架。
https://stackoverflow.com/questions/61455950
复制相似问题