首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用Nolearn训练神经网络时引发的KeyError

使用Nolearn训练神经网络时引发的KeyError
EN

Stack Overflow用户
提问于 2015-12-02 05:59:57
回答 1查看 331关注 0票数 0

我正在尝试在虹膜数据集上训练神经网络。我找到了一个使用nolearn的神经网络教程,讲师使用的是mnist数据集。我试图“模仿”相同的算法,但出现了一个错误。代码如下:

代码语言:javascript
复制
 # Sklearn libraries
from sklearn.preprocessing import LabelEncoder

# Lasagne
import lasagne
from lasagne import layers
from lasagne.updates import nesterov_momentum
from nolearn.lasagne import NeuralNet

# Pandas and Numpy
import pandas as pd
import numpy as np


def load_data():
    labelenc = LabelEncoder()

# Loading of dataset
iris = pd.read_csv('/home/gunslinger/Desktop/IrisDataset.csv', header=None)
iris.iloc[:, 4] = labelenc.fit_transform(iris.iloc[:, 4])

iris = iris.iloc[np.random.permutation(np.arange(len(iris)))]

# Initialization
X = iris.iloc[:, :4]
y = iris.iloc[:, 4]
X = X.astype(np.float32)
y = y.astype(np.int32)

X_train = X[:100]
X_valid = X[100:125]
X_test  = X[125:150]

y_train = y[:100]
y_valid = y[100:125]
y_test  = y[125:150]

return dict(
    X_train=X_train,
    y_train=y_train,
    X_valid=X_valid,
    y_valid=y_valid,
    X_test=X_test,
    y_test=y_test,
)


def nn_func(data):
net1 = NeuralNet(
    layers=[('input', layers.InputLayer),
            ('hidden', layers.DenseLayer),
            ('output', layers.DenseLayer)
            ],
    # Layer parameters:
    input_shape=(None, 4),
    hidden_num_units=5,
    output_nonlinearity=lasagne.nonlinearities.softmax,
    output_num_units=3,

    # Optimization method:
    update=nesterov_momentum,
    update_learning_rate=0.01,
    update_momentum=0.9,

    max_epochs=10,
    verbose=1,
)

net1.fit(data['X_train'], data['y_train'])


def main():
data = load_data()
print("Got %i testing datasets." % len(data['X_train']))
nn_func(data)

if __name__ == '__main__':
main()

当我运行代码时得到的错误是:http://pastebin.com/9eccuzEQ

有一个与此非常相似的问题。然而,为他解决了问题的东西,对我来说却不是。

EN

回答 1

Stack Overflow用户

发布于 2015-12-04 05:18:27

我自己解决了这个问题。数据类型必须是numpy数组,而不是pandas数据帧。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/34031008

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档