首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从多类分类tensorflow模型中获取用户的最大输出

如何从多类分类tensorflow模型中获取用户的最大输出
EN

Stack Overflow用户
提问于 2022-02-08 06:14:45
回答 2查看 130关注 0票数 1

我正在建立一个多类分类模型,根据输入的17种症状来预测疾病。作为输出,我接收一个包含“零”和“1”的数组(因为我做了一个热编码来使模型工作)。

我试图进行反向标签编码,以接收预测的疾病的标签,作为最终输出,但是收到了一系列看起来不正确的刺标签。我觉得我在一次热编码部分做错了什么。请帮帮忙。

代码语言:javascript
复制
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OrdinalEncoder
from keras.utils import np_utils

dataset = pd.read_csv('dataset.csv') #source https://www.kaggle.com/itachi9604/disease-symptom-description-dataset

symptoms = dataset.drop('Disease',axis=1).fillna('Absent')
disease=dataset['Disease']

X = symptoms.values.astype(str)
y = disease.values.astype(str)

X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.25,random_state=101)

#One-hot encoding based on https://machinelearningmastery.com/why-one-hot-encode-data-in-machine-learning/
ordinal_encoder = OrdinalEncoder()
ordinal_encoder.fit(X_train)

X_train = ordinal_encoder.transform(X_train)
X_test = ordinal_encoder.transform(X_test)

label_encoder = LabelEncoder()
label_encoder.fit(y_train)
y_train = label_encoder.transform(y_train)
y_test = label_encoder.transform(y_test)

# convert integers to dummy variables (i.e. one hot encoded)

y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)

model = Sequential()
model.add(Dense(8, input_dim=17, activation='relu'))
model.add(Dense(41, activation='softmax'))
# Compile model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(x=X_train, 
          y=y_train, 
          epochs=200,
          validation_data=(X_test, y_test), verbose=1
          )

predictions = (model.predict(X_test[:1]) > 0.5).astype("int32")

predictions = label_encoder.inverse_transform(predictions.reshape(-1))

print(predictions)
###the Output I recieve: 
['(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo' 'AIDS'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo'
 '(vertigo) Paroymsal  Positional Vertigo']
EN

回答 2

Stack Overflow用户

发布于 2022-02-15 11:20:15

只需使用argmax()执行onehot-encoding逆运算。

代码语言:javascript
复制
predictions = label_encoder.inverse_transform(np.argmax(predictions, 1))
票数 0
EN

Stack Overflow用户

发布于 2022-02-22 13:46:52

看来我已经弄明白了。问题是,我使用顺序编码的分类值(症状),而不是序号。

适用于我的情况的编码选项是Sklearn的OneHotEncoder。

伙计们,谢谢你们对这个问题的关注,因为这激励了我继续挖掘.

最后,我目前的解决方案如下所示:

代码语言:javascript
复制
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder


dataset = pd.read_csv('dataset.csv') #source https://www.kaggle.com/itachi9604/disease-symptom-description-dataset

X = dataset.drop('Disease',axis=1)
y=dataset['Disease']

X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.25,random_state=101)

X_encoder = OneHotEncoder(categories='auto',
                       drop='first', # to return k-1, use drop=false to return k dummies
                       sparse=False,
                       handle_unknown='error') # helps deal with rare labels

X_encoder.fit(X_train.fillna('Missing'))

X_train=X_encoder.transform(X_train.fillna('Missing'))
X_test=X_encoder.transform(X_test.fillna('Missing'))


y_encoder = OneHotEncoder(categories='auto',
                       drop='first', # to return k-1, use drop=false to return k dummies
                       sparse=False,
                       handle_unknown='error') # helps deal with rare labels

y_train = y_train.values.reshape(3690,1)
y_test = y_test.values.reshape(1230,1)


y_encoder.fit(y_train)


y_train=y_encoder.transform(y_train)

y_test=y_encoder.transform(y_test)

model = Sequential()
model.add(Dense(8, input_dim=391, activation='relu'))
model.add(Dense(40, activation='softmax'))
# Compile model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(x=X_train, 
          y=y_train, 
          epochs=600,
          validation_data=(X_test, y_test), verbose=1
          )
#X_test[10:11] is for taking one sline of data of symptoms  per one disease
predictions = (model.predict(X_test[10:11]) > 0.5).astype("int32")

predictions = y_encoder.inverse_transform(predictions)

print(predictions)


##Output: 

array([['Acne']], dtype=object)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71029317

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档