首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用StellarGraph的“监视您的步进”模型不适用于GPU。

使用StellarGraph的“监视您的步进”模型不适用于GPU。
EN

Stack Overflow用户
提问于 2020-11-02 13:08:32
回答 1查看 596关注 0票数 1

我正在尝试用WatchYourStep算法来训练一个使用StellarGraph的大图嵌入。

由于某种原因,模型只在一个上训练,而不使用GPU

使用:

  • TensorFlow-gpu 2.3.1
  • 有2个GPU,Cuda10.1
  • 运行在nvidia-码头容器中。(tf.debugging.set_log_device_placement(True))
  • I试图在with tf.device('/GPU:0'):
  • I下运行,已经尝试用tf.distribute.MirroredStrategy().
  • Tried运行它来卸载tensorflow并重新安装tensorflow-gpu.

尽管如此,在运行nvidia时,我没有看到GPU上有任何活动,而且培训非常缓慢。

如何调试?

代码语言:javascript
复制
def watch_your_step_model():
    '''use the config to geenrate the WatchYourStep model'''
    cfg = load_config()
    generator           = generator_for_watch_your_step()
    num_walks           = cfg['num_walks']
    embedding_dimension = cfg['embedding_dimension']
    learning_rate       = cfg['learning_rate']
    
    wys = WatchYourStep(
        generator,
        num_walks=num_walks,
        embedding_dimension=embedding_dimension,
        attention_regularizer=regularizers.l2(0.5),
    )
    
    x_in, x_out = wys.in_out_tensors()
    model = Model(inputs=x_in, outputs=x_out)
    model.compile(loss=graph_log_likelihood, optimizer=optimizers.Adam(learning_rate))
    return model, generator, wys

def train_watch_your_step_model(epochs = 3000):
    cfg = load_config()
    batch_size      = cfg['batch_size']
    steps_per_epoch = cfg['steps_per_epoch']
    callbacks, checkpoint_file = watch_your_step_callbacks(cfg)
    
    # strategy = tf.distribute.MirroredStrategy()
    # print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
    # with strategy.scope():
    
    model, generator, wys = watch_your_step_model()

    train_gen = generator.flow(batch_size=batch_size, num_parallel_calls=8)
    train_gen.prefetch(20480000)

    history = model.fit(
        train_gen, 
        epochs=epochs, 
        verbose=1, 
        steps_per_epoch=steps_per_epoch,
        callbacks = callbacks
    )
     
    copy_last_trained_wys_weights_to_data()
    
    return history, checkpoint_file

with tf.device('/GPU:0'):
    train_watch_your_step_model()
EN

回答 1

Stack Overflow用户

发布于 2021-01-04 17:17:49

我只是按照下面的说明:https://github.com/stellargraph/stellargraph/issues/546

对我起作用了。

基本上,您必须从stellargraph编辑文件setup.py,并删除tensorflow要求(第25和27行https://github.com/stellargraph/stellargraph/blob/develop/setup.py)。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64646197

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档