首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >一致性指标计算

一致性指标计算
EN

Code Review用户
提问于 2018-08-21 15:13:42
回答 1查看 1.4K关注 0票数 3

我试图为生存分析计算一个定制的一致性指数。下面是我的密码。它对于小型输入数据文件运行良好,但在具有100万行(>30分钟)的dataframe上运行非常慢。

代码语言:javascript
复制
import pandas as pd

def c_index1(y_pred, events, times):
    df = pd.DataFrame(data={'proba':y_pred, 'event':events, 'time':times})
    n_total_correct = 0
    n_total_comparable = 0
    df = df.sort_values(by=['time'])
    for i, row in df.iterrows():
        if row['event'] == 1:
            comparable_rows = df[(df['event'] == 0) & (df['time'] > row['time'])]
            n_correct_rows = len(comparable_rows[comparable_rows['proba'] < row['proba']])
            n_total_correct += n_correct_rows
            n_total_comparable += len(comparable_rows)

    return n_total_correct / n_total_comparable if n_total_comparable else None


c = c_index([0.1, 0.3, 0.67, 0.45, 0.56], [1.0,0.0,1.0,0.0,1.0], [3.1,4.5,6.7,5.2,3.4])
print(c) # print 0.5

每一行(如有关系.):

  • 如果该行的事件为1:检索其的所有可比较行
    1. 索引较大(避免重复计算),
    2. 事件为0,并且
    3. 时间大于当前行的时间。在可比行中,概率小于当前行的行是正确的预测。

我想是因为for循环的缘故,它很慢。我该怎么加速呢?

EN

回答 1

Code Review用户

发布于 2018-08-22 12:11:10

在您能够将操作向量化之前,您将不会得到戏剧性的加速,但以下是一些提示

迭代

之前的

索引

而不是

代码语言:javascript
复制
for i, row in df.iterrows():
    if row['event'] == 1:

如果你这样做了

代码语言:javascript
复制
for i, row in df[df['event'] == 1].rows():

您将对较少的行进行迭代。

迭代

一般来说,itertuplesiterrows

comparable_rows

对于comparable_rows,您只对proba和长度感兴趣,所以您最好把它变成一个系列,甚至更好的是,一个numpy数组。

测试(df['event'] == 0)在迭代过程中不会改变,因此可以在循环之外定义一个df2 = df[df['event'] == 0]

n_correct_rows

而不是len(comparable_rows[comparable_rows['proba'] < row['proba']]),您可以使用True == 1(comparable_rows['proba'] < row.proba).sum()这一事实

结果

代码语言:javascript
复制
def c_index3(y_pred, events, times):
    df = pd.DataFrame(data={'proba':y_pred, 'event':events, 'time':times})
    n_total_correct = 0
    n_total_comparable = 0
    df = df.sort_values(by=['time'])
    df2 = df.loc[df['event'] == 0]
    for row in df[df['event'] == 1].itertuples():
        comparable_rows = df2.loc[(df2['time'] > row.time), 'proba'].values
        n_correct_rows = (comparable_rows < row.proba).sum()
        n_total_correct += n_correct_rows
        n_total_comparable += len(comparable_rows)

    return n_total_correct / n_total_comparable if n_total_comparable else N

时间

代码语言:javascript
复制
data = ([0.1, 0.3, 0.67, 0.45, 0.56], [1.0,0.0,1.0,0.0,1.0], [3.1,4.5,6.7,5.2,3.4])
%timeit c_index1(*data)

5.17 ms±33.6 ms/环S(平均±std )。dev.7次运行中,每一次循环100次)

代码语言:javascript
复制
%timeit c_index3(*data)

3.77ms±160 ms /环S(平均±std )。dev.7次运行中,每一次循环100次)

票数 2
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/202140

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档