我目前正在从事一个研究项目,该项目与使用在脑电图数据集上操作的神经网络有关。我使用的是BCICIV 2a数据集,它由一系列文件组成,其中包含来自受试者的试验数据。每个文件包含一组25个通道和一个非常长的~600000个时间步长的信号阵列。我一直在编写代码,将这些数据预处理成可以传递到神经网络中的东西,但遇到了一些效率问题。目前,我已经编写了代码来确定文件中所有试验在数组中的位置,然后尝试提取存储在另一个数组中的3D NumPy数组。然而,当我尝试运行这段代码时,它慢得离谱。我对NumPy不是很熟悉,在这一点上我的大部分经验是在C中。我的目的是将预处理的结果写到一个可以加载的单独文件中,以避免预处理。从C语言的角度来看,所有需要做的就是移动指针来适当地格式化数据,所以我不确定为什么NumPy这么慢。任何建议都将是非常有帮助的,因为目前对于1个文件,提取1个试验需要大约2分钟,一个文件中有288个试验,9个文件,这将比我希望的花费更长的时间。我对如何很好地利用NumPy对泛型列表的效率改进的知识并不是很满意。谢谢!
import glob, os
import numpy as np
import mne
DURATION = 313
XDIM = 7
YDIM = 6
IGNORE = ('EOG-left', 'EOG-central', 'EOG-right')
def getIndex(raw, tagIndex):
return int(raw.annotations[tagIndex]['onset']*250)
def isEvent(raw, tagIndex, events):
for event in events:
if (raw.annotations[tagIndex]['description'] == event):
return True
return False
def getSlice1D(raw, channel, dur, index):
if (type(channel) == int):
channel = raw.ch_names[channel]
return raw[channel][0][0][index:index+dur]
def getSliceFull(raw, dur, index):
trial = np.zeros((XDIM, YDIM, dur))
for channel in raw.ch_names:
if not channel in IGNORE:
x, y = convertIndices(channel)
trial[x][y] = getSlice1D(raw, channel, dur, index)
return trial
def convertIndices(channel):
xDict = {'EEG-Fz':3, 'EEG-0':1, 'EEG-1':2, 'EEG-2':3, 'EEG-3':4, 'EEG-4':5, 'EEG-5':0, 'EEG-C3':1, 'EEG-6':2, 'EEG-Cz':3, 'EEG-7':4, 'EEG-C4':5, 'EEG-8':6, 'EEG-9':1, 'EEG-10':2, 'EEG-11':3, 'EEG-12':4, 'EEG-13':5, 'EEG-14':2, 'EEG-Pz':3, 'EEG-15':4, 'EEG-16':3}
yDict = {'EEG-Fz':0, 'EEG-0':1, 'EEG-1':1, 'EEG-2':1, 'EEG-3':1, 'EEG-4':1, 'EEG-5':2, 'EEG-C3':2, 'EEG-6':2, 'EEG-Cz':2, 'EEG-7':2, 'EEG-C4':2, 'EEG-8':2, 'EEG-9':3, 'EEG-10':3, 'EEG-11':3, 'EEG-12':3, 'EEG-13':3, 'EEG-14':4, 'EEG-Pz':4, 'EEG-15':4, 'EEG-16':5}
return xDict[channel], yDict[channel]
data_files = glob.glob('../datasets/BCICIV_2a_gdf/*.gdf')
try:
raw = mne.io.read_raw_gdf(data_files[0], verbose='ERROR')
except IndexError:
print("No data files found")
event_times = []
for i in range(len(raw.annotations)):
if (isEvent(raw, i, ('769', '770', '771', '772'))):
event_times.append(getIndex(raw, i))
data = np.empty((len(event_times), XDIM, YDIM, DURATION))
print(len(event_times))
for i, event in enumerate(event_times):
data[i] = getSliceFull(raw, DURATION, event)编辑:我想回来添加更多关于数据集结构的细节。有一个25x~600000的数组,其中包含数据和一个更短的注释对象,该对象包含事件标记并将这些标记与较大数组中的时间相关联。特定事件表明运动图像提示,这是我的网络正在训练的试验,我尝试提取3D切片,其中包括使用时间维度适当格式化的相关通道,发现该时间维度有313个时间步长。注释为我提供了相关的时间步长来进行研究。Ian推荐的性能分析结果显示,主要计算时间位于getSlice1D()函数中。特别是当我索引到原始对象时。从注释中提取事件时间的代码可以忽略不计。
发布于 2021-03-03 04:46:22
这是部分答案,注释中的格式是一种垃圾,但是
def getIndex(raw, tagIndex):
return int(raw.annotations[tagIndex]['onset']*250)
def isEvent(raw, tagIndex, events):
for event in events:
if (raw.annotations[tagIndex]['description'] == event):
return True
return False
for i in range(len(raw.annotations)):
if (isEvent(raw, i, ('769', '770', '771', '772'))):
event_times.append(getIndex(raw, i))请注意您是如何迭代I的。您可以做的是
def isEvent(raw_annotations_desc, raw_annotations_onset, events):
flag_container = []
for event in events: # Iterate through all the events
# Do a vectorized comparison across all the indices
flag_container.append(raw_annotations_desc == event)
# At this point flag_container will be of shape (|events|, len(raw_annotations_desc)
"""
Assuming I understand correctly, for a given index if
ANY of the events is true, we return true and get the index, correct?
def getIndex(raw, tagIndex):
return int(raw.annotations[tagIndex]['onset']*250)
"""
flag_container = np.asarray(flag_container) # Change raw list to np array
# Python treats False as 0 and True as 1. So, we sum over the cols
# and we now have an array of shape (1, len(raw_annotations))
flag_container = flag_container.sum(1)
# Add indices because we will use these later
flag_container = np.asarray(np.arange(len(raw_annotations)), flag_container)
# Almost there. Now, flag_container has 2 cols: the index AND the number of True in a given row
# Gets us all the indices where the sum was greater than 1 (aka one positive)
flag_container = flag_container[flag_container[1,:] > 0]
# Now, an array of shape (2, x <= len(raw_annotations_desc))
flag_container = flag_container[0, :] # We only care about the indices, not the actual count of positives now so we slice out the 0th-col
return int(raw_annotations_onset[flag_container] * 250)类似这样的东西:),这应该会让事情变得更快一些
https://stackoverflow.com/questions/66446115
复制相似问题