错误:“遇到了未知类别'2‘。将add_nan=True设置为允许未知类别”,同时在pytorch预测中创建时间序列数据集。
training = TimeSeriesDataSet(
train,
time_idx="index",
target=dni,
group_ids=["Solar Zenith Angle", "Relative Humidity","Dew
Point","Temperature","Precipitable Water", "Wind Speed"],
min_encoder_length=max_encoder_length // 2, # keep encoder length long (as it is in the
validation set)
max_encoder_length=max_encoder_length,
min_prediction_length=1,
max_prediction_length=max_prediction_length,
static_reals=["Wind Direction"],
time_varying_known_reals=["index", "Solar Zenith Angle", "Relative Humidity","Dew
Point","Temperature","Precipitable Water"],
# time_varying_unknown_categoricals=[],
time_varying_unknown_reals=[dhi,dni,ghi],
categorical_encoders={data.columns[2]: NaNLabelEncoder(add_nan=True)},
target_normalizer=GroupNormalizer(
groups=["Solar Zenith Angle", "Relative Humidity","Dew
Point","Temperature","Precipitable Water", "Wind Speed"], transformation="softplus"
), # use softplus and normalize by group
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,)
发布于 2022-03-14 16:45:45
尝试添加pytorch_forecasting.data.encoders.NaNLabelEncoder(add_nan=True),如本例所示:
max_prediction_length = 1
max_encoder_length = 27
training = TimeSeriesDataSet(
sales_train,
time_idx='dayofyear',
target="QTT",
group_ids=['S100','I100','C100','C101'],
min_encoder_length=0,
max_encoder_length=max_encoder_length,
min_prediction_length=1,
max_prediction_length=max_prediction_length,
static_categoricals=[],
static_reals=['S100','I100','C100','C101'],
time_varying_known_categoricals=[],
time_varying_known_reals=['DATE'],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=['DATE'],
categorical_encoders={
'S100': *pytorch_forecasting.data.encoders.NaNLabelEncoder(add_nan=True),*
'I100':pytorch_forecasting.data.encoders.NaNLabelEncoder(add_nan=True),
'C100':pytorch_forecasting.data.encoders.NaNLabelEncoder(add_nan=True),
'C101':pytorch_forecasting.data.encoders.NaNLabelEncoder(add_nan=True)
},
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
allow_missing_timesteps=True
)
print ('Executado')发布于 2022-04-05 09:47:45
数据集中的数字特性可能有一个字符串类型。当Pandas读取csv文件时,它将所有值视为字符串,除非定义了另一种类型。
在我的例子中,我忘记将目标变量转换为数字类型。在将变量的类型更改为np.float64后立即解决了这个问题。
我希望你觉得我的经验有用。
https://stackoverflow.com/questions/71098518
复制相似问题