首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在联邦学习中应用传输学习时获取ValueError

在联邦学习中应用传输学习时获取ValueError
EN

Stack Overflow用户
提问于 2022-05-22 07:05:52
回答 1查看 46关注 0票数 2

我希望在联邦学习中使用pre_trained模型,如下代码所示:

首先我建立我的模型并设定模型的权重,然后冻结卷积层并移除4个最后一层。

代码语言:javascript
复制
def create_keras_model():
    model = Sequential()
    model.add(Conv2D(16, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu', input_shape=(226,232,1)))
    model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
    model.add(Conv2D(32, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
    model.add(Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
    model.add(Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
    model.add(Conv2D(128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
    model.add(Conv2D(128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
    model.add(Conv2D(256, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
    model.add(Flatten())
    model.add(Dense(256, activation='relu'))
    model.add(Dense(128, activation='relu'))
    model.add(Dense(64, activation='relu'))
    model.add(Dense(10, activation='softmax'))
    return model

keras_model = create_keras_model()

server_state=FileCheckpointManager(root_dir= '/content/drive/MyDrive',
    prefix=  'federated_clustering',
    step= 1,
    keep_total= 1,
    keep_first= True).load_checkpoint(structure=server_state,round_num=10)

keras_model.set_weights(server_state)

for layer in keras_model.layers[:-4]:
  layer.trainable = False

model_pre = Model(inputs=keras_model.input,outputs=keras_model.layers[14].output)

接下来,我建立新的模型。

代码语言:javascript
复制
def create_keras_model1():
    model = Sequential()
    model.add(model_pre)
    model.add(Dense(256, activation='relu'))
    model.add(Dense(128, activation='relu'))
    model.add(Dense(64, activation='relu'))
    model.add(Dense(2, activation='softmax'))
    return model

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model1()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

但是当我想要使用tff.learning.build_federated_averaging_process.时,我得到了ValueError

代码语言:javascript
复制
iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.001),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

ValueError: Your Layer or Model is in an invalid state. This can happen for the following cases:
 1. You might be interleaving estimator/non-estimator models or interleaving models/layers made in tf.compat.v1.Graph.as_default() with models/layers created outside of it. Converting a model to an estimator (via model_to_estimator) invalidates all models/layers made before the conversion (even if they were not the model converted to an estimator). Similarly, making a layer or a model inside a a tf.compat.v1.Graph invalidates all layers/models you previously made outside of the graph.
2. You might be using a custom keras layer implementation with  custom __init__ which didn't call super().__init__.  Please check the implementation of <class 'keras.engine.functional.Functional'> and its bases.

请帮我修一下。

EN

回答 1

Stack Overflow用户

发布于 2022-05-23 09:32:13

如代码注释所示,您需要在调用model_fn期间创建一个新的keras模型。似乎您正在使用之前已经创建的model_pre,这很可能是核心问题。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72335489

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档