首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >linear流动线性回归误差放大

linear流动线性回归误差放大
EN

Stack Overflow用户
提问于 2017-05-28 04:15:14
回答 2查看 757关注 0票数 2

我试图用tensorflow拟合一个非常简单的线性回归模型。然而,损失(均方误差)爆炸,而不是减少到零。

首先,我生成数据:

代码语言:javascript
复制
x_data = np.random.uniform(high=10,low=0,size=100)
y_data = 3.5 * x_data -4 + np.random.normal(loc=0, scale=2,size=100)

然后,我定义了计算图:

代码语言:javascript
复制
X = tf.placeholder(dtype=tf.float32, shape=100)
Y = tf.placeholder(dtype=tf.float32, shape=100)
m = tf.Variable(1.0)
c = tf.Variable(1.0)
Ypred = m*X + c
loss = tf.reduce_mean(tf.square(Ypred - Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=.1)
train = optimizer.minimize(loss)

最后,运行它100个历元:

代码语言:javascript
复制
steps = {}
steps['m'] = []
steps['c'] = []

losses=[]

for k in range(100):
    _m = session.run(m)
    _c = session.run(c)
    _l = session.run(loss, feed_dict={X: x_data, Y:y_data})
    session.run(train, feed_dict={X: x_data, Y:y_data})
    steps['m'].append(_m)
    steps['c'].append(_c)
    losses.append(_l)

然而,当我策划损失时,我得到:

完整的代码也可以找到这里

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2017-05-28 19:36:33

每当你看到你的成本随着时间的增加而单调增加,那就是一个确定的标志,你的学习率太高了。每次重复以你的学习率乘以1/10重新运行你的训练,直到成本函数随着时间的增加而明显减少。

票数 5
EN

Stack Overflow用户

发布于 2017-05-28 19:30:14

学习率太高;0.001起作用很好:

代码语言:javascript
复制
x_data = np.random.uniform(high=10,low=0,size=100)
y_data = 3.5 * x_data -4 + np.random.normal(loc=0, scale=2,size=100)
X = tf.placeholder(dtype=tf.float32, shape=100)
Y = tf.placeholder(dtype=tf.float32, shape=100)
m = tf.Variable(1.0)
c = tf.Variable(1.0)
Ypred = m*X + c
loss = tf.reduce_mean(tf.square(Ypred - Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=.001)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
with tf.Session() as session:
    session.run(init)
    steps = {}
    steps['m'] = []
    steps['c'] = []

    losses=[]

    for k in range(100):
    _m = session.run(m)
    _c = session.run(c)
    _l = session.run(loss, feed_dict={X: x_data, Y:y_data})
    session.run(train, feed_dict={X: x_data, Y:y_data})
    steps['m'].append(_m)
    steps['c'].append(_c)
    losses.append(_l)

plt.plot(losses)
plt.savefig('loss.png')

(可能有用的参考资料:https://gist.github.com/fuglede/ad04ce38e80887ddcbeb6b81e97bbfbc)

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/44223756

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档