首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >ChaCha20实现(基于RFC7539)

ChaCha20实现(基于RFC7539)
EN

Code Review用户
提问于 2017-07-24 20:00:54
回答 1查看 1.4K关注 0票数 4

滚动自己的密码学,以更好地理解主题,并提出了实现的ChaCha20算法粘贴在下面。

我使用RFC中列出的测试向量以及邦西城堡库 (并假设其正确性)来验证我的输出,到目前为止一切看起来都是准确的。我还尽可能地优化了一些东西,而不用使用SIMD指令或并行处理,测试表明,我比BouncyCastle快35%-45%,这让我对代码感觉很好。

现在我的想法已经用完了,我想,现在是时候与社区联系了,以帮助查找任何可能还需要修复的准确性或性能问题。

使用:

代码语言:javascript
复制
var chacha = new ChaCha20(
    0X03020100U,
    0X07060504U,
    0X0B0A0908U,
    0X0F0E0D0CU,
    0X13121110U,
    0X17161514U,
    0X1B1A1918U,
    0X1F1E1D1CU,
    0X00000000U,
    0X00000000U,
    0X00000000U,
    0X00000000U
);

// encrypt or decrypt file
chacha.Transform(@"C:\Temp\SomeFile.txt");

// encrypt or decrypt array
var data = new byte[128];

chacha.Transform(data);

代码:

代码语言:javascript
复制
/// <remarks>
/// https://cr.yp.to/chacha/chacha-20080128.pdf
/// https://cr.yp.to/snuffle/spec.pdf
/// https://eprint.iacr.org/2013/759.pdf
/// https://www.rfc-editor.org/rfc/rfc7539
/// http://loup-vaillant.fr/tutorials/chacha20-design
/// </remarks>
public class ChaCha20
{
    private const int BLOCK_SIZE_IN_BYTES = (STATE_SIZE_IN_BYTES * 16);
    private const int STATE_SIZE_IN_BYTES = sizeof(uint);

    [CLSCompliant(false)]
    public const uint DEFAULT_STATE0 = 0x61707865U;
    [CLSCompliant(false)]
    public const uint DEFAULT_STATE1 = 0x3320646EU;
    [CLSCompliant(false)]
    public const uint DEFAULT_STATE2 = 0x79622D32U;
    [CLSCompliant(false)]
    public const uint DEFAULT_STATE3 = 0x6B206574U;

    [CLSCompliant(false)]
    protected readonly uint m_state0, m_state1, m_state2, m_state3,
                            m_state4, m_state5, m_state6, m_state7,
                            m_state8, m_state9, m_stateA, m_stateB,
                            m_stateC, m_stateD, m_stateE, m_stateF;

    /// <summary>
    /// Initializes a new instance of the <see cref="ChaCha20"/> class to the initial state indicated by sixteen 32-bit unsigned integers.
    /// </summary>
    [CLSCompliant(false)]
    public ChaCha20(
        uint state0, uint state1, uint state2, uint state3,
        uint state4, uint state5, uint state6, uint state7,
        uint state8, uint state9, uint stateA, uint stateB,
        uint stateC, uint stateD, uint stateE, uint stateF
    ) {
        m_state0 = state0;
        m_state1 = state1;
        m_state2 = state2;
        m_state3 = state3;
        m_state4 = state4;
        m_state5 = state5;
        m_state6 = state6;
        m_state7 = state7;
        m_state8 = state8;
        m_state9 = state9;
        m_stateA = stateA;
        m_stateB = stateB;
        m_stateC = stateC;
        m_stateD = stateD;
        m_stateE = stateE;
        m_stateF = stateF;
    }
    /// <summary>
    /// Initializes a new instance of the <see cref="ChaCha20"/> class to the initial state indicated by twelve 32-bit unsigned integers.
    /// </summary>
    [CLSCompliant(false)]
    public ChaCha20(
        uint state4, uint state5, uint state6, uint state7,
        uint state8, uint state9, uint stateA, uint stateB,
        uint stateC, uint stateD, uint stateE, uint stateF
    ) : this(
        DEFAULT_STATE0, DEFAULT_STATE1, DEFAULT_STATE2, DEFAULT_STATE3,
        state4, state5, state6, state7,
        state8, state9, stateA, stateB,
        stateC, stateD, stateE, stateF
    ) { }

    public void Transform(Stream source, Stream destination) {
        var dataBuffer = new byte[BLOCK_SIZE_IN_BYTES];
        var keyStreamBuffer = new byte[BLOCK_SIZE_IN_BYTES];
        var keyStreamPosition = m_stateD.ToUInt64(m_stateC);
        var numBytesRead = 0;

        while (unchecked(BLOCK_SIZE_IN_BYTES - 1) < (numBytesRead = source.Read(dataBuffer, 0, BLOCK_SIZE_IN_BYTES))) {
            if (source == destination) {
                destination.Position -= BLOCK_SIZE_IN_BYTES;
            }

            BlockRound(this, keyStreamPosition++, keyStreamBuffer, 0UL); // get next key stream chunk
            dataBuffer.VectorXor(keyStreamBuffer, BLOCK_SIZE_IN_BYTES, 0UL, 0UL); // xor data with key stream
            destination.Write(dataBuffer, 0, BLOCK_SIZE_IN_BYTES); // write transformed data to destination
        }

        if (numBytesRead != 0) {
            if (source == destination) {
                destination.Position -= numBytesRead;
            }

            BlockRound(this, keyStreamPosition++, keyStreamBuffer, 0UL); // get next key stream chunk
            dataBuffer.VectorXor(keyStreamBuffer, unchecked((ulong)numBytesRead), 0UL, 0UL); // xor data with key stream
            destination.Write(dataBuffer, 0, numBytesRead); // write transformed data to destination
        }
    }
    public void Transform(Stream stream) {
        Transform(stream, stream);
    }
    public void Transform(byte[] data) {
        using (var memoryStream = new MemoryStream(data, true)) {
            Transform(memoryStream);
        }
    }
    public void Transform(string fileName) {
        using (var fileStream = new FileStream(fileName, FileMode.Open, FileAccess.ReadWrite, FileShare.None)) {
            Transform(fileStream);
        }
    }

    /// <summary>
    /// Fills an array of bytes with the key stream calculated from the specified <see cref="ChaCha20"/> instance and iv.
    /// </summary>
    [CLSCompliant(false)]
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static void BlockRound(ChaCha20 chacha20, ulong iv, byte[] destination, ulong destinationOffset) {
        var counterLow = checked((uint)iv.CleaveHigh());
        var counterHigh = checked((uint)iv.CleaveLow());

        var tState0 = chacha20.m_state0;
        var tState1 = chacha20.m_state1;
        var tState2 = chacha20.m_state2;
        var tState3 = chacha20.m_state3;
        var tState4 = chacha20.m_state4;
        var tState5 = chacha20.m_state5;
        var tState6 = chacha20.m_state6;
        var tState7 = chacha20.m_state7;
        var tState8 = chacha20.m_state8;
        var tState9 = chacha20.m_state9;
        var tStateA = chacha20.m_stateA;
        var tStateB = chacha20.m_stateB;
        var tStateC = counterLow;
        var tStateD = counterHigh;
        var tStateE = chacha20.m_stateE;
        var tStateF = chacha20.m_stateF;

        for (var i = 0; i < 10; i++) {
            DoubleRound(
                ref tState0, ref tState1, ref tState2, ref tState3,
                ref tState4, ref tState5, ref tState6, ref tState7,
                ref tState8, ref tState9, ref tStateA, ref tStateB,
                ref tStateC, ref tStateD, ref tStateE, ref tStateF
            );
        }

        unchecked {
            tState0 += chacha20.m_state0;
            tState1 += chacha20.m_state1;
            tState2 += chacha20.m_state2;
            tState3 += chacha20.m_state3;
            tState4 += chacha20.m_state4;
            tState5 += chacha20.m_state5;
            tState6 += chacha20.m_state6;
            tState7 += chacha20.m_state7;
            tState8 += chacha20.m_state8;
            tState9 += chacha20.m_state9;
            tStateA += chacha20.m_stateA;
            tStateB += chacha20.m_stateB;
            tStateC += counterLow;
            tStateD += counterHigh;
            tStateE += chacha20.m_stateE;
            tStateF += chacha20.m_stateF;
        }

        if (BitConverter.IsLittleEndian) {
            tState0.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 0)));
            tState1.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 1)));
            tState2.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 2)));
            tState3.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 3)));
            tState4.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 4)));
            tState5.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 5)));
            tState6.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 6)));
            tState7.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 7)));
            tState8.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 8)));
            tState9.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 9)));
            tStateA.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 10)));
            tStateB.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 11)));
            tStateC.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 12)));
            tStateD.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 13)));
            tStateE.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 14)));
            tStateF.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 15)));
        }
        else {
            tState0.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 0)));
            tState1.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 1)));
            tState2.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 2)));
            tState3.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 3)));
            tState4.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 4)));
            tState5.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 5)));
            tState6.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 6)));
            tState7.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 7)));
            tState8.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 8)));
            tState9.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 9)));
            tStateA.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 10)));
            tStateB.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 11)));
            tStateC.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 12)));
            tStateD.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 13)));
            tStateE.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 14)));
            tStateF.ReverseBytes().GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 15)));
        }
    }
    /// <summary>
    /// Executes eight QuarterRound operations (four "column-rounds" + four "row-rounds") on the specified <see cref="ChaCha20"/> instance.
    /// </summary>
    [CLSCompliant(false)]
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static void DoubleRound(
        ref uint state0, ref uint state1, ref uint state2, ref uint state3,
        ref uint state4, ref uint state5, ref uint state6, ref uint state7,
        ref uint state8, ref uint state9, ref uint stateA, ref uint stateB,
        ref uint stateC, ref uint stateD, ref uint stateE, ref uint stateF
    ) {
        QuarterRound(ref state0, ref state4, ref state8, ref stateC);
        QuarterRound(ref state1, ref state5, ref state9, ref stateD);
        QuarterRound(ref state2, ref state6, ref stateA, ref stateE);
        QuarterRound(ref state3, ref state7, ref stateB, ref stateF);
        QuarterRound(ref state0, ref state5, ref stateA, ref stateF);
        QuarterRound(ref state1, ref state6, ref stateB, ref stateC);
        QuarterRound(ref state2, ref state7, ref state8, ref stateD);
        QuarterRound(ref state3, ref state4, ref state9, ref stateE);
    }
    /// <summary>
    /// Executes the basic operation of the ChaCha20 algorithm which "mixes" four state variables per invocation.
    /// </summary>
    [CLSCompliant(false)]
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static void QuarterRound(ref uint a, ref uint b, ref uint c, ref uint d) {
        d = (d ^= unchecked(a += b)).RotateLeft(16);
        b = (b ^= unchecked(c += d)).RotateLeft(12);
        d = (d ^= unchecked(a += b)).RotateLeft(8);
        b = (b ^= unchecked(c += d)).RotateLeft(7);
    }
}
EN

回答 1

Code Review用户

发布于 2017-07-25 09:26:48

简短的回答:使用数组而不是多个参数,不要试图通过优化来击败编译器,这是你的朋友,你要做的就是帮助他。

让我们假设性能已经足够了。让我们谈谈可读性。

首先,为什么类不是sealed,它有protected字段?我在你的课堂上没有看到任何扩展点。让我们将其标记为sealed,将protected更改为private并使其打开,前提是您需要它。

我更喜欢公共/内部/保护/私人秩序,但也相反的是普遍和好的,只要选择一个合理的顺序,并坚持它。我也不喜欢蛇大写常量,但如果它是为私人康斯特字段,那么没有问题,只是不要使用它为公共的!还有:为什么这两个公共区域是公开的?!

有16个参数的方法是编写、测试和调用的噩梦。幸运的是,我们有合适的数据结构:数组。

代码语言:javascript
复制
public sealed class ChaCha20
{
    private const int NUMBER_OF_STATES = 16;
    private const int MINIMUM_NUMBER_OF_USER_DEFINED_STATES = NUMBER_OF_STATES - 4;
    private const int STATE_SIZE_IN_BYTES = sizeof(uint);
    private const int BLOCK_SIZE_IN_BYTES = STATE_SIZE_IN_BYTES * NUMBER_OF_STATES;

    private static readonly uint[] DEFAULT_STATES
        = new uint[] { 0x61707865U, 0x3320646EU, 0x79622D32U, 0x6B206574U };

    private readonly uint[] _states;

    public ChaCha20(params uint[] states)
    {
        if (states == null)
            throw new ArgumentNullException(nameof(states));

        if (states.Length == NUMBER_OF_STATES)
            _states = states;
        else if (states.Length == MINIMUM_NUMBER_OF_USER_DEFINED_STATES)
            _states = DEFAULT_STATES.Concat(states);
        else
            throw new ArgumentException("Invalid number of...");
    }
}

只有一个构造函数,易于编写,您可以使用数组(如果初始化列表在某个地方被持久化)或使用以前的语法(请参见params)调用它。

现在,如果我们在任何地方应用相同的概念,您将看到代码大大减少和简化(我想说也比现在快得多)。

举一个例子,这一行:

代码语言:javascript
复制
State0.GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * 0)));

它被重复了16次(加上小端点/大端点的代码重复),但是使用数组可以简化为一个简单的for。错误的空间更小,其他功能也是如此:

代码语言:javascript
复制
for (int i=0; i < NUMBER_OF_STATES; ++i)
    states[0].GetBytes(destination, (destinationOffset + (STATE_SIZE_IN_BYTES * i));

有一个循环(对性能来说)不好吗?编译器将决定是否更好地为您展开循环(在编译时已知迭代次数之后)。

你在放[MethodImpl(MethodImplOptions.AggressiveInlining)]。不要。编译器通常比我们更好地确定一个方法是否应该内联,以及这意味着什么。例如,内联包含长循环(在其中调用其他方法)的方法调用可能会产生相反的效果。或者不是?除非您在不同的CPU架构上执行严肃的基准测试,否则这是一个很大的猜测。让编译器做好他的工作。

这一行:

代码语言:javascript
复制
d = (d ^= unchecked(a += b)).RotateLeft(16);

太复杂了!它做的事情太多了。简洁的代码不是优化代码的同义词。把它写下来,你甚至会发现你在表达式(d ^=)中有一个无用的赋值,希望编译器没有把它写好。

代码语言:javascript
复制
unchecked
{ 
    a += b;
}

d = (d ^ a).RotateLeft(16);

有那么糟吗?将其移动到单独的方法。编译器,同样,将为您内联它。那么,在您完成对数组的转换之后,您可能不会有四个这样的调用,因为您将有一个循环(编译器可以根据需要进行优化)。

票数 2
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/171065

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档