首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用numpy.piecewise的问题

使用numpy.piecewise的问题
EN

Stack Overflow用户
提问于 2020-06-03 19:15:54
回答 1查看 498关注 0票数 0

1.核心问题和问题

下面我将提供一个可执行的示例,但首先让我向您介绍这个问题。

我正在使用来自scipy.integratesolve_ivp来解决初始值问题(见文件)。事实上,我必须给求解者打两次电话,一次向前,一次向后。(我将不得不深入到我的具体问题中去解释为什么这是必要的,但是请相信我-请相信我!

代码语言:javascript
复制
sol0 = solve_ivp(rhs,[0,-1e8],y0,rtol=10e-12,atol=10e-12,dense_output=True)
sol1 = solve_ivp(rhs,[0, 1e8],y0,rtol=10e-12,atol=10e-12,dense_output=True)

这里rhs是初值问题y(t) = rhs(t,y)的右手边函数。在我的例子中,y有六个组件-- y[0]y[5]y0=y(0)是初始条件。[0,±1e8]是各自的集成范围,一个向前,另一个向后。rtolatol是公差。

重要的是,您看到我标记了dense_output=True,这意味着求解器不仅返回数值网格上的解,而且作为插值函数sol0.sol(t)sol1.sol(t)返回。

我现在的主要目标是定义一个分段函数,比如,sol(t) ,它将值 sol0.sol(t) 用于 t<0 ,而值 sol1.sol(t) 用于E 237 t>=0**.。所以主要的问题是:我该怎么做?**

我认为numpy.piecewise应该是为我这样做的首选工具。但我在使用它时遇到了困难,正如你将在下面看到的,在这里我向你展示了我迄今所做的尝试。

2.示例代码

下面框中的代码解决了我的示例的初始值问题。大部分代码是rhs函数的定义,其细节对这个问题并不重要。

代码语言:javascript
复制
import numpy as np
from scipy.integrate import solve_ivp

# aux definitions and constants
sin=np.sin; cos=np.cos; tan=np.tan; sqrt=np.sqrt; pi=np.pi;  
c  = 299792458
Gm = 5.655090674872875e26    

# define right hand side function of initial value problem, y'(t) = rhs(t,y)
def rhs(t,y):
    p,e,i,Om,om,f = y
    sinf=np.sin(f); cosf=np.cos(f); Q=sqrt(p/Gm); opecf=1+e*cosf;        

    R = Gm**2/(c**2*p**3)*opecf**2*(3*(e**2 + 1) + 2*e*cosf - 4*e**2*cosf**2)
    S = Gm**2/(c**2*p**3)*4*opecf**3*e*sinf         

    rhs    = np.zeros(6)
    rhs[0] = 2*sqrt(p**3/Gm)/opecf*S
    rhs[1] = Q*(sinf*R + (2*cosf + e*(1 + cosf**2))/opecf*S)
    rhs[2] = 0
    rhs[3] = 0
    rhs[4] = Q/e*(-cosf*R + (2 + e*cosf)/opecf*sinf*S)
    rhs[5] = sqrt(Gm/p**3)*opecf**2 + Q/e*(cosf*R - (2 + e*cosf)/opecf*sinf*S)

    return rhs

# define initial values, y0
y0=[3.3578528933149297e13,0.8846,2.34921,3.98284,1.15715,0]

# integrate twice from t = 0, once backward in time (sol0) and once forward in time (sol1)
sol0 = solve_ivp(rhs,[0,-1e8],y0,rtol=10e-12,atol=10e-12,dense_output=True)
sol1 = solve_ivp(rhs,[0, 1e8],y0,rtol=10e-12,atol=10e-12,dense_output=True)

从这里可以分别用sol0.solsol1.sol来求解求解函数。作为一个例子,让我们绘制第四个组件:

代码语言:javascript
复制
from matplotlib import pyplot as plt

t0 = np.linspace(-1,0,500)*1e8
t1 = np.linspace( 0,1,500)*1e8
plt.plot(t0,sol0.sol(t0)[4])
plt.plot(t1,sol1.sol(t1)[4])
plt.title('plot 1')
plt.show()

3.未能建立分段功能

直接从sol0.sol sol1.sol和构建向量值分段函数

代码语言:javascript
复制
def sol(t): return np.piecewise(t,[t<0,t>=0],[sol0.sol,sol1.sol])
t = np.linspace(-1,1,1000)*1e8
print(sol(t))

这将导致./numpy/lib/function_base.py第628行分段中的以下错误:

代码语言:javascript
复制
TypeError: NumPy boolean array indexing assignment requires a 0 or 1-dimensional input, input has 2 dimensions

我不确定,但我确实认为这是因为以下原因:在零碎文件中,它提到了第三个论点:

可调用函数列表,f(x,*args,**kw)或标量 ……它应该以一个一维数组作为输入,并将一个一维数组或一个标量值作为输出。……

我想问题是,在我的例子中,解决方案有六个组件。因此,在时间网格上计算输出将是一个2d数组。有人能证实,这确实是问题所在吗?因为我认为这实际上限制了piecewise的有用性。

3.2尝试相同的方法,但只对一个组件(例如,第4部分)使用

代码语言:javascript
复制
def sol4(t): return np.piecewise(t,[t<0,t>=0],[sol0.sol(t)[4],sol1.sol(t)[4]])
t = np.linspace(-1,1,1000)*1e8
print(sol4(t))

这导致上述文件第624行中出现此错误:

代码语言:javascript
复制
ValueError: NumPy boolean array indexing assignment cannot assign 1000 input values to the 500 output values where the mask is true

与以前的错误相反,不幸的是,到目前为止,我还不知道为什么它不起作用。

3.3类似的尝试,但是第一次为第四个组件定义函数

代码语言:javascript
复制
def sol40(t): return sol0.sol(t)[4]
def sol41(t): return sol1.sol(t)[4]
def sol4(t): return np.piecewise(t,[t<0,t>=0],[sol40,sol41])
t = np.linspace(-1,1,1000)
plt.plot(t,sol4(t))
plt.title('plot 2')
plt.show()

现在,这不会导致错误,我可以生成一个图,但是这个图看起来不应该。应该像上面的第一幅图。另外,到目前为止,我还不知道发生了什么。

感谢你的帮助!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-03 22:00:43

您可以查看numpy.piecewise源代码。这个函数没有什么特别之处,所以我建议手动完成所有操作。

代码语言:javascript
复制
def sol(t):
    ans = np.empty((6, len(t)))
    ans[:, t<0] = sol0.sol(t[t<0])
    ans[:, t>=0] = sol1.sol(t[t>=0])
    return ans

关于你失败的尝试。是的,piecewise函数返回一维数组。您的第二次尝试失败了,因为文档中说funclist参数应该是函数或标量的列表,但是您发送数组列表。与文档相反,它甚至适用于数组,您只需使用与t < 0t >= 0相同大小的数组,如:

代码语言:javascript
复制
def sol4(t): return np.piecewise(t,[t<0,t>=0],[sol0.sol(t[t<0])[4],sol1.sol(t[t>=0])[4]])
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62180930

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档