在dm-haiku中,神经网络的参数定义在字典中,其中键是模块(和子模块)名称。如果您想遍历这些值,有多种方法可以这样做,如“这 dm-haiku问题”中所示。然而,字典并不尊重模块的顺序,因此很难解析子模块。例如,如果我有两个linear层,每个层后面跟着一个mlp层,那么使用hk.data_structures.traverse(params)将(大致)返回:
['linear', 'linear_2', 'mlp/~/1', 'mlp/~/2'].不过,我希望它能回来:
['linear', 'mlp/~/1', 'linear_2', 'mlp/~/2'].我想要这种形式的原因是,如果创建一个可逆的神经网络,并想要逆转params被调用的顺序,为了其他目的(例如转移学习)而隔离替代部分,或者,一般地,想要更多地控制如何和在何处(重新)使用经过训练的参数。
为了解决这个问题,我采用了对名称进行正则化的方法,并将它们按我想要的顺序排列,然后使用hk.data_structures.filter(predicate, params)根据排序后的模块名称进行筛选。尽管如此,如果我每次想要这样做的时候都要重做一个regex,这是相当乏味的。
我想知道是否有一种方法可以将dm-haiku的params字典转换为具有层次结构和排序的pytree,从而使这更容易呢?我相信equinox会以这种方式处理参数(我很快会更深入地了解这一点),但是我想检查一下是否忽略了允许分组、反转和params字典的其他排列的简单方法?
发布于 2022-08-24 02:00:28
根据源代码,src/filtering.py#L42#L46 haiku使用dict的排序函数(haiku参数为香草dict,自0.0.6以来)用于hk.data_structures.traverse。因此,如果不修改函数本身,就无法获得所需的结果。顺便说一句,我不太明白你所说的“颠倒params的顺序”到底是什么意思。所有参数都在输入中一起传递,然后决定使用顺序的唯一因素是函数本身的体系结构,因此您应该手动反转前向传递,但不需要在params中更改某些内容。
https://stackoverflow.com/questions/72860276
复制相似问题