首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow LSTM尽管网络大,样本量小,数据预处理,但仍未学习

Tensorflow LSTM尽管网络大,样本量小,数据预处理,但仍未学习
EN

Stack Overflow用户
提问于 2020-07-12 12:55:01
回答 1查看 237关注 0票数 1

我有以下神经网络:

代码语言:javascript
复制
model = Sequential()

model.add(LSTM(50, activation='relu', return_sequences=True, input_shape=X_train.shape[1:]))
model.add(Dropout(0.2))

model.add(LSTM(100, activation='relu', return_sequences=True))
model.add(Dropout(0.2))

model.add(LSTM(150, activation='relu'))
model.add(Dropout(0.3))

model.add(Dense(10))
model.add(Dropout(0.3))

model.add(Dense(2, activation='softmax'))  # Activation_layer

opt = Adam(lr=1e-3, decay=1e-6)

model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

网络将按顺序提供数据,并试图将数据分类为1或0。

其中一个样本的示例:

X:

代码语言:javascript
复制
[[0.56450562 0.69825955 0.57768099 0.69077864]
 [0.58818427 0.70355375 0.61725885 0.30270281]
 [0.57407927 0.72501532 0.59603936 0.29196058]
 [0.56501804 0.69072662 0.59064673 0.66034622]
 [0.56552001 0.70354009 0.59136487 0.1586415 ]
 [0.56501496 0.68205159 0.57877241 0.62252169]
 [0.54535762 0.67067675 0.58414928 0.9077868 ]
 [0.56197241 0.71226839 0.5920788  0.1339519 ]
 [0.57308813 0.70469134 0.59749238 0.27085101]
 [0.56146488 0.69258436 0.58377929 0.7065891 ]
 [0.55943607 0.69106406 0.59569036 0.69378783]
 [0.5670203  0.68271571 0.58702014 0.70585781]
 [0.58320254 0.71228948 0.60867704 0.19280208]
 [0.56904526 0.71490986 0.59027546 0.35757948]
 [0.56398908 0.67858148 0.58197139 0.75064535]
 [0.57005691 0.7062191  0.60363236 0.38345417]
 [0.5705625  0.70394121 0.58630169 0.19171352]
 [0.56145905 0.69106039 0.58340288 0.76821359]
 [0.55183665 0.68991404 0.5935228  0.53419864]
 [0.56549613 0.68800419 0.58013082 0.74470123]
 [0.54926442 0.67315638 0.58336904 0.77819332]
 [0.56802882 0.71842805 0.60222782 0.12845991]
 [0.59591035 0.70927878 0.61161172 0.68023463]
 [0.56904526 0.713053   0.58773435 0.20017562]
 [0.58321778 0.69939555 0.61194041 0.47063807]
 [0.57814777 0.71113559 0.58991151 0.62149082]
 [0.56044844 0.69257776 0.58738045 0.39285414]
 [0.56853912 0.70091102 0.59713724 0.21938703]
 [0.56398364 0.69939514 0.59316136 0.43031303]
 [0.56701957 0.69901619 0.5935228  0.39333831]
 [0.56701916 0.68082684 0.58701647 0.84346823]
 [0.57765044 0.70812209 0.60147335 0.38961049]
 [0.58975543 0.71340576 0.6050683  0.61008348]
 [0.57207508 0.70280098 0.59821004 0.44573693]
 [0.56702537 0.71035313 0.59424384 0.30333905]
 [0.58417429 0.69901619 0.60288387 0.7210835 ]
 [0.56400225 0.70128289 0.59028243 0.42721302]
 [0.5725759  0.70241467 0.60000056 0.22784863]
 [0.57055816 0.69561772 0.59136355 0.66855609]
 [0.58766922 0.70995564 0.60538235 0.71163122]
 [0.57206444 0.69788453 0.59567842 0.707679  ]
 [0.5775922  0.70956495 0.60249313 0.32745877]
 [0.57407031 0.6997696  0.57952909 0.54327415]
 [0.55346759 0.69223554 0.58920848 0.27867972]
 [0.58612784 0.7031614  0.617901   0.76338596]
 [0.58659902 0.72005896 0.60604811 0.48696192]
 [0.57004823 0.70539865 0.59173347 0.47288217]
 [0.57405756 0.7023936  0.59030119 0.49981083]
 [0.55801818 0.68813345 0.58564415 0.38486918]
 [0.55900944 0.69300306 0.58527681 0.41875207]
 [0.56351994 0.68585174 0.58239563 0.70965566]
 [0.5509523  0.69524821 0.59280378 0.46280846]
 [0.56753474 0.69713124 0.59172507 0.29915786]
 [0.56753451 0.69939326 0.5978358  0.59996518]
 [0.56954889 0.69109776 0.57734904 0.27905973]
 [0.55595081 0.68429475 0.59424321 0.86881108]
 [0.57005376 0.71486763 0.60215717 0.20096972]
 [0.57509255 0.70467308 0.59028491 0.29196681]
 [0.5584625  0.68958804 0.59028342 0.24039387]
 [0.57005412 0.70203582 0.5964024  0.59344888]]

是:

代码语言:javascript
复制
1

我面临的问题是,损失开始于0.69左右,从未显著减少(略有波动),损失和验证损失都保持在0.5左右。

到目前为止我尝试过的是:

result.

  • Preprocessed

  • 检查了NaN的或值<0或>1的训练和验证数据,发现

  • 显着地减少了样本大小(降到50个样本),并且网络应该足够大到足以过载,但是遗憾的是,仍然是相同的

->以完全不同的方式使用sigmoid激活而不是softmax来对学习rate

  • Removing进行分类,这是alpha=0.05

F 218使用的第二个密集层

  • >F 218

尽管数据几乎是随机的,但是一个足够大的网络不应该很容易地适应50个或更少的样本吗?

EN

回答 1

Stack Overflow用户

发布于 2020-07-16 09:15:18

2项建议:

  • 看来您有二进制分类问题(0或1),也许您可以尝试二进制交叉熵损失?
  • 您是否使用to_categorical之类的方法来对标签进行一次热编码?

其他一些因素有时会极大地影响您没有提到的尝试/更改的准确性:

使用不同optimizers

  • Exploring不同架构的
  • :您是否考虑过一种CNN模型?或者您在不同的体系结构上测试过,有些比其他的更好吗?
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62861355

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档