首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >神经网络-试图预测5+5= 10

神经网络-试图预测5+5= 10
EN

Stack Overflow用户
提问于 2020-05-22 10:25:05
回答 2查看 88关注 0票数 2

我正在学习神经网络,最近我有了这样的想法:尝试给出函数$f(x) = 2x$的NN训练数据。问题是,神经网络能否准确地预测它必须将输入数加倍才能给出正确的输出?

这只是一次“心理锻炼”,让我更好地理解NNs的工作原理。

我的Python代码不起作用,下面是我尝试过的:

神经网络类:

代码语言:javascript
复制
import numpy as np

class NeuralNetwork:
    def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
        self.inodes = inputnodes
        self.hnodes = hiddennodes
        self.onodes = outputnodes

        self.lr = learningrate

        self.wih = np.random.normal(0.0, pow(self.inodes, -0.5), (self.hnodes, self.inodes))
        self.who = np.random.normal(0.0, pow(self.hnodes, -0.5), (self.onodes, self.hnodes))

    def train(self, inputs_list, targets_list):
        inputs = np.array(inputs_list, ndmin=2).T
        targets = np.array(targets_list, ndmin=2).T

        hidden_outputs = np.dot(self.wih, inputs)
        final_outputs = np.dot(self.who, hidden_outputs)

        output_errors = targets - final_outputs
        hidden_errors = np.dot(self.who.T, output_errors)

        self.who += self.lr * np.dot(
            (output_errors * final_outputs * (1.0 - final_outputs)),
            np.transpose(hidden_outputs)
        )

        self.wih += self.lr * np.dot(
            (hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),
            np.transpose(inputs)
        )

    def query(self, inputs_list):
        inputs = np.array(inputs_list, ndmin=2).T
        hidden_outputs = np.dot(self.wih, inputs)
        final_outputs = np.dot(self.who, hidden_outputs)

        return final_outputs

培训网络并预测价值:

代码语言:javascript
复制
input_nodes = 1
hidden_nodes = 20
output_nodes = 1

learning_rate = 0.3

nn = NeuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)

for i in range(10):
    i += 1
    inputs = np.log(i)
    targets = np.log(2*i)
    nn.train(inputs, targets)

print(nn.query(np.asfarray([4])))

下面是我试图运行这段代码的输出:

代码语言:javascript
复制
x.py:26: RuntimeWarning: overflow encountered in multiply
  (output_errors * final_outputs * (1.0 - final_outputs)),  
x.py:31: RuntimeWarning: overflow encountered in multiply
  (hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),  
[[nan]]

我真的不知道如何解释这个问题,也不知道我的设计对这个应用程序是否正确。任何帮助都将不胜感激。

谢谢。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-05-22 14:31:13

一些建议:

由于兴趣函数(

票数 1
EN

Stack Overflow用户

发布于 2020-05-22 12:53:37

我认为你在人工神经网络体系结构中缺少了一个非常重要的部分/构建块,这个块被称为激活函数,它试图将输出在0,1或-1之间正常化,所以我认为在计算每个隐藏层输出后附加一个激活函数(这是非常重要的)可以解决这个问题,因为数据传播网络将保持规范化值,例如0,1之间,因此可能不会发生溢出。

备注

  1. 乙状结肠激活和tanh是最流行和适合您的问题
  2. 您的学习率可能稍大,尝试使用0.01

票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61955266

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档