我在用CNN模型进行图像分类。我想尝试使用决策树或增强树模型的图像分类。我发现Tensorflow是增强树模型,但我无法理解如何将图像作为模型的输入。如果你知道如何使用tf.boosted树训练图像分类,请指导我。
发布于 2019-09-19 10:01:42
最简单的方法是使用扁平图像特征向量作为输入。您可以使用我用mnist数据集测试过的以下代码示例(修改为2类)。请注意,tensorflow BoostedTreesClassifier的当前实现不支持多类分类器的剪枝。
import tensorflow as tf
tf.enable_eager_execution()
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
for i in range(len(y_train)):
if y_train[i] is not 0:
y_train[i] = 1
for i in range(len(y_test)):
if y_test[i] is not 0:
y_test[i] = 1
NUM_FEATURES = 28*28
train_input_fn = tf.estimator.inputs.numpy_input_fn({'x': x_train.reshape(-1,NUM_FEATURES)}, y_train, batch_size=128, num_epochs=5, shuffle=True)
eval_input_fn = tf.estimator.inputs.numpy_input_fn({'x': x_test.reshape(-1,NUM_FEATURES)}, y_test, batch_size=128, num_epochs=1, shuffle=False)
features = [tf.feature_column.numeric_column("x", shape=(NUM_FEATURES,))]
est = tf.estimator.BoostedTreesClassifier(features, n_batches_per_layer=1)
est.train(train_input_fn)
results = est.evaluate(eval_input_fn)
print('Accuracy : ', results['accuracy'])https://stackoverflow.com/questions/58005117
复制相似问题