首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >二层神经网络中的Python错误

二层神经网络中的Python错误
EN

Stack Overflow用户
提问于 2018-02-08 21:46:13
回答 1查看 98关注 0票数 1

我正在尝试从头开始实现一个2层神经网络。但是有些地方不对劲。经过一些迭代,我的损失变成了nan

代码语言:javascript
复制
'''
We are implementing a two layer neural network.
'''
import numpy as np

x,y = np.random.rand(64,1000),np.random.randn(64,10)
w1,w2 = np.random.rand(1000,100),np.random.rand(100,10)
learning_rate = 1e-4
x -= np.mean(x,axis=0) #Normalizing the Training Data Set

for t in range(2000):
  h = np.maximum(0,x.dot(w1))    # Applying Relu Non linearity
  ypred = h.dot(w2) #Output of Hidden layer

  loss = np.square(ypred - y).sum()
  print('Step',t,'\tLoss:- ',loss)

  #Gradient Descent

  grad_ypred = 2.0 * (ypred - y)
  gradw2 = (h.transpose()).dot(grad_ypred)
  grad_h = grad_ypred.dot(w2.transpose())
  gradw1 = (x.transpose()).dot(grad_h*h*(1-h))

  w1 -= learning_rate*gradw1
  w2 -= learning_rate*gradw2

我还使用Softmax分类器和多类SVM损失实现了线性回归。同样的问题也会发生。请告诉我如何解决这个问题。

输出:

代码语言:javascript
复制
D:\Study Material\Python 3 Tutorial\PythonScripts\Machine Learning>python TwoLayerNeuralNet.py
Step 0  Loss:-  19436393.79233052
Step 1  Loss:-  236820315509427.38
Step 2  Loss:-  1.3887002186558748e+47
Step 3  Loss:-  1.868219503527502e+189
Step 4  Loss:-  inf
TwoLayerNeuralNet.py:23: RuntimeWarning: invalid value encountered in multiply
  gradw1 = (x.transpose()).dot(grad_h*h*(1-h))
TwoLayerNeuralNet.py:12: RuntimeWarning: invalid value encountered in maximum
  h = np.maximum(0,x.dot(w1))    # Applying Relu Non linearity
Step 5  Loss:-  nan
Step 6  Loss:-  nan
Step 7  Loss:-  nan
Step 8  Loss:-  nan
Step 9  Loss:-  nan
Step 10         Loss:-  nan
Step 11         Loss:-  nan
Step 12         Loss:-  nan
Step 13         Loss:-  nan
Step 14         Loss:-  nan
Step 15         Loss:-  nan
Step 16         Loss:-  nan
Step 17         Loss:-  nan
Step 18         Loss:-  nan
Step 19         Loss:-  nan
Step 20         Loss:-  nan
EN

回答 1

Stack Overflow用户

发布于 2018-02-08 23:39:58

因为你的损失变得太高了

尝尝这个

代码语言:javascript
复制
loss = np.square(ypred - y).mean()

如果仍然不起作用,试着把学习率降低到像1e-8这样的值。

观察损失是上升还是下降,如果损失正在减少这是好的,如果损失在增加这是一个坏信号,你可能想要考虑使用更好的数据集并检查权重更新。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48687009

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档