首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我正在努力实现带有csv模块的Cython

我正在努力实现带有csv模块的Cython
EN

Stack Overflow用户
提问于 2021-06-29 00:12:33
回答 1查看 76关注 0票数 0

这里是原始的程序员。我的任务是清理以csv格式存储的医疗数据。

(当你阅读这篇文章时,请记住,我只是一个初学编程的人,所以感谢你的耐心)

简要介绍一下背景知识: data1是一种包含研究所需患者(约17,000名患者)的csv。PUF_ED包含数以百万计患者的急诊科数据。对于data1中的每个患者,我遍历PUF_ED,直到患者标识关键字(患者)与PUF_ED (i)中的标识关键字匹配,然后将患者数据重写到data2中,并在末尾附加来自PUF_ED的新数据。我知道熊猫的效率会高得多,但由于严格的截止日期,我没有时间学习熊猫并重写我所有的代码。我希望这里有人能帮我实现Cython。

代码语言:javascript
复制
import csv

def mortWriter():
    puf_ed = open('PUF_ED.csv', 'r')
    ed = csv.reader(puf_ed)
    csv_data1 = open('data1.csv', 'r')
    data1 = csv.reader(csv_data1)
    csv_data2 = open('data2.csv', 'w')
    data2 = csv.writer(csv_data2, 
lineterminator='\n')

    patNum=0
    for patient in data1:
        if patNum==0:
            data2.writerow(patient + ['EDDISP'])
            patNum+=1

        for i in ed:
            if patient[0] == i[0]:
                data2.writerow(patient + [i[12]])
                break

        puf_ed.seek(0)

    puf_ed.close()
    csv_data1.close()
    csv_data2.close()

在第一个for循环之前,我尝试将patient和i变量作为整数键入,如下所示:

代码语言:javascript
复制
cdef int patient
cdef int i 

但我收到错误:正在尝试索引非数组类型'int‘

当我尝试索引它们时,就像在我的代码中看到的那样。

我应该采取哪些后续步骤?非常感谢所有帮助我的人,我由衷的感谢。

EN

回答 1

Stack Overflow用户

发布于 2021-06-30 02:52:02

在您拿出“大炮”并编写自定义cython代码或使用pandas之前,我想看看您是否可以先改进基本python代码的算法。查看您的代码,您的mortWriter函数是以下步骤(忽略没有意义的patNum内容):

代码语言:javascript
复制
for each patient in the study:
    for each patient in the ED:
        if ED patient happens to match study patient:
            do_something()

这里有几件事会导致糟糕的性能:

  1. 在列表中查找对象是一个线性操作(也称为O(n)),因为您必须一次搜索一个患者,直到找到目标患者。你可以使用setdict来获得更快的、恒定时间(也称为O(1))的查找。
  2. 你应该尽量避免重复地从磁盘加载文件。这很慢。就总体速度而言,CPU操作比内存(RAM)操作快得多,内存(RAM)操作比磁盘/IO操作快得多。一遍又一遍地重置ED文件并遍历这是problematic.
  3. In的一般做法,相对于更小的列表而不是更大的列表,更喜欢重复迭代。在这种情况下,我不会为研究中的每个患者查看整个急诊科,而是颠倒这个过程:对于急诊室中的每个人,看看他们是否可以在较小的研究集中找到。这可以与点#2结合使用,因为它更容易将较小的研究集存储在内存中。

使用所有这些要点,我想出了以下代码:

代码语言:javascript
复制
import csv
import math
import numpy as np
import time

def write_random_pat_csv(file_name, max_num_pats, num_data_cols=20, min_id=0, max_id=100_000, prefix="UH"):
    num_pad_zeros = len(str(max_id - 1))
    pat_ids = np.random.randint(min_id, max_id, max_num_pats)
    pat_ids = np.unique(pat_ids)
    num_pats = len(pat_ids)
    pat_ids = [["UH" + str(i).zfill(num_pad_zeros)] for i in pat_ids]
    pat_data = np.random.rand(num_pats, num_data_cols)
    csv_file = open(file_name, "w")
    csv_data = csv.writer(csv_file, lineterminator="\n")
    for i in range(num_pats):
        csv_data.writerow(pat_ids[i] + list(pat_data[i]))
    csv_file.close()

def generate_random_pat_files():
    #num_study_pats = 17_000
    #num_ed_pats = 10_000_000
    num_study_pats = 170
    num_ed_pats = 100000
    write_random_pat_csv("data1.csv", num_study_pats)
    write_random_pat_csv("PUF_ED.csv", num_ed_pats)

def mort_writer_slow():
    puf_ed = open('PUF_ED.csv', 'r')
    ed = csv.reader(puf_ed)
    csv_data1 = open('data1.csv', 'r')
    data1 = csv.reader(csv_data1)
    csv_data2 = open('data2.csv', 'w')
    data2 = csv.writer(csv_data2, lineterminator='\n')

    for patient in data1:
        for i in ed:
            if patient[0] == i[0]:
                data2.writerow(patient + [i[12]])
                break

        puf_ed.seek(0)

    puf_ed.close()
    csv_data1.close()
    csv_data2.close()

def mort_writer_fast():
    puf_ed = open('PUF_ED.csv', 'r')
    csv_data1 = open('data1.csv', 'r')
    csv_data3 = open('data3.csv', 'w')

    ed = csv.reader(puf_ed)
    data1 = csv.reader(csv_data1)
    data3 = csv.writer(csv_data3, lineterminator='\n')

    pat_num = 0
    data1_pat_ids = set([row[0] for row in data1])
    data1_rows = {pat_id: [] for pat_id in data1_pat_ids}
    
    for i in ed:
        pat_id = i[0]
        if pat_id in data1_pat_ids:
            data1_rows[pat_id].append([i[12]])
    
    csv_data1.seek(0)
    for i in data1:
        pat_id = i[0]
        rows = data1_rows[pat_id]
        for row in rows:
            data3.writerow(i + row)

    puf_ed.close()
    csv_data1.close()
    csv_data3.close()

if __name__ == "__main__":
    #pass
    start = time.time()
    generate_random_pat_files()
    end = time.time()
    print(end - start)

    start = time.time()
    mort_writer_slow()
    end = time.time()
    print(end - start)

    start = time.time()
    mort_writer_fast()
    end = time.time()
    print(end - start)

这比原始代码快50-100倍,至少在我尝试的较小数据集上是这样。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68166714

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档