这个可复制示例创建了一个基本的回归模型,预测MPG给定的马力(希望我可以只提供链接)。据我所知,这将将功能马力转换为模型的培训--也称为“模型内部”。这很有吸引力,因为模型在评分/推理期间(例如在部署之后)对原始数据进行了必要的转换(如果我误解了,请纠正我)。我想知道,当一个人拥有比自变量更多的东西时,这是如何实现的。这摘自上文引用的可复制代码:
horsepower_normalizer = tf.keras.layers.Normalization(input_shape=[1, ], axis=None)
horsepower_normalizer.adapt(horsepower)
horsepower_normalizer = tf.keras.layers.Normalization(input_shape=[1, ], axis=None)
horsepower_normalizer.adapt(horsepower)
horsepower_model = Sequential([
horsepower_normalizer,
layers.Dense(units=1)
])那么,假设我们有一个数字特性列表,X, Y, Z,可以在此基础上(例如通过functional )生成模型定义代码吗?任何指示都是非常受欢迎的。谢谢!
PS:
我目前正在努力学习Keras + TF,理想情况下,我希望正常化成为模式/培训的一部分。我使用了非常粗鲁的代码(需要改进!)按照这些方针:
train_data = pd.read_csv('train.csv')
val_data = pd.read_csv('val.csv')
target_name = 'ze_target'
y_train = train_data[target_name]
X_train = train_data.drop(target_name, axis=1)
y_val = train_data[target_name]
X_val = train_data.drop(target_name, axis=1)
def create_model():
model = Sequential()
model.add(Dense(20, input_dim=X.shape[1], activation='relu'))
model.add(Dense(20, input_dim=X.shape[1], activation='relu'))
model.add(Dense(20, input_dim=X.shape[1], activation='relu'))
model.add(Dense(1))
# Compile model
model.compile(optimizer=Adam(learning_rate=0.0001), loss = 'mse')
return model
model = create_model()
model.summary()
model.fit(X_train, y_train, validation_data=(X_val,y_val), batch_size=128, epochs=30)发布于 2022-06-25 17:47:35
您可以在tf.concat上使用axis=1并将三个特性连接起来,然后对三个特性使用tf.keras.layers.Normalization,如下所示,因为我们希望对三个特性进行规范化,确保设置input_shape=(3,)和axis=-1。
import tensorflow as tf
x = tf.random.uniform((100, 1))
y = tf.random.uniform((100, 1))
z = tf.random.uniform((100, 1))
xyz = tf.concat([x, y, z], 1)
horsepower_normalizer = tf.keras.layers.Normalization(input_shape=(3,), axis=-1)
horsepower_normalizer.adapt(xyz)
horsepower_model = tf.keras.models.Sequential([
horsepower_normalizer,
tf.keras.layers.Dense(units=1)
])
horsepower_model(xyz)输出:
<tf.Tensor: shape=(100, 1), dtype=float32, numpy=
array([[-0.17135675],
[-0.48248804],
[-2.2847023 ],
[-0.05702276],
[ 2.9332483 ],
[ 0.64826846],
[-2.1490448 ],
[-1.1697797 ],
[-0.01030668],
...
[-1.880199 ],
[ 1.2854142 ],
[-0.5471661 ]], dtype=float32)>https://stackoverflow.com/questions/72755165
复制相似问题