首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >自建神经网络

自建神经网络
EN

Stack Overflow用户
提问于 2022-08-17 17:41:50
回答 1查看 64关注 0票数 0

我一直试图建立一个简单的神经网络自己(3层)来预测MNIST数据集。我在网上引用了一些代码,写了一些我自己的代码,代码运行时没有任何错误,但是学习过程有问题。受过训练的网络总是给我错误的预测,不管我输入什么,一两个类总是有很高的概率。我试着找出问题,但几天内没有取得任何进展。有人能告诉我我哪里做错了吗?

代码语言:javascript
复制
import numpy as np
from PIL import Image
import os
np.set_printoptions(formatter={'float_kind':'{:f}'.format})
def init_setup():
    #three layers perception
    w1=np.random.randn(10,784)-0.8
    b1=np.random.rand(10,1)-0.8
    #second layer
    w2=np.random.randn(10,10)-0.8
    b2=np.random.randn(10,1)-0.8
    #third layer
    w3=np.random.randn(10,10)-0.8
    b3=np.random.randn(10,1)-0.8
    return w1,b1,w2,b2,w3,b3
def activate(A):
    # use ReLU function as the activation function
    Z=np.maximum(0,A)
    return Z
def softmax(Z):
    return np.exp(Z)/np.sum(np.exp(Z))

def forward_propagation(A,w1,b1,w2,b2,w3,b3):
    # input A :(784,1)-> A1: (10,1) ->A2: (10,1) -> prob: (10,1)
    z1=w1@A+b1
    A1=activate(z1)
    z2=w2@A1+b2
    A2=activate(z2)
    z3=w3@A2+b3
    prob=softmax(z3)

    return z1,A1,z2,A2,z3,prob
def one_hot(Y:np.ndarray)->np.ndarray:

    one_hot=np.zeros((10, 1)).astype(int)
    
    one_hot[Y]=1
    return one_hot

def back_propagation(A,z1,A1:np.ndarray,z2,A2:np.ndarray,z3,prob,w1,w2:np.ndarray,w3,Y:np.ndarray,lr:float):

    m=1/Y.size

    dz3=prob-Y 

    dw3=m*dz3@A2.T

    db3= dz3
    dz2=ReLU_deriv(z2)*w3.T@dz3
    dw2 =  dz2@A1.T
    db2 =  dz2
    dz1=ReLU_deriv(z1)*w2.T@dz2
    dw1 = dz1@A.T
    db1 =  dz1
    return db1,dw1,dw2,db2,dw3,db3
def ReLU_deriv(Z):
    Z[Z>0]=1
    Z[Z<=0]=0
    return Z 
def step(lr,w1,b1,w2,b2,w3,b3,dw1,db1,dw2,db2,dw3,db3):
    w1 = w1 - lr * dw1

    b1 = b1 - lr * db1    
    w2 = w2 - lr * dw2  
    b2 = b2 - lr * db2
    w3 = w3 - lr * dw3 
    b3 = b3 - lr * db3       
    return w1,b1,w2,b2,w3,b3

把功能放在一起

代码语言:javascript
复制
def learn():
    lr=0.5
    dir=r'C:\Users\Desktop\MNIST - JPG - training\{}'
    w1,b1,w2,b2,w3,b3=init_setup()
    for e in range(10):
        if e%3 == 0:
            lr=lr/10
        for num in range(10):
            Y=one_hot(num)
            # print(Y)
            path=dir.format(str(num))
            for i in os.listdir(path):
                img=Image.open(path+'\\'+i)
                A=np.asarray(img)
                A=A.reshape(-1,1) 
                z1,A1,z2,A2,z3,prob=forward_propagation(A,w1,b1,w2,b2,w3,b3)
                # print('loss='+str(np.sum(np.abs(Y-prob))))
                db1,dw1,dw2,db2,dw3,db3=back_propagation(A,z1,A1,z2,A2,z3,prob,w1,w2,w3,Y,lr)
                w1,b1,w2,b2,w3,b3=step(lr,w1,b1,w2,b2,w3,b3,dw1,db1,dw2,db2,dw3,db3)
    return  w1,b1,w2,b2,w3,b3
optimize_params=learn()
w1,b1,w2,b2,w3,b3=optimize_params
img=Image.open(r'C:\Users\Desktop\MNIST - JPG - training\2\5.jpg')
A=np.asarray(img)
A=A.reshape(-1,1)
z1,A1,z2,A2,z3,prob=forward_propagation(A,w1,b1,w2,b2,w3,b3)
print(prob)
print(np.argmax(prob))

在运行了学习功能之后,网络给了我这样的东西

代码语言:javascript
复制
>>>[[0.040939]
    [0.048695]
    [0.048555]
    [0.054962]
    [0.060614]
    [0.066957]
    [0.086470]
    [0.117370]
    [0.163163]
    [0.312274]]
>>>9

结果显然是错误的,真正的标签应该是2,但正如我们在prob上看到的那样,2类的值极低,所以我相信在学习过程中一定有问题。但是我一点也不知道,有人能给我一些提示吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-08-17 21:53:26

当前代码仅在标签0和1上进行训练。

代码语言:javascript
复制
for num in range(2):

因此,您的模型没有办法“知道”任何其他标签。

现在,您的模型是以一种非常有序的方式进行培训的,因此,您的模型偏向于上的类。因为这是它在训练中看到的最后一次。您应该在每个时代洗牌您的培训数据,而不是按网络类提供数据。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73392810

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档