首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >锈蚀中的快速近似sin/cos函数

锈蚀中的快速近似sin/cos函数
EN

Code Review用户
提问于 2023-03-22 18:43:16
回答 1查看 169关注 0票数 3

在过去的一个月左右,我一直试图创建一个非常快,平台不可知论,自动矢量化的sin/cos函数的乐趣。我最初从斯利弗氏快速罪函数开始,并将它从库中分离出来,使其更易于扩展,并进一步优化了它。

在我的优化过程中,我一直在密切关注在多个平台上生成的程序集和llvm,以确保某些平台上的进度不会导致其他平台上的回归。现在,这比SLEEF和Intel SVML更快,然而,它们通常用于更精确的结果。

下面是当前的代码:

代码语言:javascript
复制
#![feature(core_intrinsics)]
#![no_std]

use core::intrinsics::*;

/// Inputs valid between [-2^23, 2^23].
/// Precision can set between 0 and 3, with 0 being the fastest and least
/// precise, and 3 being the slowest and most precise.
/// If COS is set to true, the period is offset by PI/2.
///
/// As the inputs get further from 0, the accuracy gets continuously worse
/// due to nature of the fast range reduction.
///
/// This function should auto vectorize under LLVM with opt-level=3.
///
/// The coefficient constants were derived from the constants defined here:
/// https://publik-void.github.io/sin-cos-approximations/#_cos_abs_error_minimized_degree_6
#[inline]
pub unsafe fn sin_fast_approx(x: f32) -> f32 {
    let pi_multiples = fadd_fast(
        fmul_fast(x, core::f32::consts::FRAC_1_PI),
        if COS { 0.0_f32 } else { -0.5_f32 },
    );
    let rounded_multiples = nearbyintf32(pi_multiples);
    let pi_fraction = pi_multiples - rounded_multiples;
    let fraction_squared = pi_fraction * pi_fraction;

    let coeffs = {
        const COEFF_TABLE: [f32; 14] = [
            -4.0_f32,
            0.9719952_f32,
            3.5838444_f32,
            -4.8911867_f32,
            0.99940324_f32,
            -1.2221271_f32,
            4.0412836_f32,
            -4.933938_f32,
            0.9999933_f32,
            0.2196968_f32,
            -1.3318802_f32,
            4.058412_f32,
            -4.934793_f32,
            0.99999994_f32,
        ];

        let shifted_degree = PRECISION + 1;
        let slice_start = (((shifted_degree * shifted_degree) + shifted_degree) / 2) - 1;
        let slice_end = slice_start + PRECISION + 2;
        &COEFF_TABLE[slice_start..slice_end]
    };

    let mut output = coeffs[0];
    for i in 1..coeffs.len() {
        output = fadd_fast(fmul_fast(fraction_squared, output), coeffs[i]);
    }

    let parity_sign = (rounded_multiples.to_int_unchecked::() as u32) << 31_u32;
    f32::from_bits(output.to_bits() ^ parity_sign)
}

编译器资源管理器上的代码和内置基准测试 (请记住,用于运行基准测试的硬件在测试之间并不一致。我建议查看llvm来检查当前的体系结构以解释结果。)

我一直在关注的主要问题是分离的截断和舍入,如果可能的话,尝试将它们结合起来。我试着添加和减去2^23,在这两者之间移动奇偶校验位,但这在英特尔处理器上速度更快。我试着把操作的各个部分结合起来,但这样做的速度要慢一些。

我能够做的是,使用特定于平台的指令,将轮和int转换合并到一个vcvtps2pd中,并将其转换回一个vcvtpd2ps,但我似乎无法让LLVM在任何情况下生成它。如果我能够生成这个,对于拥有6-8个周期延迟的vroundps指令的英特尔CPU来说,这将是一个相当大的性能提升。

有什么方法可以使这个更快,更准确(而不牺牲性能),或者更干净?

EN

回答 1

Code Review用户

回答已采纳

发布于 2023-03-25 18:31:09

我会听一些大人物的话做这个老派的。阿波罗11号制导计算机逼近正弦 (其作者包括玛格丽特·汉密尔顿)计算了一个三项多项式,.7853134·x - .3216147·x³+ 0.036551·x⁵(接近但与泰勒级数近似不完全相同),它运行在一台比你的USB端口更强大的计算机上,并且足够精确地将宇宙飞船送上月球。

但那早在我的时代之前。至少她有个FPU!在80年代,人们常用的方法是保持一个1/8的单位圆的表格,然后旋转或反射查找该表的实际sin或cos值,使用差示和/或倒转标牌。

对于kick,我继续并实现了一个类似的解决方案,它在第一个octant中保留一个sin值的const表,使用三角函数恒等式将所有其他值映射到第一个octant,并将圆圈映射到表中最近的离散条目。

我怀疑这是否真的会更快,因为它有不可预测的分支,而一个更大的、具有可接受的精确性的表会导致大量缓存丢失。此外,Rust不支持conststatic计算中的浮点运算,因此您需要在源中生成一个非常大的数组表达式,或者使用lazy_static。就像我说的为了踢。

不过,从理论上讲,由于8:1映射提取了正弦位和两位精度,所以您可以使用“仅”16 MiB的内存“只”使用“只”4 Mi条目来覆盖任何MiB的24位尾数限制。想知道玛格丽特·汉密尔顿会怎么说。

在没有FPU的计算机上,程序员实际上所做的就是计算一个只使用整数数学的固定精度表。如果你能做到这一点,你实际上可以把它变成一个const fn

代码语言:javascript
复制
use std::f64::consts::TAU;

const SIN_TABLE_ENTRIES : usize = 32;
const SIN_OCTANT : [f64; SIN_TABLE_ENTRIES + 1] = [ 
    0.0,  0.024541228522912288, 0.049067674327418015, 0.07356456359966743,
    0.0980171403295606, 0.1224106751992162, 0.14673047445536175, 0.17096188876030122,
    0.19509032201612825, 0.2191012401568698, 0.24298017990326387, 0.26671275747489837,
    0.29028467725446233, 0.3136817403988915, 0.33688985339222005, 0.3598950365349881,
    0.3826834323650898, 0.40524131400498986, 0.4275550934302821, 0.44961132965460654,
    0.47139673682599764, 0.49289819222978404, 0.5141027441932217, 0.5349976198870972,
    0.5555702330196022, 0.5758081914178453, 0.5956993044924334, 0.6152315905806268,
    0.6343932841636455, 0.6531728429537768, 0.6715589548470183, 0.6895405447370668,
    0.0 // THis entry is a workaround for calculating a modulus of SIN_TABLE_ENTRIES instead of 0,
 ];


pub fn sin_approx(theta : f64) -> f64 {
    let approx_gradient = (theta*8.0*SIN_TABLE_ENTRIES as f64/TAU).round() as i64 % (SIN_TABLE_ENTRIES*8) as i64;
    let normalized = if approx_gradient < 0 {
            (approx_gradient + 8*SIN_TABLE_ENTRIES as i64) as usize 
        } else {
            approx_gradient as usize
        };

    if normalized < SIN_TABLE_ENTRIES {
        SIN_OCTANT[normalized]
    } else if normalized < 2*SIN_TABLE_ENTRIES {
        let complement = SIN_OCTANT[2*SIN_TABLE_ENTRIES - normalized];
        (1.0 - complement*complement).sqrt()
    } else if normalized < 3*SIN_TABLE_ENTRIES {
        let complement = SIN_OCTANT[normalized-2*SIN_TABLE_ENTRIES];
        (1.0 - complement*complement).sqrt()
    } else if normalized < 4*SIN_TABLE_ENTRIES {
        SIN_OCTANT[4*SIN_TABLE_ENTRIES - normalized]
    } else if normalized < 5*SIN_TABLE_ENTRIES {
        -SIN_OCTANT[normalized - 4*SIN_TABLE_ENTRIES]
    } else if normalized < 6*SIN_TABLE_ENTRIES {
        let complement = SIN_OCTANT[6*SIN_TABLE_ENTRIES-normalized];
        -(1.0 - complement*complement).sqrt()
    } else if normalized < 7*SIN_TABLE_ENTRIES {
        let complement = SIN_OCTANT[normalized-6*SIN_TABLE_ENTRIES];
        -(1.0 - complement*complement).sqrt()
    } else if normalized < 8*SIN_TABLE_ENTRIES {
        -SIN_OCTANT[8*SIN_TABLE_ENTRIES-normalized]
    } else {
        panic!("The normalized index was not normalized!")
    }
}

还有一个快速测试:

代码语言:javascript
复制
pub fn main() {
    (-24 as i32..=24).map(move|n|{n as f64 * TAU/12.0})
                     .map(move|x|{ ( x, x.sin(), sin_approx(x) ) })
                     .for_each(move|(theta, x1, x2)| {
                          println!("sin {} ≈ {} ≈ {}", theta, x1, x2); } )
}

更新

与实际的代码评审相比,这个答案更像是一种不同的解决方案,所以我希望您不会认为这个站点是典型的。我想调查一些关于你的答案的事情,最初没有时间去完成,我发现的是,LGTM,运送它。

我可能推荐的一件事是,在我看来,如果您将系数数组的计算写成类似于以match PRECISION结尾的_ => !unreachable()块的话,它会更干净。这使得某个人将PRECISION设置为5的错误很快就会失败,并在代码中说明了失败的原因。但是,已经有一条注释解释了PRECISION的允许值。让系数const和浮点数字的const数组成为可能也会更好,而不是使用浮点运算的const表达式。但是正如您所知道的,Rust阻止您使用外部函数的const属性。

票数 1
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/284119

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档