首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >加载tensorflow2.0模型时出现错误

加载tensorflow2.0模型时出现错误
EN

Stack Overflow用户
提问于 2020-01-15 20:52:32
回答 1查看 675关注 0票数 0

我正在学习一个简单的模型来执行线性回归,然后保存该模型

代码语言:javascript
复制
class NN(tf.keras.Model):
  def __init__(self):
    super(NN, self).__init__()
    L = 20
    self.W1 = tf.Variable(tf.random.truncated_normal([1, L], stddev=math.sqrt(3)))
    self.B1 = tf.Variable(tf.random.truncated_normal([1, L], stddev=1.0))
    self.W2 = tf.Variable(tf.random.truncated_normal([L, 1], stddev=math.sqrt(3/L)))
    self.B2 = tf.Variable(tf.zeros([1]))
  def call(self, inputs):
    Z1 = tf.matmul(inputs, self.W1) + self.B1
    Y1 = tf.nn.tanh(Z1)
    Y = tf.matmul(Y1, self.W2) + self.B2
    return Y

# The loss function to be optimized
def loss(model, X, Y_):
  error = model(X) - Y_
  return tf.reduce_mean(tf.square(error))

model = NN()
optimizer = tf.optimizers.Adam(learning_rate=0.001)
bsize = 20

# You can call this function in a loop to train the model, bsize samples at a time
def training_step(i):
  # read data
  x_batch, y_batch = func.next_batch(bsize)
  x_batch = np.reshape(x_batch, (bsize,1))
  y_batch = np.reshape(y_batch, (bsize,1))
  # compute training values
  loss_fn = lambda: loss(model, x_batch, y_batch)
  optimizer.minimize(loss_fn, [model.W1, model.B1, model.W2, model.B2])
  if i%5000 == 0:
    l = loss(model, x_batch, y_batch)
    print(str(i) + ": epoch: " + str(func._epochs_completed) + ": loss: " + str(l.numpy()))

for i in range(50001): 
  training_step(i)

# save the model
tf.saved_model.save(model, "my_file")

然后,我尝试使用以下几行tensorflow文档加载模型:

代码语言:javascript
复制
model = tf.saved_model.load("my_file")
f = model.signatures["serving_default"]
y = f(x)

然而,我得到了以下错误消息:

代码语言:javascript
复制
 f = model.signatures["serving_default"]
File "my_file/signature_serialization.py", line 195, in __getitem__
    return self._signatures[key]
KeyError: 'serving_default'

出什么问题了?为什么没有定义serving_default?

EN

回答 1

Stack Overflow用户

发布于 2020-01-16 15:19:57

我通过向tf.saved_model.save函数添加第三个参数解决了这个问题

代码语言:javascript
复制
tf.saved_model.save(model, "myfile", signatures=model.call.get_concrete_function(tf.TensorSpec(shape=[None,1], dtype=tf.float32, name="inp")))

并将@tf.function添加到call方法之上

代码语言:javascript
复制
class NN(tf.keras.Model):
  def __init__(self):
    super(NN, self).__init__()
    L = 20
    self.W1 = tf.Variable(tf.random.truncated_normal([1, L], stddev=math.sqrt(3)))
    self.B1 = tf.Variable(tf.random.truncated_normal([1, L], stddev=1.0))
    self.W2 = tf.Variable(tf.random.truncated_normal([L, 1], stddev=math.sqrt(3/L)))
    self.B2 = tf.Variable(tf.zeros([1]))
  @tf.function
  def call(self, X):
    Z1 = tf.matmul(X, self.W1) + self.B1
    Y1 = tf.nn.tanh(Z1)
    Y = tf.matmul(Y1, self.W2) + self.B2
    return Y
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59751851

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档