我想通过scikit学习管道整合标签编码。不幸的是,LabelEncoder()被管道API破坏了,所以现在这不是一个选项。我尝试创建自己的类,它调用.map()将类别映射到标签:
from sklearn.base import TransformerMixin
from sklearn.base import BaseEstimator
class RatingEncoder(BaseEstimator, TransformerMixin):
"""Takes in dataframe, converts all categorical Ratings columns into numerical Ratings columns
via label-encoding"""
def __init__(self):
pass
def fit(self, df, y=None):
return self
def transform(self, df, y=None):
""""Transform all of the categorical ratings columns into numerical ratings columns"""
for feature in df.columns:
df[feature] = df[feature].map({
"Po" : 1,
"Fa" : 2,
"TA" : 3,
"Gd" : 4,
"Ex" : 5,
})
return df然后,我设置了以下管道:
def select_numeric_features(df):
return df.select_dtypes(include=np.number).columns
def select_categorical_features(df):
return df.select_dtypes(exclude=np.number).columns
def select_rated_features(df):
rated_features = []
for column in df:
# This criteria determines if a column is a 'rated column'
if any(df[column] == 'TA'):
rated_features.append(column)
return rated_features
pipeline = make_column_transformer(
(RatingsTransformer(), select_rated_features),
(SimpleImputer(strategy='constant', fill_value='None'), select_categorical_features),
(SimpleImputer(strategy='constant', fill_value=0), select_numeric_features),
remainder='passthrough'
)这方面的问题是,在RatingsTransformer()步骤之后,分类的“评级”列应该变成数字列。但是,这种更改不会出现在列转换器的列选择部分中,因此select_numerical_features和select_categorical_features将选择不正确的“分级”列,就好像它们没有从类别映射到值一样。基本上,列转换器没有使用在管道中间更新的列。有什么解决办法吗?或者,是否有一个更简单的解决方案LabelEncoding使用管道API?
发布于 2019-12-20 14:20:44
LabelEncoder是编码标签,从而编码y (或目标)。如果您想对数据(即X)进行编码,您可以使用一个OneHotEncoder或OrdinalEncoder,它们可以很容易地从scikit学习中集成到Pipeline中。
在您的情况下,您似乎希望对数据进行序号编码。
from sklearn.pipeline import make_pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OrdinalEncoder
preprocessor = make_pipeline(
SimpleImputer(strategy="constant", fill_value="missing"),
OrdinalEncoder()
)
preprocessor.fit_transform(X_train)如果分类器不是线性模型(例如,OrdinalEncoder ),您可以想象使用一个OneHotEncoder而不是OneHotEncoder。
https://stackoverflow.com/questions/59423266
复制相似问题