首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么我的线性回归不是那么简单?

为什么我的线性回归不是那么简单?
EN

Stack Overflow用户
提问于 2020-04-15 09:01:26
回答 1查看 69关注 0票数 0

我是tensorflow-2的新手,我开始了我的学习曲线,使用以下简单的线性回归模型:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


# Make data
num_samples, w, b = 20, 0.5, 2
xs = np.asarray(range(num_samples))
ys = np.asarray([x*w + b + np.random.normal() for x in range(num_samples)])
xts = tf.convert_to_tensor(xs, dtype=tf.float32)
yts = tf.convert_to_tensor(xs, dtype=tf.float32)
plt.plot(xs, ys, 'ro')

class Linear(tf.keras.Model):
    def __init__(self, name='linear', **kwargs):
        super().__init__(name='linear', **kwargs)
        self.w = tf.Variable(0, True, name="w", dtype=tf.float32)
        self.b = tf.Variable(1, True, name="b", dtype=tf.float32)   

    def call(self, inputs):
        return self.w*inputs + self.b

class Custom(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if epoch % 20 == 0:
            preds = self.model.predict(xts)
            plt.plot(xs, preds, label='{} {:7.2f}'.format(epoch, logs['loss']))
            print('The average loss for epoch {} is .'.format(epoch, logs['loss']))

x = tf.keras.Input(dtype=tf.float32, shape=[])
#model = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
model = Linear()
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='MSE')
model.fit(x=xts, y=yts, verbose=1, batch_size=4, epochs=250, callbacks=[Custom()])

plt.legend()
plt.show()

由于我不明白的原因,我的模型似乎不能拟合曲线。我也尝试了keras.layers.Dense(1),我得到了同样的结果。而且,结果似乎并不对应于适当的损失函数,因为在时期120左右,模型的损失应该比250年的损失小。

你能帮我弄明白我做错了什么吗?非常感谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-04-15 16:12:44

您的代码中有一个小错误,因为xtsyts彼此相同,即您编写了

代码语言:javascript
复制
xts = tf.convert_to_tensor(xs, dtype=tf.float32)
yts = tf.convert_to_tensor(xs, dtype=tf.float32)

而不是

代码语言:javascript
复制
xts = tf.convert_to_tensor(xs, dtype=tf.float32)
yts = tf.convert_to_tensor(ys, dtype=tf.float32)

这就是为什么损失是没有意义的。一旦修复了这个问题,结果就像预期的一样,请参见下面的图表。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61219672

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档