首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在skflow中添加正则化器

在skflow中添加正则化器
EN

Stack Overflow用户
提问于 2016-04-13 11:51:59
回答 1查看 509关注 0票数 3

我最近从tensorflow转到了skflow。在tensorflow中,我们会将我们的lambda*tf.nn.l2_loss(权重)添加到我们的损失中。现在,我在skflow中有了以下代码:

代码语言:javascript
复制
def deep_psi(X, y):
    layers = skflow.ops.dnn(X, [5, 10, 20, 10, 5], keep_prob=0.5)
    preds, loss = skflow.models.logistic_regression(layers, y)
    return preds, loss

def exp_decay(global_step):
    return tf.train.exponential_decay(learning_rate=0.01,
                                      global_step=global_step,
                                      decay_steps=1000,
                                      decay_rate=0.005)

deep_cd = skflow.TensorFlowEstimator(model_fn=deep_psi,
                                    n_classes=2,
                                    steps=10000,
                                    batch_size=10,
                                    learning_rate=exp_decay,
                                    verbose=True,)

我怎么在这里加一个正规化器?伊利亚暗示了一些这里的东西,但我搞不懂。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-04-13 17:17:03

您仍然可以向损失中添加其他组件,只需从dnn / logistic_regression中检索权重并将它们添加到损失中:

代码语言:javascript
复制
def regularize_loss(loss, weights, lambda):
    for weight in weights:
        loss = loss + lambda * tf.nn.l2_loss(weight)
    return loss    


def deep_psi(X, y):
    layers = skflow.ops.dnn(X, [5, 10, 20, 10, 5], keep_prob=0.5)
    preds, loss = skflow.models.logistic_regression(layers, y)

    weights = []
    for layer in range(5): # n layers you passed to dnn
        weights.append(tf.get_variable("dnn/layer%d/linear/Matrix" % layer))
        # biases are also available at dnn/layer%d/linear/Bias
    weights.append(tf.get_variable('logistic_regression/weights'))

    return preds, regularize_loss(loss, weights, lambda)
代码语言:javascript
复制

注意,变量的路径可以是在这里发现的

另外,我们希望增加对所有层的正则化支持,包括变量(如dnnconv2dfully_connected),因此下周Tensorflow的夜间构建应该有类似于dnn(.., regularize=tf.contrib.layers.l2_regularizer(lambda))的内容。当这种情况发生时,我会更新这个答案。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/36597519

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档