下面的代码片段根据提供的选项构造不同类型的神经网络输出。目前,我的代码中有一个很大的注释,它描述了所有选项应该做什么。我想学习如何使这段代码不那么复杂,或者把它的复杂性分解成更简单的部分。有什么建议吗?
#------------------------------------------------------------#
# NOTE: Meaning of all the options. #
# stagger_schedule=extended: We copy input vec to output. #
# stagger_schedule=external: We dont copy input to output. #
# -----------------------------------------------------------#
# do_backward_pass: We use the output of the backward LSTM #
# Default:True. #
# -----------------------------------------------------------#
# chop_bilstm: Should we chop the first and last vectors from#
# the sequence. Default:False #
#------------------------------------------------------------#
# extended_multiplicative: Multiply the forward and back LSTM#
# and concatenate the input embedding. #
# external_multiplicative: Multiply the forward and back LSTM#
# and but dont concatenate the input embedding. #
#------------------------------------------------------------#
if (self.prm('stagger_schedule') == 'extended'):
if self.prm('chop_bilstm'):
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate(
[forward, backward, input_tv], axis=1)[1:-1]
pass
else:
self.output_tv = T.concatenate(
[forward, input_tv], axis=1)[1:-1]
pass
pass
else:
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate(
[forward, backward, input_tv], axis=1)
pass
else:
self.output_tv = T.concatenate(
[forward, input_tv], axis=1)
pass
pass
pass
elif self.prm('stagger_schedule') == 'external':
if self.prm('chop_bilstm'):
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate(
[forward[1:-1], backward[2:]], axis=1)
pass
else:
self.output_tv = forward[1:-1]
pass
pass
else:
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate(
[forward, backward], axis=1)
pass
else:
self.output_tv = forward
pass
pass
pass
elif self.prm('stagger_schedule') == 'extended_multiplicative':
if self.prm('chop_bilstm') or (not self.prm('do_backward_pass')):
raise NotImplementedError()
self.output_tv = T.concatenate(
[forward * backward, input_tv], axis=1)
pass
elif self.prm('stagger_schedule') == 'external_multiplicative':
if self.prm('chop_bilstm') or (not self.prm('do_backward_pass')):
raise NotImplementedError()
self.output_tv = forward * backward
pass
else:
raise NotImplementedError()发布于 2016-01-17 15:02:27
下面是一些整理这些代码的方法:
pass语句字面上是一事无成,除了为未编写的代码提供占位符之外。如果您将它们全部删除,您将节省许多行,并能够在屏幕上安装更多的代码。extended分支下,除了删除self.output_tv的第一个和最后一个字符外,代码几乎是相同的。如果我们推迟到最后,我们可以有一组分支如下: If (self.prm('stagger_schedule') == 'extended'):if self.prm('do_backward_pass'):self.output_tv = T.concatenate(向前,向后,输入_电视,axis=1) of : self.output_tv = T.concatenate(前进,输入_电视,( axis=1)如果self.prm('chop_bilstm'):self.output_tv = self.output_tv1:-1 22行减少到7行,而这仅仅是2层深。external分支中的代码也可以进行类似的合并:您可以在两个子分支中执行相同的效果,但是使用略为截断的forward和backward变量。这里有另一个版本: elif self.prm('stagger_schedule') == 'external':if self.prm('chop_bilstm'):forward = forward1:-1 deep =反向2: if self.prm('do_backward_pass'):self.output_tv =T.concatenate([前进,向后,axis=1) ) is : self.output_tv =前进20行,减少到8行,这只有2个层次深。https://codereview.stackexchange.com/questions/117028
复制相似问题