首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >ScaNN索引内存瓶颈的Tensorflow推荐大模型

ScaNN索引内存瓶颈的Tensorflow推荐大模型
EN

Stack Overflow用户
提问于 2022-08-03 16:13:32
回答 1查看 180关注 0票数 3

我有一个相对较大的TF检索模型,使用TFRS库。它使用ScaNN层作为为建议编制索引。当我试图通过model.save()方法保存这个模型时,我遇到了系统主机内存问题。我在云中的VM上运行带有TFRS的官方TF 2.9.1 Docker容器。我有28 GB的内存试图保存模型。

下面是快速启动示例:

基本上,我们创建了第一个嵌入

代码语言:javascript
复制
user_model = tf.keras.Sequential([
    tf.keras.layers.StringLookup(
    vocabulary=unique_user_ids, mask_token=None),
    # We add an additional embedding to account for unknown tokens.
    tf.keras.layers.Embedding(len(unique_user_ids) + 1, embedding_dimension)
])

然后创建模型

代码语言:javascript
复制
class MovielensModel(tfrs.Model):

  def __init__(self, user_model, movie_model):
    super().__init__()
    self.movie_model: tf.keras.Model = movie_model
    self.user_model: tf.keras.Model = user_model
    self.task: tf.keras.layers.Layer = task

  def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
    # We pick out the user features and pass them into the user model.
    user_embeddings = self.user_model(features["user_id"])
    # And pick out the movie features and pass them into the movie model,
    # getting embeddings back.
    positive_movie_embeddings = self.movie_model(features["movie_title"])

    # The task computes the loss and the metrics.
    return self.task(user_embeddings, positive_movie_embeddings)

接下来,我们创建ScaNN索引层

代码语言:javascript
复制
scann_index = tfrs.layers.factorized_top_k.ScaNN(model.user_model)

scann_index.index_from_dataset(
  tf.data.Dataset.zip((movies.batch(100), movies.batch(100).map(model.movie_model)))
)

# Get recommendations.
_, titles = scann_index(tf.constant(["42"]))
print(f"Recommendations for user 42: {titles[0, :3]}")

最后,将模型发送出去保存。

代码语言:javascript
复制
# Export the query model.
with tempfile.TemporaryDirectory() as tmp:
   path = os.path.join(tmp, "model")

   # Save the index.
   tf.saved_model.save(
      index,
      path,
      options=tf.saved_model.SaveOptions(namespace_whitelist=["Scann"])
   )

   # Load it back; can also be done in TensorFlow Serving.
   loaded = tf.saved_model.load(path)

   # Pass a user id in, get top predicted movie titles back.
   scores, titles = loaded(["42"])

   print(f"Recommendations: {titles[0][:3]}")

这就是问题所在:

代码语言:javascript
复制
   # Save the index.
   tf.saved_model.save(
      index,
      path,
      options=tf.saved_model.SaveOptions(namespace_whitelist=["Scann"])
   )

我不确定是否有内存泄漏什么的,但是当我用5M+记录训练我的模型时.我可以看到主机系统的内存峰值达到100%,进程就被扼杀了。如果我用更小的数据集训练..。没有问题,所以我知道代码没问题。

谁能建议在保存大型ScaNN检索模型时如何绕过内存瓶颈,这样我最终就可以加载模型进行推断了吗?

EN

回答 1

Stack Overflow用户

发布于 2022-08-22 09:23:08

我想你是在训练完成后保存TF模型。您只需要保存的模型从模型中获得经过训练的权重。

您可以尝试以下代码:

代码语言:javascript
复制
    sku_ids = df['SKU_ID']
    sku_ids_list = sku_ids.to_list()
    q = embedding(sku_ids, output_mode='distance_matrix')
    dist_mat = tf.cast(q, tf.float32)
    tree = scann.Scann(n_tables=scann_tables_file_name,
                      n_clusters_per_table=scann_clusters_file_name,
                      dimension=embedding_dimensions,
                      space_type=dist_mat.dtype,
                      metric_type=tf.float32,
                      random_seed=seed,
                      transport_dtype=tf.float32,
                      symmetrize_query_and_dataset=True,
                      num_neighbors_per_table=scann_tables_number_of_neighbors)

    q = tree.build_index(dist_mat)
    p = tree.run(dist_mat)

    model = keras.models.Sequential([
        scann.Dense(1, use_bias=False, activation='linear', dtype=tf.float32),
        keras.layers.Activation('sigmoid')
    ])

    model.compile(
        keras.optimizers.Adam(1e-3),
        'binary_crossentropy', metrics=[metrics.BinaryAccuracy()])

    idx = -1
    number_of_epochs = 10
    optimizer = keras.optimizers.Adam(1e-3)
    optimizer_state = None
    random_seed = seed
    callbacks = [
        keras.callbacks.EarlyStopping(
            monitor='binary_accuracy', mode='max',
            patience=10, restore_best_weights=True)]
    batch_size = 1000
    total_records = len(sku_ids)
    epochs = number_of_epochs
    epochs_completed = 0
    while epochs_completed < epochs:
        idx += 1
        if idx * batch_size >= total_records:
            idx = 0
            epochs_completed += 1
            optimizer_state = None
        print("training epoch: {}".format(idx))
        q_ = tree.transform(dist_mat[idx * batch_size : (idx + 1) * batch_size])
        p_ = tree.transform(dist_mat)
        y = p_[:, :, 0]
        print("callbacks: {}".format(callbacks))
        print("model compile: {}".format(model.compile))
        model.fit(q_, y, epochs=1, batch_size=batch_size,
                  callbacks=callbacks,
                  validation_split=0.2,
                  verbose=0,
                  shuffle=True,
                  initial_epoch=0,
                  steps_per_epoch=None,
                  validation_steps=None,
                  validation_batch_size=None,
                  validation_freq=1,
                  class_weight=None,
                  max_queue_size=10,
                  workers=1,
                  use_multiprocessing=False,
                  shuffle=False, initial_epoch=0)


    sku_ids_tensor = tf.constant(sku_ids_list,
                                 shape=[len(sku_ids_list), 1],
                                 dtype=tf.int64)

    print("sku_ids_tensor shape: {}".format(sku_ids_tensor.shape))
    tree_tensor = tree.transform(dist_mat)
    print("tree_tensor shape: {}".format(tree_tensor.shape))
    predictions = tf.constant(tf.sigmoid(model.predict(tree_tensor)),
                              dtype=tf.float32)
    print("predictions shape: {}".format(predictions.shape))
    recommendations = tf.concat([sku_ids_tensor, predictions], axis=1)
    print("recommendations shape: {}".format(recommendations.shape))

    retrieval_user_sku_recommendations = []

    for u in unique_sku_list:
        print("u: {}".format(u))
        user_skus = sku_ids[sku_ids.isin([u])]
        print("user_skus: {}".format(user_skus))
        user_sku_id = user_skus.index[0]
        print("user_sku_id: {}".format(user_sku_id))
        user_sku_recommendations = recommendations[sku_ids.isin([u])]
        print("user_sku_recommendations: {}".format(user_sku_recommendations))
        retrieval_user_sku_recommendations.append(user_sku_recommendations)

    retrieval_skus_df = pd.DataFrame(sku_ids_list, columns=['SKU_ID'])
    retrieval_skus_df['SKU_ID'] = retrieval_skus_df['SKU_ID'].astype(int)
    retrieval_skus_df.head()

    user_sku_recommendations_list = []
    for sku in retrieval_skus_df['SKU_ID']:
        for u in unique_sku_list:
            print("sku: {}".format(sku))
            print("u: {}".format(u))
            if sku == u:
                user_skus = sku_ids[sku_ids.isin([sku])]
                user_sku_id = user_skus.index[0]
                user_sku_recommendations = recommendations[sku_ids.isin([sku])]
                user_sku_recommendations_list.append(user_sku_recommendations)

    tf.saved_model.save(model, ss_model_dir)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73224541

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档