首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow_probability子类化JointDistributionNamed __init__行为

tensorflow_probability子类化JointDistributionNamed __init__行为
EN

Stack Overflow用户
提问于 2019-12-13 19:01:26
回答 1查看 53关注 0票数 0

我正在尝试在tensorflow_probability库(tensorflow v2.0.0,tensorflow_probability v0.8.0)中从JointDistributionNamed创建一个派生类。然而,super().__init__函数以一种我不理解的奇怪方式运行。也许我只是错误地使用了super(),但对于其他类,它似乎可以像我期望的那样工作。不管怎样,这里有一个例子:

代码语言:javascript
复制
from tensorflow_probability import distributions as tfd

models = {'normal': tfd.Normal(loc=0, scale=1)}
joint = tfd.JointDistributionNamed(models) # Works perfectly fine
print("joint:",joint) 

class Test(tfd.JointDistributionNamed):
    def __init__(self,name,models):
        self.myname = name
        self.models = models
        super().__init__(models) #(1) Works
        #super().__init__(self.models) #(2) Doesn't work

t = Test('hello',models)
print("t:", t)

models赋给super().__init__时的行为是不同的,无论我是直接传入models还是先将其赋给self.models。为什么?在后一种情况下,我得到以下错误:

代码语言:javascript
复制
Traceback (most recent call last):
  File "test_jointdistnamed.py", line 18, in <module>
    t = Test('hello',models)
  File "</home/farmer/anaconda3/envs/tensorflow/lib/python3.7/site-packages/decorator.py:decorator-gen-244>", line 2, in __init__
  File "/home/farmer/anaconda3/envs/tensorflow/lib/python3.7/site-packages/tensorflow_probability/python/distributions/distribution.py", line 276, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "test_jointdistnamed.py", line 16, in __init__
    super().__init__(self.models) #doesn't work
  File "</home/farmer/anaconda3/envs/tensorflow/lib/python3.7/site-packages/decorator.py:decorator-gen-138>", line 2, in __init__
  File "/home/farmer/anaconda3/envs/tensorflow/lib/python3.7/site-packages/tensorflow_probability/python/distributions/distribution.py", line 276, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/home/farmer/anaconda3/envs/tensorflow/lib/python3.7/site-packages/tensorflow_probability/python/distributions/joint_distribution_named.py", line 170, in __init__
    model, validate_args, name or 'JointDistributionNamed')
  File "</home/farmer/anaconda3/envs/tensorflow/lib/python3.7/site-packages/decorator.py:decorator-gen-70>", line 2, in __init__
  File "/home/farmer/anaconda3/envs/tensorflow/lib/python3.7/site-packages/tensorflow_probability/python/distributions/distribution.py", line 276, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/home/farmer/anaconda3/envs/tensorflow/lib/python3.7/site-packages/tensorflow_probability/python/distributions/joint_distribution_sequential.py", line 211, in __init__
    self._model_unflatten(self._model_flatten(model))
  File "/home/farmer/anaconda3/envs/tensorflow/lib/python3.7/site-packages/tensorflow_probability/python/distributions/joint_distribution_named.py", line 186, in _model_unflatten
    return type(self.model)(**kwargs)
TypeError: __init__() got an unexpected keyword argument 'normal'

由于某种原因,它就像是试图在类层次结构中的某个地方解压models字典。但是,为什么这会因为我是否首先赋值给self而不同呢?无论哪种方式,我传递的不是对完全相同的字典的引用吗?有什么不同?这是一个奇怪的bug,还是我做错了什么?如果我编写自己的简单定制类而不是JointDistributionNamed,同样的事情似乎也可以很好地工作。

EN

回答 1

Stack Overflow用户

发布于 2019-12-17 03:19:12

我的直觉是与字段的tf.Module依赖跟踪包装器对象有关。什么是type(self.models)?如果您更改为self._models = self._no_dependency(models),它可以工作吗?

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59321225

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档