首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Viterbi算法与单元测试的实现

Viterbi算法与单元测试的实现
EN

Code Review用户
提问于 2018-01-11 22:33:36
回答 2查看 1K关注 0票数 9

算法

简介

一个系统可以处于N个不同的、不可观测的状态,(也就是说,我们永远不知道系统的实际状态)。该系统还具有有限的可能的可观测“输出”,这些输出取决于系统的实际(不可观测的)状态。

Viterbi算法的输入是一个时间序列中的观测列表,该算法计算每个时间帧对应的最可能的状态。

除意见清单外,还提供了以下内容:

  • (不可观测)态的初始概率分布
  • 对于每种状态,它转换到另一种状态(包括自身)的概率。
  • 对于每一种状态,每个观察都能在这种状态下被观察到的概率。

有关更多细节,请参见例如维基百科

--关于

设计的几点看法

除了正确地实现算法之外,我还想到了以下几个目标:

  • 接口应易于使用,且不容易出错。
  • 在开始计算之前,应该对数据进行验证。
  • 在每一步之后,算法的状态应该是可观测的(因此,nextStepgetProbabilitiesForObservationsgetPreviousStatesObservations方法)。
  • 然而,获得最终结果也应该很容易,这是由calculate方法完成的。

本综述的

目标

虽然任何建议/意见都是受欢迎的,但以下是我最感兴趣的几点。

Implementation

  • 在Viterbi算法的实现中,您看到了什么错误吗?(换句话说,有效的输入会产生错误的结果。)
  • 您能给出任何无效的输入,这些输入是验证没有检测到的吗?
  • 是否有更好的方法来列出与所有状态/观测相对应的所有枚举(特别是,不需要至少一个元素的集合)?
  • API可以改进吗?(例如,为了使它更直观、更容易使用、更少出错等)
  • 你对验证的地点有什么看法?把它放入模型而不是机器中的好处是什么?(我把它放进机器里,因为我想让模型尽可能的愚蠢,并且把所有的逻辑都放在机器里。)

测试

  • 是否有办法更好地组织与测试用例相对应的状态/观察枚举?(不幸的是,在Java 不可能在方法中定义枚举。中,这也不是一个完整的解决方案,因为一些枚举在测试用例中是共享的。)
  • 您能为算法本身提出更多的测试用例吗?(就目前而言,我没有测试“步进”逻辑,这一逻辑稍后才会出现。)
  • 你有没有看到任何多余的测试。

代码

注:我只包括有关的部分。在GitHub上可以找到一个完整的工作版本。作为外部库,代码使用番石榴。(以及用于测试的JUnit / 汉克雷斯特。)

Implementation

代码语言:javascript
复制
public static class ViterbiModel<S extends Enum<S>, T extends Enum<T>> {
    public final ImmutableMap<S, Double> initialDistributions;
    public final ImmutableTable<S, S, Double> transitionProbabilities;
    public final ImmutableTable<S, T, Double> emissionProbabilities;

    private ViterbiModel(ImmutableMap<S, Double> initialDistributions, 
            ImmutableTable<S, S, Double> transitionProbabilities,
            ImmutableTable<S, T, Double> emissionProbabilities) {
        this.initialDistributions = checkNotNull(initialDistributions);
        this.transitionProbabilities = checkNotNull(transitionProbabilities);
        this.emissionProbabilities = checkNotNull(emissionProbabilities);
    }

    public static <S extends Enum<S>, T extends Enum<T>> Builder<S, T> builder() {
        return new Builder<>();
    }

    public static class Builder<S extends Enum<S>, T extends Enum<T>> {
        private ImmutableMap<S, Double> initialDistributions;
        private ImmutableTable.Builder<S, S, Double> transitionProbabilities = ImmutableTable.builder();
        private ImmutableTable.Builder<S, T, Double> emissionProbabilities = ImmutableTable.builder();

        public ViterbiModel<S, T> build() {
            return new ViterbiModel<S, T>(immutableEnumMap(initialDistributions), transitionProbabilities.build(), emissionProbabilities.build());
        }

        public Builder<S, T> withInitialDistributions(ImmutableMap<S, Double> initialDistributions) {
            this.initialDistributions = initialDistributions;
            return this;
        }

        public Builder<S, T> withTransitionProbability(S src, S dest, Double prob) {
            transitionProbabilities.put(src, dest, prob);
            return this;
        }

        public Builder<S, T> withEmissionProbability(S state, T emission, Double prob) {
            emissionProbabilities.put(state, emission, prob);
            return this;
        }
    }
}

public static class ViterbiMachine<S extends Enum<S>, T extends Enum<T>> {
    private final List<S> possibleStates;
    private final List<T> possibleObservations;

    private final ViterbiModel<S, T> model;
    private final ImmutableList<T> observations;

    private Table<S, Integer, Double> stateProbsForObservations = HashBasedTable.create();
    private Table<S, Integer, Optional<S>> previousStatesForObservations = HashBasedTable.create();

    private int step;

    public ViterbiMachine(ViterbiModel<S, T> model, ImmutableList<T> observations) {
        this.model = checkNotNull(model);
        this.observations = checkNotNull(observations);

        try {
            possibleStates = ImmutableList.copyOf(getPossibleStates());
        } catch (IllegalStateException ise) {
            throw new IllegalArgumentException("empty states enum, or no explicit initial distribution provided", ise);
        }

        try {
            possibleObservations = ImmutableList.copyOf(getPossibleObservations());
        } catch (IllegalStateException ise) {
            throw new IllegalArgumentException("empty observations enum, or no explicit observations provided", ise);
        }

        validate();
        initialize();
    }

    private void validate() {
        if (model.initialDistributions.size() != possibleStates.size()) {
            throw new IllegalArgumentException("model.initialDistributions.size() = " + model.initialDistributions.size());
        }
        double sumInitProbs = 0.0;
        for (double prob: model.initialDistributions.values()) {
            sumInitProbs += prob;
        }
        if (!doublesEqual(sumInitProbs, 1.0)) {
            throw new IllegalArgumentException("the sum of initial distributions should be 1.0, was " + sumInitProbs);
        }
        if (observations.size() < 1) {
            // should not happen (observations size already checked when retrieving possible enum values),
            // only added for the sake of completeness
            throw new IllegalArgumentException("at least one observation should be provided, " + observations.size() + " given");
        }
        if (model.transitionProbabilities.size() < 1) {
            throw new IllegalArgumentException("at least one transition probability should be provided, " + model.transitionProbabilities.size() + " given");
        }
        for (S row : possibleStates) {
            double sumRowProbs = 0.0;
            for (double prob : rowOrDefault(model.transitionProbabilities, row, ImmutableMap.<S, Double>of()).values()) {
                sumRowProbs += prob;
            }
            if (!doublesEqual(sumRowProbs, 1.0)) {
                throw new IllegalArgumentException("sum of transition probabilities for each state should be one, was " + sumRowProbs + " for state " + row);
            }
        }
        if (model.emissionProbabilities.size() < 1) {
            throw new IllegalArgumentException("at least one emission probability should be provided, 0 given " + model.emissionProbabilities.size() + " given");
        }
        for (S row : possibleStates) {
            double sumRowProbs = 0.0;
            for (double prob : rowOrDefault(model.emissionProbabilities, row, ImmutableMap.<T, Double>of()).values()) {
                sumRowProbs += prob;
            }
            if (!doublesEqual(sumRowProbs, 1.0)) {
                throw new IllegalArgumentException("sum of emission probabilities for each state should be one, was " + sumRowProbs + " for state " + row);
            }
        }
    }

    private static <S, T, V> V getOrDefault(Table<S, T, V> table, S key1, T key2, V defaultValue) {
        V ret = table.get(key1, key2);
        if (ret == null) {
            ret = defaultValue;
        }
        return ret;
    }

    private static <S, T, V> Map<T, V> rowOrDefault(Table<S, T, V> table, S key, Map<T, V> defaultValue) {
        Map<T, V> ret = table.row(key);
        if (ret == null) {
            ret = defaultValue;
        }
        return ret;
    }

    private void initialize() {
        final T firstObservation = observations.get(0);
        for (S state : possibleStates) {
            stateProbsForObservations.put(state, 0, model.initialDistributions.getOrDefault(state, 0.0) * getOrDefault(model.emissionProbabilities, state, firstObservation, 0.0));
            previousStatesForObservations.put(state, 0, Optional.<S>empty());
        }

        step = 1;
    }

    public void nextStep() {
        if (step >= observations.size()) {
            throw new IllegalStateException("already finished last step");
        }

        for (S state : possibleStates) {
            double maxProb = 0.0;
            Optional<S> prevStateWithMaxProb = Optional.empty();
            for (S state2 : possibleStates) {
                double prob = getOrDefault(stateProbsForObservations, state2, step - 1, 0.0) * getOrDefault(model.transitionProbabilities, state2, state, 0.0);
                if (prob > maxProb) {
                    maxProb = prob;
                    prevStateWithMaxProb = Optional.of(state2);
                }
            }
            stateProbsForObservations.put(state, step, maxProb * getOrDefault(model.emissionProbabilities, state, observations.get(step), 0.0));
            previousStatesForObservations.put(state, step, prevStateWithMaxProb);
        }

        ++step;
    }

    public ImmutableTable<S, Integer, Double> getProbabilitiesForObservations() {
        return ImmutableTable.copyOf(stateProbsForObservations);
    }

    public ImmutableTable<S, Integer, Optional<S>> getPreviousStatesObservations() {
        return ImmutableTable.copyOf(previousStatesForObservations);
    }

    public List<S> finish() {
        if (step != observations.size()) {
            throw new IllegalStateException("step = " + step);
        }

        S stateWithMaxProb = possibleStates.get(0);
        double maxProb = stateProbsForObservations.get(stateWithMaxProb, observations.size() - 1);
        for (S state : possibleStates) {
            double prob = stateProbsForObservations.get(state, observations.size() - 1);
            if (prob > maxProb) {
                maxProb = prob;
                stateWithMaxProb = state;
            }
        }

        List<S> result = new ArrayList<>();

        for (int i = observations.size() - 1; i >= 0; --i) {
            result.add(stateWithMaxProb);
            stateWithMaxProb = previousStatesForObservations.get(stateWithMaxProb, i).orElse(null);
        }

        return Lists.reverse(result);
    }

    public List<S> calculate() {
        for (int i = 0; i < observations.size() - 1; ++i) {
            nextStep();
        }
        return finish();
    }

    private S[] getPossibleStates() {
        return getEnumsFromIterator(model.initialDistributions.keySet().iterator());
    }

    private T[] getPossibleObservations() {
        return getEnumsFromIterator(observations.iterator());
    }

    private static <X extends Enum<X>> X[] getEnumsFromIterator(Iterator<X> it) {
        if (!it.hasNext()) {
            throw new IllegalStateException("iterator should have at least one element");
        }
        Enum<X> val1 = it.next();
        return val1.getDeclaringClass().getEnumConstants();
    }

    private static boolean doublesEqual(double d1, double d2) {
        return Math.abs(d1 - d2) < 0.0000001;
    }
}

测试

代码语言:javascript
复制
public class ViterbiTest {

    @Rule
    public ExpectedException thrown = ExpectedException.none(); 

    enum ZeroStatesZeroObservationsState { };
    enum ZeroStatesZeroObservationsObservation { };

    @Test
    public void zeroStatesZeroObservationsIsNotOk() {
        ViterbiModel<ZeroStatesZeroObservationsState, ZeroStatesZeroObservationsObservation> model = ViterbiModel.<ZeroStatesZeroObservationsState, ZeroStatesZeroObservationsObservation>builder()
                .withInitialDistributions(ImmutableMap.<ZeroStatesZeroObservationsState, Double>builder()
                        .build())
                .build();

        ImmutableList<ZeroStatesZeroObservationsObservation> observations = ImmutableList.of();

        thrown.expect(IllegalArgumentException.class);
        thrown.expectMessage("empty states enum, or no explicit initial distribution provided");
        new ViterbiMachine<>(model, observations);
    }

    enum ZeroStatesOneObservationState { };
    enum ZeroStatesOneObservationObservation { OBSERVATION0 };

    @Test
    public void zeroStatesOneObservationIsNotOk() {
        ViterbiModel<ZeroStatesOneObservationState, ZeroStatesOneObservationObservation> model = ViterbiModel.<ZeroStatesOneObservationState, ZeroStatesOneObservationObservation>builder()
                .withInitialDistributions(ImmutableMap.<ZeroStatesOneObservationState, Double>builder()
                        .build())
                .build();

        ImmutableList<ZeroStatesOneObservationObservation> observations = ImmutableList.of();

        thrown.expect(IllegalArgumentException.class);
        thrown.expectMessage("empty states enum, or no explicit initial distribution provided");
        new ViterbiMachine<>(model, observations);
    }

    enum OneStateZeroObservationsState { STATE0 };
    enum OneStateZeroObservationsObservation { };

    @Test
    public void oneStateZeroObservationsIsNotOk() {
        ViterbiModel<OneStateZeroObservationsState, OneStateZeroObservationsObservation> model = ViterbiModel.<OneStateZeroObservationsState, OneStateZeroObservationsObservation>builder()
                .withInitialDistributions(ImmutableMap.<OneStateZeroObservationsState, Double>builder()
                        .put(OneStateZeroObservationsState.STATE0, 1.0)
                        .build())
                .build();

        ImmutableList<OneStateZeroObservationsObservation> observations = ImmutableList.of();

        thrown.expect(IllegalArgumentException.class);
        thrown.expectMessage("empty observations enum, or no explicit observations provided");
        new ViterbiMachine<>(model, observations);
    }

    enum OneStateOneObservationState { STATE0 };
    enum OneStateOneObservationObservation { OBSERVATION0 };

    @Test
    public void oneStateOneObservationIsOk() {
        ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
                .withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
                        .put(OneStateOneObservationState.STATE0, 1.0)
                        .build())
                .withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
                .withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
                .build();

        ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);

        ViterbiMachine<OneStateOneObservationState, OneStateOneObservationObservation> machine = new ViterbiMachine<>(model, observations);
        List<OneStateOneObservationState> states = machine.calculate();
        final List<OneStateOneObservationState> expected = ImmutableList.of(OneStateOneObservationState.STATE0);
        assertThat(states, is(expected));
    }

    @Test
    public void oneStateOneObservationMissingInitialDistributionIsNotOk() {
        ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
                .withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
                        .build())
                .withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
                .withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
                .build();

        ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);

        thrown.expect(IllegalArgumentException.class);
        thrown.expectMessage("empty states enum, or no explicit initial distribution provided");
        new ViterbiMachine<>(model, observations);
    }

    @Test
    public void oneStateOneObservationMissingObservationsIsNotOk() {
        ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
                .withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
                        .put(OneStateOneObservationState.STATE0, 1.0)
                        .build())
                .withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
                .withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
                .build();

        ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of();

        thrown.expect(IllegalArgumentException.class);
        thrown.expectMessage("empty observations enum, or no explicit observations provided");
        new ViterbiMachine<>(model, observations);
    }

    @Test
    public void oneStateOneObservationSumInitialDistribNotOneIsNotOk() {
        ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
                .withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
                        .put(OneStateOneObservationState.STATE0, 1.1)
                        .build())
                .withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
                .withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
                .build();

        ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);

        thrown.expect(IllegalArgumentException.class);
        thrown.expectMessage("the sum of initial distributions should be 1.0, was 1.1");
        new ViterbiMachine<>(model, observations);
    }

    @Test
    public void oneStateOneObservationNoTransitionProbabilitiesIsNotOk() {
        ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
                .withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
                        .put(OneStateOneObservationState.STATE0, 1.0)
                        .build())
                .withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
                .build();

        ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);

        thrown.expect(IllegalArgumentException.class);
        thrown.expectMessage("at least one transition probability should be provided, 0 given");
        new ViterbiMachine<>(model, observations);
    }

    @Test
    public void oneStateOneObservationSumTransitionProbabilitiesNotOneIsNotOk() {
        ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
                .withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
                        .put(OneStateOneObservationState.STATE0, 1.0)
                        .build())
                .withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.1)
                .withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
                .build();

        ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);

        thrown.expect(IllegalArgumentException.class);
        thrown.expectMessage("sum of transition probabilities for each state should be one, was 1.1 for state STATE0");
        new ViterbiMachine<>(model, observations);
    }

    @Test
    public void oneStateOneObservationZeroEmissionProbabilitiesIsNotOk() {
        ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
                .withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
                        .put(OneStateOneObservationState.STATE0, 1.0)
                        .build())
                .withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
                .build();

        ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);

        thrown.expect(IllegalArgumentException.class);
        thrown.expectMessage("at least one emission probability should be provided, 0 given");
        new ViterbiMachine<>(model, observations);
    }

    @Test
    public void oneStateOneObservationSumEmissionProbabilitiesNotOneIsNotOk() {
        ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
                .withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
                        .put(OneStateOneObservationState.STATE0, 1.0)
                        .build())
                .withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
                .withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.1)
                .build();

        ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);

        thrown.expect(IllegalArgumentException.class);
        thrown.expectMessage("sum of emission probabilities for each state should be one, was 1.1 for state STATE0");
        new ViterbiMachine<>(model, observations);
    }

    enum OneStateTwoObservationsState { STATE0 };
    enum OneStateTwoObservationsObservation { OBSERVATION0, OBSERVATION1 };

    @Test
    public void oneStateTwoObservationsIsOk() {
        ViterbiModel<OneStateTwoObservationsState, OneStateTwoObservationsObservation> model = ViterbiModel.<OneStateTwoObservationsState, OneStateTwoObservationsObservation>builder()
                .withInitialDistributions(ImmutableMap.<OneStateTwoObservationsState, Double>builder()
                        .put(OneStateTwoObservationsState.STATE0, 1.0)
                        .build())
                .withTransitionProbability(OneStateTwoObservationsState.STATE0, OneStateTwoObservationsState.STATE0, 1.0)
                .withEmissionProbability(OneStateTwoObservationsState.STATE0, OneStateTwoObservationsObservation.OBSERVATION0, 0.4)
                .withEmissionProbability(OneStateTwoObservationsState.STATE0, OneStateTwoObservationsObservation.OBSERVATION1, 0.6)
                .build();

        ImmutableList<OneStateTwoObservationsObservation> observations = ImmutableList.of(OneStateTwoObservationsObservation.OBSERVATION1, OneStateTwoObservationsObservation.OBSERVATION1);

        ViterbiMachine<OneStateTwoObservationsState, OneStateTwoObservationsObservation> machine = new ViterbiMachine<>(model, observations);
        List<OneStateTwoObservationsState> states = machine.calculate();
        final List<OneStateTwoObservationsState> expected = ImmutableList.of(OneStateTwoObservationsState.STATE0, OneStateTwoObservationsState.STATE0);
        assertThat(states, is(expected));
    }

    enum TwoStatesOneObservationState { STATE0, STATE1 };
    enum TwoStatesOneObservationObservation { OBSERVATION0 };

    @Test
    public void twoStatesOneObservationIsOk() {
        ViterbiModel<TwoStatesOneObservationState, TwoStatesOneObservationObservation> model = ViterbiModel.<TwoStatesOneObservationState, TwoStatesOneObservationObservation>builder()
                .withInitialDistributions(ImmutableMap.<TwoStatesOneObservationState, Double>builder()
                        .put(TwoStatesOneObservationState.STATE0, 0.6)
                        .put(TwoStatesOneObservationState.STATE1, 0.4)
                        .build())
                .withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE0, 0.7)
                .withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE1, 0.3)
                .withTransitionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationState.STATE0, 0.4)
                .withTransitionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationState.STATE1, 0.6)
                .withEmissionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
                .withEmissionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
                .build();

        ImmutableList<TwoStatesOneObservationObservation> observations = ImmutableList.of(TwoStatesOneObservationObservation.OBSERVATION0, TwoStatesOneObservationObservation.OBSERVATION0);

        ViterbiMachine<TwoStatesOneObservationState, TwoStatesOneObservationObservation> machine = new ViterbiMachine<>(model, observations);
        List<TwoStatesOneObservationState> states = machine.calculate();
        final List<TwoStatesOneObservationState> expected = ImmutableList.of(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE0);
        assertThat(states, is(expected));
    }

    @Test
    public void twoStatesOneObservationTransitionsOmittedForOneStateIsNotOk() {
        ViterbiModel<TwoStatesOneObservationState, TwoStatesOneObservationObservation> model = ViterbiModel.<TwoStatesOneObservationState, TwoStatesOneObservationObservation>builder()
                .withInitialDistributions(ImmutableMap.<TwoStatesOneObservationState, Double>builder()
                        .put(TwoStatesOneObservationState.STATE0, 0.6)
                        .put(TwoStatesOneObservationState.STATE1, 0.4)
                        .build())
                .withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE0, 0.7)
                .withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE1, 0.3)
                .withEmissionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
                .withEmissionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
                .build();

        ImmutableList<TwoStatesOneObservationObservation> observations = ImmutableList.of(TwoStatesOneObservationObservation.OBSERVATION0, TwoStatesOneObservationObservation.OBSERVATION0);

        thrown.expect(IllegalArgumentException.class);
        thrown.expectMessage("sum of transition probabilities for each state should be one, was 0.0 for state STATE1");
        new ViterbiMachine<>(model, observations);
    }

    @Test
    public void twoStatesOneObservationEmissionsOmittedForOneStateIsNotOk() {
        ViterbiModel<TwoStatesOneObservationState, TwoStatesOneObservationObservation> model = ViterbiModel.<TwoStatesOneObservationState, TwoStatesOneObservationObservation>builder()
                .withInitialDistributions(ImmutableMap.<TwoStatesOneObservationState, Double>builder()
                        .put(TwoStatesOneObservationState.STATE0, 0.6)
                        .put(TwoStatesOneObservationState.STATE1, 0.4)
                        .build())
                .withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE0, 0.7)
                .withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE1, 0.3)
                .withTransitionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationState.STATE0, 0.4)
                .withTransitionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationState.STATE1, 0.6)
                .withEmissionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
                .build();

        ImmutableList<TwoStatesOneObservationObservation> observations = ImmutableList.of(TwoStatesOneObservationObservation.OBSERVATION0, TwoStatesOneObservationObservation.OBSERVATION0);

        thrown.expect(IllegalArgumentException.class);
        thrown.expectMessage("sum of emission probabilities for each state should be one, was 0.0 for state STATE1");
        new ViterbiMachine<>(model, observations);
    }

    enum TwoStatesTwoObservationsState { STATE0, STATE1 };
    enum TwoStatesTwoObservationsObservation { OBSERVATION0, OBSERVATION1 };

    @Test
    public void twoStatesTwoObservationsIsOk() {
        ViterbiModel<TwoStatesTwoObservationsState, TwoStatesTwoObservationsObservation> model = ViterbiModel.<TwoStatesTwoObservationsState, TwoStatesTwoObservationsObservation>builder()
                .withInitialDistributions(ImmutableMap.<TwoStatesTwoObservationsState, Double>builder()
                        .put(TwoStatesTwoObservationsState.STATE0, 0.6)
                        .put(TwoStatesTwoObservationsState.STATE1, 0.4)
                        .build())
                .withTransitionProbability(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsState.STATE0, 0.7)
                .withTransitionProbability(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsState.STATE1, 0.3)
                .withTransitionProbability(TwoStatesTwoObservationsState.STATE1, TwoStatesTwoObservationsState.STATE0, 0.4)
                .withTransitionProbability(TwoStatesTwoObservationsState.STATE1, TwoStatesTwoObservationsState.STATE1, 0.6)
                .withEmissionProbability(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsObservation.OBSERVATION0, 0.6)
                .withEmissionProbability(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsObservation.OBSERVATION1, 0.4)
                .withEmissionProbability(TwoStatesTwoObservationsState.STATE1, TwoStatesTwoObservationsObservation.OBSERVATION0, 0.6)
                .withEmissionProbability(TwoStatesTwoObservationsState.STATE1, TwoStatesTwoObservationsObservation.OBSERVATION1, 0.4)
                .build();

        ImmutableList<TwoStatesTwoObservationsObservation> observations = ImmutableList.of(TwoStatesTwoObservationsObservation.OBSERVATION0, TwoStatesTwoObservationsObservation.OBSERVATION0);

        ViterbiMachine<TwoStatesTwoObservationsState, TwoStatesTwoObservationsObservation> machine = new ViterbiMachine<>(model, observations);
        List<TwoStatesTwoObservationsState> states = machine.calculate();
        final List<TwoStatesTwoObservationsState> expected = ImmutableList.of(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsState.STATE0);
        assertThat(states, is(expected));
    }


    enum WikipediaState { HEALTHY, FEVER };
    enum WikipediaObservation { OK, COLD, DIZZY };

    @Test
    public void wikipediaSample() {
        ViterbiModel<WikipediaState, WikipediaObservation> model = ViterbiModel.<WikipediaState, WikipediaObservation>builder()
                .withInitialDistributions(ImmutableMap.<WikipediaState, Double>builder()
                        .put(WikipediaState.HEALTHY, 0.6)
                        .put(WikipediaState.FEVER, 0.4)
                        .build())
                .withTransitionProbability(WikipediaState.HEALTHY, WikipediaState.HEALTHY, 0.7)
                .withTransitionProbability(WikipediaState.HEALTHY, WikipediaState.FEVER, 0.3)
                .withTransitionProbability(WikipediaState.FEVER, WikipediaState.HEALTHY, 0.4)
                .withTransitionProbability(WikipediaState.FEVER, WikipediaState.FEVER, 0.6)
                .withEmissionProbability(WikipediaState.HEALTHY, WikipediaObservation.OK, 0.5)
                .withEmissionProbability(WikipediaState.HEALTHY, WikipediaObservation.COLD, 0.4)
                .withEmissionProbability(WikipediaState.HEALTHY, WikipediaObservation.DIZZY, 0.1)
                .withEmissionProbability(WikipediaState.FEVER, WikipediaObservation.OK, 0.1)
                .withEmissionProbability(WikipediaState.FEVER, WikipediaObservation.COLD, 0.3)
                .withEmissionProbability(WikipediaState.FEVER, WikipediaObservation.DIZZY, 0.6)
                .build();

        ImmutableList<WikipediaObservation> observations = ImmutableList.of(WikipediaObservation.OK, WikipediaObservation.COLD, WikipediaObservation.DIZZY);

        ViterbiMachine<WikipediaState, WikipediaObservation> machine = new ViterbiMachine<>(model, observations);
        List<WikipediaState> states = machine.calculate();
        final List<WikipediaState> expected = ImmutableList.of(WikipediaState.HEALTHY, WikipediaState.HEALTHY, WikipediaState.FEVER);
        assertThat(states, is(expected));
    }
// ... SNIP
}

还有一个关于API

的评论

这个API可能看起来很冗长,但它是迄今为止我能想到的最好的API。我以前尝试过更简洁的方法,但它们更容易出错,也更难管理大量(4-5以上)的状态/观察。

作为参考,下面是API的前面尝试:

代码语言:javascript
复制
public static int [] viterbi(int numStates, int numObservations,
        double [] initialDistrib,
        double [][] transitionProbs, double [][] emissionProbs,
        int [] observations)  // --> causes huge/unmenegeable arrays


public static List<String> viterbi(Set<String> states,
        Set<String> emissions,
        Map<Key<String>, Double> transitionProbs,
        Map<Key<String>, Double> emissionProbs,
        Map<String, Double> initProbs,
        List<String> observations) // --> a bit better, but not type safe
EN

回答 2

Code Review用户

发布于 2019-04-03 17:02:16

如果您将repo更新为包含make/ant/maven/graven构建文件,我将能够轻松地更改和运行您的代码。不能够重现您的构建环境,我可以做一些一般性的评论。

不要滚动你自己的建筑商

考虑使用谷歌的CallBuilder库来保存大量样板代码。这个库简单地通过注释构造函数就可以轻松地创建一个构建器。您可能需要实现一个自定义的“样式”类来复制您在自定义构建器中的确切行为;但是,我认为这是值得的。使用代码生成可以使构建器节省大量重复的、容易出错的代码,并有助于在整个项目中强制执行一致的构建器接口。

实际上,为所有Gauva数据结构编写CallBuilder样式类将是一个非常酷和有用的项目。但这超出了这个算法的范围。

使ViterbiModel的构造函数更易于接受

类似于:

代码语言:javascript
复制
private ViterbiModel(Map<? extends S, Double> initialDistributions, 
            Table<? extends S, ? extends S, Double> transitionProbabilities,
            Table<? extends S, ? extends T, Double> emissionProbabilities)

然后,在构造函数内部,使用ImmutableMap.copyOfImmutableTable.copyOf方法创建和存储不可变的副本。这些相同的更改需要适当地扩展到构建器。

创建一个ViterbiObservations

它应载有意见清单。它应该提供一个生成器。这是为了一致性,与VirterbiModel类匹配。

在构造函数

中执行验证

分别在适当的构造函数中验证ViterbiModelVirterbiObservations对象。在这种情况下早期失败是与用户通信的一种重要方式。如果他们能够在不抛出任何异常的情况下创建VirterbiModel,那么它应该是有效的。

更接受泛型类型

你应该

代码语言:javascript
复制
ViterbiMachine(ViterbiModel<S, ? extends T> model, ImmutableList<T> observations)

因为可以在由父类型组成的模型中发出一系列子观察。

扩展ImmutableTable

您编写的getOrDefaultrowOrDefault方法很不错。但是,它们应该属于表类本身。因此,将ImmutableTable扩展到具有这些方法的类。

内联initialize()方法

不清楚为什么这不是构造函数的一部分。

创建实用程序类

您的一些较小的函数与VirterbiMachines没有什么关系。把他们转移到另一个班去。

不强制S, T为枚举类型(

)

我不明白为什么要把这些当作墓穴。有人会想要创建一个VirterbiMachine,比如状态是整数,输出是字符串吗?当然,您的代码可以允许这样做。

票数 3
EN

Code Review用户

发布于 2019-06-08 12:56:54

本杰明的评论已经很好了上有一些挑剔的地方:

  • 使用标准库的Objects.requireNonNull‘s checkNotNull
  • 获得枚举值的方式感觉有点人为的。考虑使用values()代替。java的泛型如此脆弱,使得不可能“只”调用X.values() (这是保证存在的),这有点烦人。
  • 您也可以用标准库替换ImmutableMap依赖项(比较这就是答案)。这将允许您使用EnumMap来获得更好的性能。
  • 您还可以将表stateProbsForObservationspreviousStatesForObservations分别替换为Map<S, Double[]>Map<S, Optional<S>[]>类型的映射。它可以由EnumMap再次填充,从而进一步减少内存占用和提高性能。同样,对于大多数用途来说,这只是微不足道的。
  • 我不喜欢在ViterbiMachine的构造函数中使用异常作为流控制和验证。为了避免这种情况的发生,您可以检查您正在显式执行的操作的先决条件,而不是依赖下游方法在某个异常情况下失败。YMMV :)
  • 我不喜欢库不公开getOrDefaultrowOrDefault,但这不是您可以修复的东西:/
  • API并不意味着在调用nextStep()一次之后,calculate就会抛出一个IllegalStateException。IG,我会尽量避免使获得的结果容易出现非法的州例外。
  • 我也希望能够多次调用calculate(),但这是因为我喜欢缓存和智能和懒惰的计算器类。我只是喜欢实施这些..。
票数 2
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/184896

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档