我正在用python编写一个用于科学计算的函数。此函数的一个参数表示实值输入参数。如果将复数值作为此参数传递,则函数的结果将不正确,因为我没有实现针对复数值输入情况所需的特殊注意事项,但函数将返回不正确的值,而不会出现错误或异常,因为函数中的每一行在语法方面都是有效的。
例如,请考虑如下函数:
import numpy as np
def foo(vara):
"""
This function evaluates the Foo formula for the real variable vara.
This function does not work for the complex variable vara because I am
too lazy to take care of the branch cut of the complex square-root function.
"""
if vara<0:
vv = -0.57386286*vara
else:
vv = 3.49604327*vara
return np.sqrt(vv)即使参数vara是复数,函数foo也会返回一个复数值,因为numpy.sqrt函数也是为复数参数定义的,但假设函数foo在实现时只考虑了实数参数,则返回的值将是不正确的。
如何在函数中检入参数是实值的,这样才能使函数抛出异常或错误退出?
请注意,我希望让该函数既适用于python的原生float类型,也适用于float类型元素的numpy数组。我只想禁止将函数与complex变量或complex元素的numpy数组一起使用。
(我考虑将1.0j与参数相乘,并检查结果的实部是否为零,但这看起来并不简洁。)
发布于 2015-09-03 04:51:51
如果你只想禁止复杂的数据类型,这样做是可行的:
import types
scalar_complex_types = [types.ComplexType, np.complex64, np.complex128]
def is_complex_sequence(vara):
return (hasattr(vara, '__iter__')
and any(isinstance(v, t) for v in vara for t in complex_types)
def is_complex_scalar(vara):
return any(isinstance(vara, t) for t in complex_types)然后在你的函数中你就可以..
if is_complex_scalar(vara) or is_complex_sequence(vara):
raise ValueError('Argument must not be a complex number')发布于 2015-09-03 06:42:53
(我是在回答我自己的问题。我不确定这是不是最好的方法,但我想留下一个我尝试记录的代码。)
根据polpak的回答,我编写了以下代码。我猜这将满足我提出的条件。该函数是学院式的,因为它拒绝除float scaler或float ndarray之外的任何其他类型的输入参数。(也许它甚至不接受所有类型的浮动ndarray。)特别地,它拒绝整数定标器和整数ndarray以及复数定标器和复数ndarray。
#!/usr/bin/python
import numpy as np
import types
def foo(vara):
"""vara must be a real-valued scaler or ndarray."""
real_types = [types.FloatType, np.float16, np.float32, np.float64, np.float128]
print '----------'
print 'vara:', vara
if isinstance(vara, np.ndarray):
if not any(vara.dtype==t for t in real_types):
print 'NG.'
print ' type(vara)=', type(vara)
print ' vara.dtype=', vara.dtype
# raise an error here
else:
print 'OK.'
print ' type(vara)=', type(vara)
print ' vara.dtype=', vara.dtype
else:
if not any(isinstance(vara, t) for t in real_types):
print 'NG.'
print ' type(vara)=', type(vara)
# raise an error here
else:
print 'OK.'
print ' type(vara)=', type(vara)
varalist=[3.0,
np.array([0.5, 0.2]),
np.array([3, 4, 1]),
np.array([3.4+1.2j, 0.8+0.7j]),
np.array([3.4+0.0j, 0.8+0.0j]),
np.array([1.3, 4.2, 5.9], dtype=complex),
np.array([1.3, 4.2, 5.9], dtype=complex).real ]
for vara in varalist:
foo(vara)此代码的输出如下所示。
$ ./main003.py
----------
vara: 3.0
OK.
type(vara)= <type 'float'>
----------
vara: [ 0.5 0.2]
OK.
type(vara)= <type 'numpy.ndarray'>
vara.dtype= float64
----------
vara: [3 4 1]
NG.
type(vara)= <type 'numpy.ndarray'>
vara.dtype= int64
----------
vara: [ 3.4+1.2j 0.8+0.7j]
NG.
type(vara)= <type 'numpy.ndarray'>
vara.dtype= complex128
----------
vara: [ 3.4+0.j 0.8+0.j]
NG.
type(vara)= <type 'numpy.ndarray'>
vara.dtype= complex128
----------
vara: [ 1.3+0.j 4.2+0.j 5.9+0.j]
NG.
type(vara)= <type 'numpy.ndarray'>
vara.dtype= complex128
----------
vara: [ 1.3 4.2 5.9]
OK.
type(vara)= <type 'numpy.ndarray'>
vara.dtype= float64https://stackoverflow.com/questions/32362707
复制相似问题