首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >数据集预处理中的NumPy效率

数据集预处理中的NumPy效率
EN

Stack Overflow用户
提问于 2021-03-03 03:45:01
回答 1查看 59关注 0票数 0

我目前正在从事一个研究项目,该项目与使用在脑电图数据集上操作的神经网络有关。我使用的是BCICIV 2a数据集,它由一系列文件组成,其中包含来自受试者的试验数据。每个文件包含一组25个通道和一个非常长的~600000个时间步长的信号阵列。我一直在编写代码,将这些数据预处理成可以传递到神经网络中的东西,但遇到了一些效率问题。目前,我已经编写了代码来确定文件中所有试验在数组中的位置,然后尝试提取存储在另一个数组中的3D NumPy数组。然而,当我尝试运行这段代码时,它慢得离谱。我对NumPy不是很熟悉,在这一点上我的大部分经验是在C中。我的目的是将预处理的结果写到一个可以加载的单独文件中,以避免预处理。从C语言的角度来看,所有需要做的就是移动指针来适当地格式化数据,所以我不确定为什么NumPy这么慢。任何建议都将是非常有帮助的,因为目前对于1个文件,提取1个试验需要大约2分钟,一个文件中有288个试验,9个文件,这将比我希望的花费更长的时间。我对如何很好地利用NumPy对泛型列表的效率改进的知识并不是很满意。谢谢!

代码语言:javascript
复制
import glob, os
import numpy as np
import mne

DURATION = 313
XDIM = 7
YDIM = 6
IGNORE = ('EOG-left', 'EOG-central', 'EOG-right')

def getIndex(raw, tagIndex):
    return int(raw.annotations[tagIndex]['onset']*250)

def isEvent(raw, tagIndex, events):
    for event in events:
        if (raw.annotations[tagIndex]['description'] == event):
            return True
    return False

def getSlice1D(raw, channel, dur, index):
    if (type(channel) == int):
        channel = raw.ch_names[channel]
    return raw[channel][0][0][index:index+dur]

def getSliceFull(raw, dur, index):
    trial = np.zeros((XDIM, YDIM, dur))
    for channel in raw.ch_names:
        if not channel in IGNORE:
            x, y = convertIndices(channel)
            trial[x][y] = getSlice1D(raw, channel, dur, index)
    return trial

def convertIndices(channel):
    xDict = {'EEG-Fz':3, 'EEG-0':1, 'EEG-1':2, 'EEG-2':3, 'EEG-3':4, 'EEG-4':5, 'EEG-5':0, 'EEG-C3':1, 'EEG-6':2, 'EEG-Cz':3, 'EEG-7':4, 'EEG-C4':5, 'EEG-8':6, 'EEG-9':1, 'EEG-10':2, 'EEG-11':3, 'EEG-12':4, 'EEG-13':5, 'EEG-14':2, 'EEG-Pz':3, 'EEG-15':4, 'EEG-16':3}
    yDict = {'EEG-Fz':0, 'EEG-0':1, 'EEG-1':1, 'EEG-2':1, 'EEG-3':1, 'EEG-4':1, 'EEG-5':2, 'EEG-C3':2, 'EEG-6':2, 'EEG-Cz':2, 'EEG-7':2, 'EEG-C4':2, 'EEG-8':2, 'EEG-9':3, 'EEG-10':3, 'EEG-11':3, 'EEG-12':3, 'EEG-13':3, 'EEG-14':4, 'EEG-Pz':4, 'EEG-15':4, 'EEG-16':5}
    return xDict[channel], yDict[channel]

data_files = glob.glob('../datasets/BCICIV_2a_gdf/*.gdf')

try:
    raw = mne.io.read_raw_gdf(data_files[0], verbose='ERROR')
except IndexError:
    print("No data files found")

event_times = []

for i in range(len(raw.annotations)):
    if (isEvent(raw, i, ('769', '770', '771', '772'))):
        event_times.append(getIndex(raw, i))

data = np.empty((len(event_times), XDIM, YDIM, DURATION))

print(len(event_times))

for i, event in enumerate(event_times):
    data[i] = getSliceFull(raw, DURATION, event)

编辑:我想回来添加更多关于数据集结构的细节。有一个25x~600000的数组,其中包含数据和一个更短的注释对象,该对象包含事件标记并将这些标记与较大数组中的时间相关联。特定事件表明运动图像提示,这是我的网络正在训练的试验,我尝试提取3D切片,其中包括使用时间维度适当格式化的相关通道,发现该时间维度有313个时间步长。注释为我提供了相关的时间步长来进行研究。Ian推荐的性能分析结果显示,主要计算时间位于getSlice1D()函数中。特别是当我索引到原始对象时。从注释中提取事件时间的代码可以忽略不计。

EN

回答 1

Stack Overflow用户

发布于 2021-03-03 04:46:22

这是部分答案,注释中的格式是一种垃圾,但是

代码语言:javascript
复制
def getIndex(raw, tagIndex):
    return int(raw.annotations[tagIndex]['onset']*250)


def isEvent(raw, tagIndex, events):
    for event in events:
        if (raw.annotations[tagIndex]['description'] == event):
            return True
    return False

for i in range(len(raw.annotations)):
    if (isEvent(raw, i, ('769', '770', '771', '772'))):
        event_times.append(getIndex(raw, i))

请注意您是如何迭代I的。您可以做的是

代码语言:javascript
复制
def isEvent(raw_annotations_desc, raw_annotations_onset, events):
    flag_container = []

    for event in events:    # Iterate through all the events
        # Do a vectorized comparison across all the indices
        flag_container.append(raw_annotations_desc == event)
    # At this point flag_container will be of shape (|events|, len(raw_annotations_desc) 

    """
    Assuming I understand correctly, for a given index if  
        ANY of the events is true, we return true and get the index, correct?
    def getIndex(raw, tagIndex):
        return int(raw.annotations[tagIndex]['onset']*250)
    """
    flag_container = np.asarray(flag_container)  # Change raw list to np array
    
    # Python treats False as 0 and True as 1. So, we sum over the cols 
    # and we now have an array of shape (1, len(raw_annotations))
    flag_container = flag_container.sum(1)  

    # Add indices because we will use these later
    flag_container = np.asarray(np.arange(len(raw_annotations)), flag_container)

    # Almost there. Now, flag_container has 2 cols: the index AND the number of True in a given row
    
    # Gets us all the indices where the sum was greater than 1 (aka one positive)
    
    flag_container = flag_container[flag_container[1,:] > 0]  

    # Now, an array of shape (2, x <= len(raw_annotations_desc))
    flag_container = flag_container[0, :]  # We only care about the indices, not the actual count of positives now so we slice out the 0th-col

    return int(raw_annotations_onset[flag_container] * 250)

类似这样的东西:),这应该会让事情变得更快一些

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66446115

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档