首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在python中确保输入参数不是复数值而是实数值

如何在python中确保输入参数不是复数值而是实数值
EN

Stack Overflow用户
提问于 2015-09-03 04:41:14
回答 2查看 405关注 0票数 1

我正在用python编写一个用于科学计算的函数。此函数的一个参数表示实值输入参数。如果将复数值作为此参数传递,则函数的结果将不正确,因为我没有实现针对复数值输入情况所需的特殊注意事项,但函数将返回不正确的值,而不会出现错误或异常,因为函数中的每一行在语法方面都是有效的。

例如,请考虑如下函数:

代码语言:javascript
复制
import numpy as np
def foo(vara):
    """
    This function evaluates the Foo formula for the real variable vara.
    This function does not work for the complex variable vara because I am 
    too lazy to take care of the branch cut of the complex square-root function.
    """
    if vara<0:
        vv = -0.57386286*vara
    else:
        vv =  3.49604327*vara
    return np.sqrt(vv)

即使参数vara是复数,函数foo也会返回一个复数值,因为numpy.sqrt函数也是为复数参数定义的,但假设函数foo在实现时只考虑了实数参数,则返回的值将是不正确的。

如何在函数中检入参数是实值的,这样才能使函数抛出异常或错误退出?

请注意,我希望让该函数既适用于python的原生float类型,也适用于float类型元素的numpy数组。我只想禁止将函数与complex变量或complex元素的numpy数组一起使用。

(我考虑将1.0j与参数相乘,并检查结果的实部是否为零,但这看起来并不简洁。)

EN

回答 2

Stack Overflow用户

发布于 2015-09-03 04:51:51

如果你只想禁止复杂的数据类型,这样做是可行的:

代码语言:javascript
复制
import types

scalar_complex_types = [types.ComplexType, np.complex64, np.complex128]

def is_complex_sequence(vara):
    return (hasattr(vara, '__iter__') 
             and any(isinstance(v, t) for v in vara for t in complex_types)

def is_complex_scalar(vara):
    return any(isinstance(vara, t) for t in complex_types)

然后在你的函数中你就可以..

代码语言:javascript
复制
if is_complex_scalar(vara) or is_complex_sequence(vara):
    raise ValueError('Argument must not be a complex number')
票数 1
EN

Stack Overflow用户

发布于 2015-09-03 06:42:53

(我是在回答我自己的问题。我不确定这是不是最好的方法,但我想留下一个我尝试记录的代码。)

根据polpak的回答,我编写了以下代码。我猜这将满足我提出的条件。该函数是学院式的,因为它拒绝除float scaler或float ndarray之外的任何其他类型的输入参数。(也许它甚至不接受所有类型的浮动ndarray。)特别地,它拒绝整数定标器和整数ndarray以及复数定标器和复数ndarray。

代码语言:javascript
复制
#!/usr/bin/python

import numpy as np
import types

def foo(vara):
    """vara must be a real-valued scaler or ndarray."""

    real_types = [types.FloatType, np.float16, np.float32, np.float64, np.float128]
    print '----------'
    print 'vara:', vara
    if isinstance(vara, np.ndarray):
        if not any(vara.dtype==t for t in real_types):
            print 'NG.'
            print '   type(vara)=', type(vara)
            print '   vara.dtype=', vara.dtype
            # raise an error here
        else:
            print 'OK.'
            print '   type(vara)=', type(vara)
            print '   vara.dtype=', vara.dtype
    else:
        if not any(isinstance(vara, t) for t in real_types):
            print 'NG.'
            print '   type(vara)=', type(vara)
            # raise an error here
        else:
            print 'OK.'
            print '   type(vara)=', type(vara)


varalist=[3.0, 
          np.array([0.5, 0.2]), 
          np.array([3, 4, 1]), 
          np.array([3.4+1.2j, 0.8+0.7j]),
          np.array([3.4+0.0j, 0.8+0.0j]),
          np.array([1.3, 4.2, 5.9], dtype=complex),
          np.array([1.3, 4.2, 5.9], dtype=complex).real ]

for vara  in varalist:
    foo(vara)

此代码的输出如下所示。

代码语言:javascript
复制
$ ./main003.py 
----------
vara: 3.0
OK.
   type(vara)= <type 'float'>
----------
vara: [ 0.5  0.2]
OK.
   type(vara)= <type 'numpy.ndarray'>
   vara.dtype= float64
----------
vara: [3 4 1]
NG.
   type(vara)= <type 'numpy.ndarray'>
   vara.dtype= int64
----------
vara: [ 3.4+1.2j  0.8+0.7j]
NG.
   type(vara)= <type 'numpy.ndarray'>
   vara.dtype= complex128
----------
vara: [ 3.4+0.j  0.8+0.j]
NG.
   type(vara)= <type 'numpy.ndarray'>
   vara.dtype= complex128
----------
vara: [ 1.3+0.j  4.2+0.j  5.9+0.j]
NG.
   type(vara)= <type 'numpy.ndarray'>
   vara.dtype= complex128
----------
vara: [ 1.3  4.2  5.9]
OK.
   type(vara)= <type 'numpy.ndarray'>
   vara.dtype= float64
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/32362707

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档