首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow sigmoid回归保持线性

Tensorflow sigmoid回归保持线性
EN

Stack Overflow用户
提问于 2021-09-25 15:48:11
回答 1查看 57关注 0票数 0

我正在尝试使用张量流使一个简单的神经网络拟合一个简单的函数,我知道我使用的结构和参数在MatLab中实现了这一点,但我需要将其移植到另一种语言(目前是Python语言,但后来是c++)。正因为如此,我正在努力寻找一个好的神经网络库,我认为应该是TensorFlow,但事实证明它非常挑剔。以下是代码的重要部分

代码语言:javascript
复制
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt

f1 = lambda x: ((x < .5) * np.power(x, 2) + (x > .5) * x) * 2 -1
x = np.linspace(-1, 1, 180).reshape(-1,1)
y = f1(x).reshape(-1, 1)
model = keras.models.Sequential()
model.add(layers.Dense(100, activation=tf.keras.activations.sigmoid, input_shape=(1,)))
model.add(layers.Dense(1, activation=tf.keras.activations.linear))
model.compile(loss=tf.keras.losses.mean_squared_error, optimizer=tf.keras.optimizers.SGD(0.001), metrics=[tf.keras.losses.mean_squared_error])
model.fit(x, y, batch_size=1, epochs=3)
xtest = np.linspace(-1, 1, 100).reshape(-1, 1)
ytest = model.predict(xtest)

plt.scatter(x, y)
plt.plot(xtest, ytest)
plt.show()

当绘制预期的图和预测值时,它会产生This Plot,其中点是期望函数,实线是预测值,我不确定我做错了什么。

网络必须是由100个sigmoid激活的神经元组成的层,然后是线性输出层,即使我改变了时代的数量和批量大小,网络仍然线性训练。任何帮助都将不胜感激

EN

回答 1

Stack Overflow用户

发布于 2021-09-25 16:15:04

你可以尝试改变激活和增加训练时间,我将优化器改为Adam,将训练周期增加到30,得到了这个结果。

代码:

代码语言:javascript
复制
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt

f1 = lambda x: ((x < .5) * np.power(x, 2) + (x > .5) * x) * 2 -1
x = np.linspace(-1, 1, 180).reshape(-1,1)
y = f1(x).reshape(-1, 1)
model = keras.models.Sequential()
model.add(layers.Dense(100, activation=tf.keras.activations.sigmoid, input_shape=(1,)))
model.add(layers.Dense(1, activation=tf.keras.activations.linear))
model.compile(loss=tf.keras.losses.mean_squared_error, 
              optimizer=tf.keras.optimizers.Adam(0.01), 
              metrics=[tf.keras.losses.mean_squared_error])
model.fit(x, y, batch_size=1, epochs=30)
xtest = np.linspace(-1, 1, 100).reshape(-1, 1)
ytest = model.predict(xtest)

plt.scatter(x, y)
plt.plot(xtest, ytest)
plt.show()
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69327735

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档