首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何解决一个二次丢番图方程组,有效地解锁Python?

如何解决一个二次丢番图方程组,有效地解锁Python?
EN

Stack Overflow用户
提问于 2022-01-25 17:36:22
回答 1查看 109关注 0票数 0

我正在研究/评估求解二次丢番图方程组的技术方法。我的具体问题可以归结为以下两个步骤:

  1. 加载一个包含元组[sqrt(s), sqrt(t), sqrt(u), s, t, u, t+u, t+u-s, t-s]行的Textfile,其中每个元素都是一个整数。下面给出了该文件的节录。
  2. 对于该文件中的每一行:搜索一个整数四倍[w,x,y,z],它解决以下方程组:[x^2-w^2=s][y^2-w^2=t][z^2-y^2=u][z^2-w^2=t+u][z^2-x^2=t+u-s][y^2-x^2=t-s]

下面是对Textfile的删减:

代码语言:javascript
复制
520, 533, 756, 270400, 284089, 571536, 855625, 585225, 13689
672, 680, 153, 451584, 462400, 23409, 485809, 34225, 10816
756, 765, 520, 571536, 585225, 270400, 855625, 284089, 13689
612, 740, 2688, 374544, 547600, 7225344, 7772944, 7398400, 173056
644, 725, 2040, 414736, 525625, 4161600, 4687225, 4272489, 110889

到目前为止,我尝试的是使用z3求解器,它编译和运行,但不幸的是速度缓慢:

代码语言:javascript
复制
import pandas as pd
import sys
from z3 import Ints, solve

def main() -> int:
    df = pd.read_csv('tuples.txt', header=None)
    
    tuples = df.to_numpy()

    x, y, z, w = Ints('x y z w')
    for row in tuples:
        s=int(row[3])
        t=int(row[4])
        u=int(row[5])
        solve(x*x-w*w==s, y*y-w*w==t, z*z-y*y==u, w!=0)

    return 0

if __name__ == '__main__':
    sys.exit(main())

对于用Python解决这种diophantine系统的任何替代方法(最佳实践),我都会非常满意。

EN

回答 1

Stack Overflow用户

发布于 2022-01-25 17:38:05

在Python中为您创建了相当庞大但非常快速的解决方案。它应该比任何做类似事情的数学代码或z3-solver代码更快地解决问题。当然,在预计算阶段之后,只完成一次,然后可以在多次运行时重新使用(它们将所有计算数据保存到缓存文件中)。

下面的解进行两次预计算。第一个需要几分钟,它预先计算2.7 GB文件,这是一个巨大的方块过滤器。这个大小是可调整的,并且可以更小。此文件只计算一次(除非您更改设置),并在每次运行时重新使用。这种预计算是单核的(但经过一些努力,我可以使它成为多核)。

第二,预计算需要更多的时间,这个是多核的,它使用所有的CPU核.这种预计算会产生相当小的文件,即使对于较大的params值也不到1GB。此预计算表存储所有可能的带有整数边的毕达哥拉斯直角三角形。

对所有小于limit大小的导管进行预计算。将当前设置limit = 100_000更改为更大的设置,在您的情况下可能是100万。如果这个表太小,那么它将找不到一些解决办法的大导管。预计算表也存储在磁盘上,并在每次运行时重复使用(不再计算)。

第二次预计算计算下列直角三角形表。它遍历所有可能的第一个整数导尿管A(不超过极限),并找到所有可能的第二个整数导尿管B(不超过极限),使得A^2 + B^2 = C^2,其中C也是整数。然后,对于每个A,它存储了一组满足这个方程的B。C不被存储,因为它可以很容易地从A和B中计算出来。

为了快速搜索B,我建立了两个过滤器。例如,我们有任何整数K0和K1。我们可以很容易地看到,如果X是一个正方形,那么X% K0就是一个正方形,X% K1也是。因此,我们可以构建一个大小为K0的表,如果它是正方形的,那么tableX % K0是1,否则是0。这为我们提供了一个快速筛选器,用于删除所有这些绝对是非平方的X(即tableX % K0为0)。第二种K1滤波器用于第二阶段的额外滤波。

下面的Python代码可以直接运行,不需要依赖,它会自动从GitHub中获取S文件并将其缓存在磁盘上。

在完成上述两次预计算之后,所有的数千s/t/u解都在1-2秒内计算出来。最后,所有解决方案都以JSON格式存储以文件stu_solutions.100000

查找到的几乎解决方案(具有非整数Z)可以通过命令转储:

cat stu_solutions.100000 | grep false

“找到的精确解决方案”(带有整数Z)可以通过命令转储:

cat stu_solutions.100000 | grep true

该文件的其余行包含有错误的解决方案(如果表对于它们来说太小),或者在找不到w、x、y时,包含零的解决方案。如果出现错误,您必须通过设置更大的limit = ...来构建更大的表(第二次预计算)。

有必要设定至少和Max(Sqrt(s), Sqrt(t))一样大的限制。但最好把它设置成大几倍。越大就越有可能找到所有可能的解决方案。极限最多需要尽可能大的w

要按照Python代码运行,您必须安装一次PIP包python -m pip install numba numpy requests

在网上试试!

代码语言:javascript
复制
numba = None
import numba

import json, multiprocessing, time, timeit, os, math, numpy as np

if numba is None:
    class NumbaInt:
        def __getitem__(self, key):
            return None
    class numba:
        uint8, uint16, uint32, int64, uint64 = [NumbaInt() for i in range(5)]
        def njit(*pargs, **nargs):
            return lambda f: f
        def prange(*pargs):
            return range(*pargs)
        class types:
            class Tuple:
                def __init__(self, *nargs, **pargs):
                    pass
                def __call__(self, *nargs, **pargs):
                    pass
        class objmode:
            def __init__(self, *pargs, **nargs):
                pass
            def __enter__(self):
                return self
            def __exit__(self, ext, exv, tb):
                pass

@numba.njit(cache = True, parallel = True)
def create_filters():
    Ks = [np.uint32(e) for e in [2 * 2 * 3 * 5 * 7 * 11 * 13,    17 * 19 * 23 * 29 * 31 * 37]]
    filts = []
    for i in range(len(Ks)):
        K = Ks[i]
        filt = np.zeros((K,), dtype = np.uint8)
        block = 1 << 25
        nblocks = (K + block - 1) // block
        for j0 in numba.prange(nblocks):
            j = j0 * block
            a = np.arange(j, min(j + block, K)).astype(np.uint64)
            a *= a; a %= K
            filt[a] = 1
        idxs = np.flatnonzero(filt).astype(np.uint32)
        filts.append((K, filt, idxs))
        print(f'filter {i} ratio', round(len(filts[-1][2]) / K, 4))
    return filts

@numba.njit('u2[:, :, :](u4, u4[:])', cache = True, parallel = True, locals = dict(
    t = numba.uint32, s = numba.uint32, i = numba.uint32, j = numba.uint32))
def filter_chain(K, ix):
    assert len(ix) < (1 << 16)
    ix_rev = np.full((K,), len(ix), dtype = np.uint16)
    for i, e in enumerate(ix):
        ix_rev[e] = i
    r = np.zeros((len(ix), K, 2), dtype = np.uint16)
    
    print('filter chain pre-computing...')
    
    for i in numba.prange(K):
        if i % 5000 == 0 or i + 1 >= K:
            with numba.objmode():
                print(f'{i}/{K}, ', end = '', flush = True)
        for j, x in enumerate(ix):
            t, s = i, x
            while True:
                s += 2 * t + 1; s %= K
                t += 1
                if ix_rev[s] < len(ix):
                    assert t - i < (1 << 16)
                    assert t - i < K
                    r[j, i, 0] = ix_rev[s]
                    r[j, i, 1] = np.uint16(t - i)
                    break
    
    print()
    
    return r

def filter_chain_create_load(K, ix):
    fname = f'filter_chain.{K}'
    if not os.path.exists(fname):
        r = filter_chain(K, ix)
        with open(fname, 'wb') as f:
            f.write(r.tobytes())
    with open(fname, 'rb') as f:
        return np.copy(np.frombuffer(f.read(), dtype = np.uint16).reshape(len(ix), K, 2))

@numba.njit(
    #'void(i8, i8, u4, u1[:], u4[:], u2[:, :, :], u4, u1[:])',
    numba.types.Tuple([numba.uint64[:], numba.uint32[:]])(
        numba.int64, numba.int64, numba.uint32, numba.uint8[:],
        numba.uint32[:], numba.uint16[:, :, :], numba.uint32, numba.uint8[:]),
    cache = True, parallel = True,
    locals = dict(x = numba.uint64, Atpos = numba.uint64, Btpos = numba.uint64, bpos = numba.uint64))
def create_table(limit, cpu_count, k0, f0, fi0, fc0, k1, f1):
    print('Computing tables...')
    
    def gen_squares_candidates_A(cnt, lim, off, t, K, f, fi, fc):
        mark = np.zeros((np.int64(K),), dtype = np.uint8)
        while True:
            start_s = np.int64((np.int64(off) + np.int64(t) ** 2) % K)
            tK = np.uint32(np.int64(t) % np.int64(K))
            if mark[tK]:
                return np.zeros((0,), dtype = np.uint32)
            mark[tK] = 1
            if f[start_s]:
                break
            t += 1
        j = np.int64(np.searchsorted(fi, start_s))
        assert fi[j] == start_s
        r = np.zeros((np.int64(cnt),), dtype = np.uint32)
        r[0] = t
        rpos = np.int64(1)
        tK = np.uint32(np.int64(t) % np.int64(K))
        while True:
            j, dt = fc[j, tK]
            t += dt
            tK += dt
            if tK >= np.uint32(K):
                tK -= np.uint32(K)
            if t >= lim:
                return r[:rpos]
            if np.int64(rpos) >= np.int64(r.shape[0]):
                r = np.concatenate((r, np.zeros_like(r)), axis = 0)
            assert rpos < len(r)
            r[rpos] = t
            rpos += 1
    
    def gen_squares(cnt, lim, off, t, K, f, fi, fc, k1, f1):
        def is_square(x):
            assert x >= 0
            if not f1[np.int64(x) % np.uint32(k1)]:
                return False
            root = np.uint64(math.sqrt(np.float64(x)) + 0.5)
            return root * root == x
        rA = gen_squares_candidates_A(cnt, lim, off, t, K, f, fi, fc)
        r = np.zeros_like(rA)
        rpos = np.int64(0)
        for t in rA:
            if not is_square(np.int64(off) + np.int64(t) ** 2):
                continue
            assert np.int64(rpos) < np.int64(r.shape[0])
            r[rpos] = t
            rpos += 1
        return r[:rpos]
    
    with numba.objmode(gtb = 'f8'):
        gtb = time.time()
    
    search_start = 2
    cnt_limit = max(1 << 4, round(pow(limit, 0.66)))
    
    nblocks2 = cpu_count * 8
    nblocks = nblocks2 * 64
    block = (limit + nblocks - 1) // nblocks
    
    At = np.zeros((limit + 1,), dtype = np.uint64)
    Bt = np.zeros((0,), dtype = np.uint32)
    Atpos, Btpos = search_start + 1, 0
    
    with numba.objmode(tb = 'f8'):
        tb = time.time()
    for iMblock in range(0, nblocks, nblocks2):
        cur_blocks = min(nblocks, iMblock + nblocks2) - iMblock
        As = np.zeros((cur_blocks, block), dtype = np.uint64)
        As_size = np.zeros((cur_blocks,), dtype = np.uint64)
        Bs = np.zeros((cur_blocks, 1 << 16,), dtype = np.uint32)
        Bs_size = np.zeros((cur_blocks,), dtype = np.uint64)
        for iblock in numba.prange(cur_blocks):
            iblock0 = iMblock + iblock
            begin, end = max(search_start, iblock0 * block), min(limit, (iblock0 + 1) * block)
            begin = min(begin, end)
            #a = np.zeros((block,), dtype = np.uint64)
            #b = np.zeros((1 << 10,), dtype = np.uint32)
            bpos = 0
            for ix, x in enumerate(range(begin, end)):
                s = gen_squares(cnt_limit, limit, x ** 2, search_start, k0, f0, fi0, fc0, k1, f1)
                assert not (np.int64(bpos) + np.int64(s.shape[0]) > np.int64(Bs[iblock].shape[0]))
                #while np.int64(bpos) + np.int64(s.shape[0]) > np.int64(b.shape[0]):
                #    b = np.concatenate((b, np.zeros_like(b)), axis = 0)
                bpos_end = bpos + s.shape[0]
                Bs[iblock, bpos : bpos_end] = s
                As[iblock, ix] = bpos_end
                bpos = bpos_end
            As_size[iblock] = end - begin
            Bs_size[iblock] = bpos
        for iblock, (cA, cB) in enumerate(zip(As, Bs)):
            cA = cA[:As_size[iblock]]
            cB = cB[:Bs_size[iblock]]
            assert Atpos + cA.shape[0] <= At.shape[0]
            prevA = At[Atpos - 1]
            for e in cA:
                At[Atpos] = prevA + e
                Atpos += 1
            #while np.int64(Btpos) + np.int64(cB.shape[0]) > np.int64(Bt.shape[0]):
                #Bt = np.concatenate((Bt, np.zeros_like(Bt)), axis = 0)
                #Bt = np.concatenate((Bt, np.zeros(Bt.shape, dtype = np.uint32)), axis = 0)
            #assert np.int64(Btpos) + np.int64(cB.shape[0]) <= np.int64(Bt.shape[0])
            #assert cB.shape[0] > 0
            #Bt[Btpos : Btpos + cB.shape[0]] = cB
            Bt = np.concatenate((Bt, cB))
            #Btpos += cB.shape[0]
            #assert At[Atpos - 1] == Btpos
            assert At[Atpos - 1] == Bt.shape[0]
        with numba.objmode(tim = 'f8'):
            tim = max(0.001, round(time.time() - tb, 3))
        print(f'{str(min(limit, (iMblock + cur_blocks) * block) >> 10).rjust(len(str(limit >> 10)))}/{limit >> 10} K, ELA',
            round(tim / 60.0, 1), 'min, ETA', round((nblocks - (iMblock + cur_blocks)) * (tim / (iMblock + cur_blocks)) / 60.0, 1), 'min')
    
    assert Atpos == At.shape[0]
    
    with numba.objmode(gtb = 'f8'):
        gtb = time.time() - gtb
    
    print(f'Tables sizes: A {Atpos}, B {Bt.shape[0]}')
    print('Time elapsed computing tables:', round(gtb / 60.0, 1), 'min')
    
    return At, Bt
    
def table_create_load(limit, *pargs):
    fnameA = f'right_triangles_table.A.{limit}'
    fnameB = f'right_triangles_table.B.{limit}'
    if not os.path.exists(fnameA) or not os.path.exists(fnameB):
        A, B = create_table(limit, *pargs)
        with open(fnameA, 'wb') as f:
            f.write(A.tobytes())
        with open(fnameB, 'wb') as f:
            f.write(B.tobytes())
        del A, B
    with open(fnameA, 'rb') as f:
        A = np.copy(np.frombuffer(f.read(), dtype = np.uint64))
        assert A.shape[0] == limit + 1, (fnameA, A.shape[0], limit + 1)
    with open(fnameB, 'rb') as f:
        B = np.copy(np.frombuffer(f.read(), dtype = np.uint32))
        assert A[-1] == B.shape[0], (fnameB, A[-1], B.shape[0])
    print(f'Table A size {A.shape[0]}, B size {B.shape[0]}')
    return A, B

def find_solutions(tA, tB, stu):
    def is_square(x):
        root = np.uint64(math.sqrt(np.float64(x)) + 0.5)
        return bool(root * root == x), int(root)
    
    assert tA[-1] == tB.shape[0]
    
    fname = f'stu_solutions.{tA.shape[0] - 1}'
    with open(fname, 'w', encoding = 'utf-8') as fout:
        for s, t, u in stu:
            s, t, u = map(int, (s, t, u))
            r = {'stu': [s, t, u]}
            if s + 1 >= tA.shape[0]:
                r['err'] = f's {s} exceeds table A len {tA.shape[0]}'
            elif t + 1 >= tA.shape[0]:
                r['err'] = f't {t} exceeds table A len {tA.shape[0]}'
            else:
                r['res'] = []
                bs = tB[tA[s] : tA[s + 1]]
                ts = tB[tA[t] : tA[t + 1]]
                for w in np.intersect1d(bs, ts):
                    w = int(w)
                    x2 = s ** 2 + w ** 2
                    y2 = t ** 2 + w ** 2
                    x_isq, x = is_square(x2)
                    assert x_isq, (s, t, u, w, x2)
                    y_isq, y = is_square(y2)
                    assert y_isq, (s, t, u, w, x2, y2)
                    z2 = u ** 2 + y2
                    z_isq, z = is_square(z2)
                    r['res'].append({
                        'w': w,
                        'x': x,
                        'y': y,
                        'z2': z2,
                        'z2_is_square': z_isq,
                        'z': z if z_isq else math.sqrt(z2),
                    })
            fout.write(json.dumps(r, ensure_ascii = False) + '\n')
    
    print(f'STU solutions written to {fname}')

def solve(limit):
    import requests
    
    filts = create_filters()
    fc0 = filter_chain_create_load(filts[0][0], filts[0][2])
    
    tA, tB = table_create_load(limit, multiprocessing.cpu_count(),
        filts[0][0], filts[0][1], filts[0][2], fc0, filts[1][0], filts[1][1])
    
    # https://github.com/Sultanow/pythagorean/blob/main/data/pythagorean_stu_Arty_.txt?raw=true
    ifname = 'pythagorean_stu_Arty_.txt'
    iurl = f'https://github.com/Sultanow/pythagorean/blob/main/data/{ifname}?raw=true'
    if not os.path.exists(ifname):
        print(f'Downloading: {iurl}')
        data = requests.get(iurl).content
        with open(ifname, 'wb') as f:
            f.write(data)
    stu = []
    with open(ifname, 'r', encoding = 'utf-8') as f:
        for line in f:
            if not line.strip():
                continue
            if 'elapsed' in line:
                continue
            s, t, u, *_ = eval(f'[{line}]')
            stu.append([s, t, u])
    print(f'Read {len(stu)} s/t/u tuples')
    find_solutions(tA, tB, stu)
    
def main():
    limit = 100_000
    solve(limit)

if __name__ == '__main__':
    main()

输出:

代码语言:javascript
复制
filter 0 ratio 0.0224
filter 1 ratio 0.0199
Table A size 100001, B size 371720
Read 27060 s/t/u tuples
STU solutions written to stu_solutions.100000

对于50K限制,所有找到的几乎-解(其中只有Z不是整数)的例子:

代码语言:javascript
复制
{"stu": [3528, 37128, 10175], "res": [{"w": 31654, "x": 31850, "y": 48790, "z2": 2483994725, "z2_is_square": false, "z": 49839.69025786577}]}
{"stu": [7700, 12155, 5460], "res": [{"w": 10608, "x": 13108, "y": 16133, "z2": 290085289, "z2_is_square": false, "z": 17031.89035309939}]}
{"stu": [9405, 12155, 10608], "res": [{"w": 5460, "x": 10875, "y": 13325, "z2": 290085289, "z2_is_square": false, "z": 17031.89035309939}]}
{"stu": [11760, 18564, 31977], "res": [{"w": 13475, "x": 17885, "y": 22939, "z2": 1548726250, "z2_is_square": false, "z": 39353.8594041296}]}
{"stu": [14364, 18564, 13475], "res": [{"w": 31977, "x": 35055, "y": 36975, "z2": 1548726250, "z2_is_square": false, "z": 39353.8594041296}]}
{"stu": [15400, 24310, 10920], "res": [{"w": 21216, "x": 26216, "y": 32266, "z2": 1160341156, "z2_is_square": false, "z": 34063.78070619878}]}
{"stu": [18480, 29172, 13104], "res": [{"w": 21175, "x": 28105, "y": 36047, "z2": 1471101025, "z2_is_square": false, "z": 38354.93481939449}]}
{"stu": [18810, 24310, 21216], "res": [{"w": 10920, "x": 21750, "y": 26650, "z2": 1160341156, "z2_is_square": false, "z": 34063.78070619878}]}
{"stu": [21840, 43500, 30800], "res": [{"w": 14651, "x": 26299, "y": 45901, "z2": 3055541801, "z2_is_square": false, "z": 55276.95542448046}]}
{"stu": [22572, 29172, 21175], "res": [{"w": 13104, "x": 26100, "y": 31980, "z2": 1471101025, "z2_is_square": false, "z": 38354.93481939449}]}
{"stu": [23100, 36465, 16380], "res": [{"w": 31824, "x": 39324, "y": 48399, "z2": 2610767601, "z2_is_square": false, "z": 51095.67105929816}]}
{"stu": [23520, 37128, 63954], "res": [{"w": 26950, "x": 35770, "y": 45878, "z2": 6194905000, "z2_is_square": false, "z": 78707.7188082592}]}
{"stu": [28215, 36465, 31824], "res": [{"w": 16380, "x": 32625, "y": 39975, "z2": 2610767601, "z2_is_square": false, "z": 51095.67105929816}]}
{"stu": [30800, 48620, 21840], "res": [{"w": 42432, "x": 52432, "y": 64532, "z2": 4641364624, "z2_is_square": false, "z": 68127.56141239755}]}
{"stu": [36960, 37128, 31654], "res": [{"w": 10175, "x": 38335, "y": 38497, "z2": 2483994725, "z2_is_square": false, "z": 49839.69025786577}]}
{"stu": [37620, 43500, 14651], "res": [{"w": 30800, "x": 48620, "y": 53300, "z2": 3055541801, "z2_is_square": false, "z": 55276.95542448046}]}
{"stu": [37620, 48620, 42432], "res": [{"w": 21840, "x": 43500, "y": 53300, "z2": 4641364624, "z2_is_square": false, "z": 68127.56141239755}]}
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70853194

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档