我目前正在尝试使用vowpal wabbit来熟悉上下文盗贼,但我在使用数字特征时遇到了一些问题。
基本上,我的强盗应该根据两个数字特征(估计的数据速率和信息年龄)在两个动作(动作1=发送数据,动作2=空闲)之间做出决定。到目前为止,这是我的简单代码:
import pandas as pd
from vowpalwabbit import pyvw
train_data = [{'action': 2, 'cost': 0.1, 'probability': 0.5, 'feature1': 3, 'feature2': 10},
{'action': 1, 'cost': 9.99, 'probability': 0.6, 'feature1': 3, 'feature2': 10},
{'action': 1, 'cost': 0.1, 'probability': 0.2, 'feature1': 29, 'feature2': 90},
{'action': 2, 'cost': 9.99, 'probability': 0.3, 'feature1': 29, 'feature2': 90}]
train_df = pd.DataFrame(train_data)
train_df['index'] = range(1, len(train_df) + 1)
train_df = train_df.set_index("index")
test_data = [{'feature1': 29, 'feature2': 90},
{'feature1': 3, 'feature2': 10}]
test_df = pd.DataFrame(test_data)
test_df['index'] = range(1, len(test_df) + 1)
test_df = test_df.set_index("index")
vw = pyvw.vw("--cb 2")
for i in train_df.index:
action = int(train_df.loc[i, "action"])
cost = train_df.loc[i, "cost"]
probability = train_df.loc[i, "probability"]
feature1 = train_df.loc[i, "feature1"]
feature2 = train_df.loc[i, "feature2"]
# Construct the example in the required vw format.
learn_example = str(action) + ":" + str(cost) + ":" + str(probability) + " | rate:" + str(feature1) + " aoi:" + str(feature2)
vw.learn(learn_example)
for j in test_df.index:
feature1 = test_df.loc[j, "feature1"]
feature2 = test_df.loc[j, "feature2"]
#test_example = "| " + str(feature1) + " " + str(feature2)
test_example = "| rate:" + str(feature1) + " aoi:" + str(feature2)
choice = vw.predict(test_example)
print(j, choice)输出为:
1 1
2 1通常,根据我的训练数据的成本结构,我预计第一次预测的输出是动作1,第二次预测的输出是动作2。当我将特征值更改为字符(例如,在官方教程-> VW Tutorial中,"a“代表高数据速率,"b”代表低数据速率)并调整训练/测试字符串时,我得到了正确的预测,因此我认为问题是功能的错误实现:值对。
有人知道我的代码中的错误吗?
发布于 2019-12-14 01:59:14
感谢您提出这个问题。
在VW中,当您使用具有数字值(名称:值)的功能时,根据命名空间和功能名称计算哈希。因此,无论使用什么数值,大众在该功能的模型中只有一个权重。
相反,如果要对特征使用分类值(即rate=low、rate=high),则每个分类值都会在模型中获得一个权重。
因为当使用数值时,每个特征只有一个权重,与分类情况相比,大众在特征和正确动作之间映射的自由度较少,因此它需要更多数据。事实上,如果您将训练数据集复制75次,并使用总共300个示例进行训练,那么您的预测将是正确的。或者,使用你最初的4个例子的分类特征足以让大众正确地预测你的行动。下面是你可以格式化它的方式:
2:0.1:0.5 | rate=3 aoi=10
1:9.99:0.6 | rate=3 aoi=10
1:0.1:0.2 | rate=29 aoi=90
2:9.99:0.3 | rate=29 aoi=90
| rate=29 aoi=90
| rate=3 aoi=10注意: VW文本格式=中的实际上不是像:那样的特殊字符。=只是功能名称的一部分,它使人类读者更容易将其识别为具有该名称的分类功能。如果有一个功能rate具有两个可能的值low或high,则可以将其作为rate=low或rate=high传递。用其他东西替换=会改变散列,但学习不会改变,因为它们仍然被解释为不同的特征。因此,您还可以使用rate-high、rate_high等。
https://stackoverflow.com/questions/59037746
复制相似问题