首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在python中,是否存在由因果关系影响生成的保存情节?

在python中,是否存在由因果关系影响生成的保存情节?
EN

Stack Overflow用户
提问于 2022-02-18 04:46:06
回答 1查看 189关注 0票数 0

下面是示例代码

代码语言:javascript
复制
import pandas as pd
from causalimpact import CausalImpact
data = pd.read_csv('https://raw.githubusercontent.com/WillianFuks/tfcausalimpact/master/tests/fixtures/arma_data.csv')[['y', 'X']]
data.iloc[70:, 0] += 5

pre_period = [0, 69]
post_period = [70, 99]

ci = CausalImpact(data, pre_period, post_period)
ci.plot()

我想写上面生成的绘图到html或至少保存为图像。是否有任何解决方案,因为ci.plot()类型是非类型的。https://github.com/WillianFuks/tfcausalimpact

EN

回答 1

Stack Overflow用户

发布于 2022-09-10 20:20:37

一个非常肮脏(但有效)的解决方案是:

  1. 使用

获取ci.plot()的代码

代码语言:javascript
复制
import inspect
print(inspect.getsource(ci.plot))

  1. 基于ci.plot()创建一个新函数,它实际上保存了绘图(或者很可能重写类的方法)。在我的例子中,它是带有新参数路径的函数plot2。它与原始函数的唯一区别是最后一行。

代码语言:javascript
复制
def plot2(self, path, panels=['original', 'pointwise', 'cumulative'], figsize=(15, 12)):
    """Plots inferences results related to causal impact analysis.

    Args
    ----
      panels: list.
        Indicates which plot should be considered in the graphics.
      figsize: tuple.
        Changes the size of the graphics plotted.

    Raises
    ------
      RuntimeError: if inferences were not computed yet.
    """
    plt = self._get_plotter()
    fig = plt.figure(figsize=figsize)
    if self.summary_data is None:
        raise RuntimeError('Please first run inferences before plotting results')

    valid_panels = ['original', 'pointwise', 'cumulative']
    for panel in panels:
        if panel not in valid_panels:
            raise ValueError(
                '"{}" is not a valid panel. Valid panels are: {}.'.format(
                    panel, ', '.join(['"{}"'.format(e) for e in valid_panels])
                )
            )

    # First points can be noisy due approximation techniques used in the likelihood
    # optimizaion process. We remove those points from the plots.
    llb = self.trained_model.filter_results.loglikelihood_burn
    inferences = self.inferences.iloc[llb:]

    intervention_idx = inferences.index.get_loc(self.post_period[0])
    n_panels = len(panels)
    ax = plt.subplot(n_panels, 1, 1)
    idx = 1

    if 'original' in panels:
        ax.plot(pd.concat([self.pre_data.iloc[llb:, 0], self.post_data.iloc[:, 0]]),
                'k', label='y')
        ax.plot(inferences['preds'], 'b--', label='Predicted')
        ax.axvline(inferences.index[intervention_idx - 1], c='k', linestyle='--')
        ax.fill_between(
            self.pre_data.index[llb:].union(self.post_data.index),
            inferences['preds_lower'],
            inferences['preds_upper'],
            facecolor='blue',
            interpolate=True,
            alpha=0.25
        )
        ax.grid(True, linestyle='--')
        ax.legend()
        if idx != n_panels:
            plt.setp(ax.get_xticklabels(), visible=False)
        idx += 1

    if 'pointwise' in panels:
        ax = plt.subplot(n_panels, 1, idx, sharex=ax)
        ax.plot(inferences['point_effects'], 'b--', label='Point Effects')
        ax.axvline(inferences.index[intervention_idx - 1], c='k', linestyle='--')
        ax.fill_between(
            inferences['point_effects'].index,
            inferences['point_effects_lower'],
            inferences['point_effects_upper'],
            facecolor='blue',
            interpolate=True,
            alpha=0.25
        )
        ax.axhline(y=0, color='k', linestyle='--')
        ax.grid(True, linestyle='--')
        ax.legend()
        if idx != n_panels:
            plt.setp(ax.get_xticklabels(), visible=False)
        idx += 1

    if 'cumulative' in panels:
        ax = plt.subplot(n_panels, 1, idx, sharex=ax)
        ax.plot(inferences['post_cum_effects'], 'b--',
                label='Cumulative Effect')
        ax.axvline(inferences.index[intervention_idx - 1], c='k', linestyle='--')
        ax.fill_between(
            inferences['post_cum_effects'].index,
            inferences['post_cum_effects_lower'],
            inferences['post_cum_effects_upper'],
            facecolor='blue',
            interpolate=True,
            alpha=0.25
        )
        ax.grid(True, linestyle='--')
        ax.axhline(y=0, color='k', linestyle='--')
        ax.legend()

    # Alert if points were removed due to loglikelihood burning data
    if llb > 0:
        text = ('Note: The first {} observations were removed due to approximate '
                'diffuse initialization.'.format(llb))
        fig.text(0.1, 0.01, text, fontsize='large')

    plt.savefig(path)

  1. 调用新函数:

代码语言:javascript
复制
plot2(ci, 'myplot.png')
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71168478

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档