首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么nltk.align.bleu_score.bleu会出错?

为什么nltk.align.bleu_score.bleu会出错?
EN

Stack Overflow用户
提问于 2016-01-01 14:40:08
回答 1查看 2.6K关注 0票数 3

当我计算BLEU对汉语句子的分数时,我发现了零值。

候选句是c,两个引语是r1r2

代码语言:javascript
复制
c=[u'\u9274\u4e8e', u'\u7f8e\u56fd', u'\u96c6', u'\u7ecf\u6d4e', u'\u4e0e', u'\u8d38\u6613', u'\u6700\u5927', u'\u56fd\u4e8e', u'\u4e00\u8eab', u'\uff0c', u'\u4e0a\u8ff0', u'\u56e0\u7d20', u'\u76f4\u63a5', u'\u5f71\u54cd', u'\u7740', u'\u4e16\u754c', u'\u8d38\u6613', u'\u3002']

r1 = [u'\u8fd9\u4e9b', u'\u76f4\u63a5', u'\u5f71\u54cd', u'\u5168\u7403', u'\u8d38\u6613', u'\u548c', u'\u7f8e\u56fd', u'\u662f', u'\u4e16\u754c', u'\u4e0a', u'\u6700\u5927', u'\u7684', u'\u5355\u4e00', u'\u7684', u'\u7ecf\u6d4e', u'\u548c', u'\u8d38\u6613\u5546', u'\u3002']

r2=[u'\u8fd9\u4e9b', u'\u76f4\u63a5', u'\u5f71\u54cd', u'\u5168\u7403', u'\u8d38\u6613', u'\uff0c', u'\u56e0\u4e3a', u'\u7f8e\u56fd', u'\u662f', u'\u4e16\u754c', u'\u4e0a', u'\u6700\u5927', u'\u7684', u'\u5355\u4e00', u'\u7684', u'\u7ecf\u6d4e\u4f53', u'\u548c', u'\u8d38\u6613\u5546', u'\u3002']

守则是:

代码语言:javascript
复制
weights = [0.1, 0.8, 0.05, 0.05]
print nltk.align.bleu_score.bleu(c, [r1, r2], weights)

但我得到了一个结果,0。当我进入bleu过程时,我发现

代码语言:javascript
复制
try:
    s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns))
except ValueError:
    # some p_ns is 0
    return 0

上面的程序转到except ValueError。但是,我不知道为什么会返回一个错误。如果我尝试其他句子,我可以得到一个非零值。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-01-01 20:48:17

似乎您在NLTK实现中遇到了错误!这个try-exceptscore.py#L76是错的

In Long:

首先,让我们来看看BLEU分数中的p_n意味着什么:

注意到

  • Papineni公式是基于语料库级别的BLEU评分,而本地实现使用的是一个句子级BLEU评分( NLTK的流血边缘版本包含了一个实现,该实现遵循Papineni的论文来计算语料库级别。
  • 在多引用BLEU中,Count_match(ngram)是基于具有更高计数的引用(参见score.py#L270)。

所以默认的BLEU分数使用的是n=4,其中包含一克到四克。对于每个纳克,让我们计算一下p_n

代码语言:javascript
复制
>>> from collections import Counter
>>> from nltk import ngrams
>>> hyp = u"鉴于 美国 集 经济 与 贸易 最大 国于 一身 , 上述 因素 直接 影响 着 世界 贸易 。".split()
>>> ref1 = u"这些 直接 影响 全球 贸易 和 美国 是 世界 上 最大 的 单一 的 经济 和 贸易商 。".split()
>>> ref2 = u"这些 直接 影响 全球 贸易 和 美国 是 世界 上 最大 的 单一 的 经济 和 贸易商 。".split()
# Calculate p_1, p_2, p_3 and p_4
>>> from nltk.translate.bleu_score import _modified_precision
>>> p_1 = _modified_precision([ref1, ref2], hyp, 1)
>>> p_2 = _modified_precision([ref1, ref2], hyp, 2)
>>> p_3 = _modified_precision([ref1, ref2], hyp, 3)
>>> p_4 = _modified_precision([ref1, ref2], hyp, 4)
>>> p_1, p_2, p_3, p_4
(Fraction(4, 9), Fraction(1, 17), Fraction(0, 1), Fraction(0, 1))

请注意BLEU评分中最新版本的_modified_precision,因为https://github.com/nltk/nltk/pull/1229使用的是Fraction而不是float输出。所以现在,我们可以清楚地看到分子和分母。

因此,现在让我们验证一下_modified_precision的输出是否为unigram。在假设中,粗体词出现在参考文献中:

  • <代码>E 235E 136E142E 237E 138E 239E 140E142e 241E 142。

有9个令牌重叠,其中1个是重复的,发生两次。

代码语言:javascript
复制
>>> from collections import Counter
>>> ref1_unigram_counts = Counter(ngrams(ref1, 1))
>>> ref2_unigram_counts = Counter(ngrams(ref2, 1))
>>> hyp_unigram_counts = Counter(ngrams(hyp,1))
>>> for overlaps in set(hyp_unigram_counts.keys()).intersection(ref1_unigram_counts.keys()):
...     print " ".join(overlaps)
... 
美国
直接
经济
影响
。
最大
世界
贸易
>>> overlap_counts = Counter({ng:hyp_unigram_counts[ng] for ng in set(hyp_unigram_counts.keys()).intersection(ref1_unigram_counts.keys())})
>>> overlap_counts
Counter({(u'\u8d38\u6613',): 2, (u'\u7f8e\u56fd',): 1, (u'\u76f4\u63a5',): 1, (u'\u7ecf\u6d4e',): 1, (u'\u5f71\u54cd',): 1, (u'\u3002',): 1, (u'\u6700\u5927',): 1, (u'\u4e16\u754c',): 1})

现在,让我们看看这些重叠词在引用中发生了多少次。将来自不同引用的“组合”计数器的值作为p_1公式的分子。如果两个引用中都出现相同的单词,则取最大计数。

代码语言:javascript
复制
>>> overlap_counts_in_ref1 = Counter({ng:ref1_unigram_counts[ng] for ng in set(hyp_unigram_counts.keys()).intersection(ref1_unigram_counts.keys())})
>>> overlap_counts_in_ref2 = Counter({ng:ref2_unigram_counts[ng] for ng in set(hyp_unigram_counts.keys()).intersection(ref1_unigram_counts.keys())})
>>> overlap_counts_in_ref1
Counter({(u'\u7f8e\u56fd',): 1, (u'\u76f4\u63a5',): 1, (u'\u7ecf\u6d4e',): 1, (u'\u5f71\u54cd',): 1, (u'\u3002',): 1, (u'\u6700\u5927',): 1, (u'\u4e16\u754c',): 1, (u'\u8d38\u6613',): 1})
>>> overlap_counts_in_ref2
Counter({(u'\u7f8e\u56fd',): 1, (u'\u76f4\u63a5',): 1, (u'\u7ecf\u6d4e',): 1, (u'\u5f71\u54cd',): 1, (u'\u3002',): 1, (u'\u6700\u5927',): 1, (u'\u4e16\u754c',): 1, (u'\u8d38\u6613',): 1})
>>> overlap_counts_in_ref1_ref2 = Counter()
>>> numerator = overlap_counts_in_ref1_ref2
>>> 
>>> for c in [overlap_counts_in_ref1, overlap_counts_in_ref2]:
...     for k in c:
...             numerator[k] = max(numerator.get(k,0), c[k])
... 
>>> numerator
Counter({(u'\u7f8e\u56fd',): 1, (u'\u76f4\u63a5',): 1, (u'\u7ecf\u6d4e',): 1, (u'\u5f71\u54cd',): 1, (u'\u3002',): 1, (u'\u6700\u5927',): 1, (u'\u4e16\u754c',): 1, (u'\u8d38\u6613',): 1})
>>> sum(numerator.values())
8

现在对于分母来说,这仅仅是否定。在假设中出现的单位数:

代码语言:javascript
复制
>>> hyp_unigram_counts
Counter({(u'\u8d38\u6613',): 2, (u'\u4e0e',): 1, (u'\u7f8e\u56fd',): 1, (u'\u56fd\u4e8e',): 1, (u'\u7740',): 1, (u'\u7ecf\u6d4e',): 1, (u'\u5f71\u54cd',): 1, (u'\u56e0\u7d20',): 1, (u'\u4e16\u754c',): 1, (u'\u3002',): 1, (u'\u4e00\u8eab',): 1, (u'\u6700\u5927',): 1, (u'\u9274\u4e8e',): 1, (u'\u4e0a\u8ff0',): 1, (u'\u96c6',): 1, (u'\u76f4\u63a5',): 1, (u'\uff0c',): 1})
>>> sum(hyp_unigram_counts.values())
18

因此得到的分数是8/18 -> 4/9,我们的_modified_precision函数签出了。

现在让我们来看看BLEU的完整公式:

从公式中,我们只考虑现在求和的指数,即exp(...)。它也可以简化为各种p_n的对数之和,如我们先前计算的,即sum(log(p_n))。这就是在NLTK中实现它的方式,请参阅score.py#L79

现在,忽略BP,让我们考虑一下p_n之和,并考虑它们各自的权重:

代码语言:javascript
复制
>>> from fractions import Fraction
>>> from math import log
>>> log(Fraction(4, 9))
-0.8109302162163288
>>> log(Fraction(1, 17))
-2.833213344056216
>>> log(Fraction(0, 1))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: math domain error

啊哈!这就是出现错误的地方,日志的和在通过ValueError时会返回一个math.fsum()。

代码语言:javascript
复制
>>> try:
...     sum(log(pi) for pi in (Fraction(4, 9), Fraction(1, 17), Fraction(0, 1), Fraction(0, 1)))
... except ValueError:
...     0
... 
0

要纠正实现,the try-except应该是:

代码语言:javascript
复制
s = []
# Calculates the overall modified precision for all ngrams.
# by summing the the product of the weights and the respective log *p_n*
for w, p_n in zip(weights, p_ns)):
    try:
        s.append(w * math.log(p_n))
    except ValueError:
        # some p_ns is 0
        s.append(0)
 return sum(s)

参考资料:

公式来自http://lotus.kuee.kyoto-u.ac.jp/WAT/papers/submissions/W15/W15-5009.pdf,描述了BLEU的一些敏感问题。

票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/34557078

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档