首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用StellarGraph创建嵌入是不可复制的

使用StellarGraph创建嵌入是不可复制的
EN

Stack Overflow用户
提问于 2019-10-24 21:25:47
回答 1查看 822关注 0票数 1

我正在使用StellarGraph (一个奇妙的图形神经网络包),并试图为特定的图/特征集创建嵌入。不幸的是,每次创建/训练图形时,嵌入都是不同的,尽管每次都提供相同的信息。

是这个错误,还是我不正确地使用StellarGraph?

下面是说明这一问题的代码:

代码语言:javascript
复制
import networkx as nx
import random
import numpy as np
import pandas as pd
import keras
import stellargraph as sg
from stellargraph.mapper import GraphSAGELinkGenerator, GraphSAGENodeGenerator
from stellargraph.layer import GraphSAGE, link_classification
from stellargraph.data import UnsupervisedSampler

# Establish random seed
RANDOM_SEED = 42
random.seed(RANDOM_SEED)

# Create a graph from well-known karate club data
print(f"Creating graph")
graph = nx.karate_club_graph()

# Create features for each node
print(f"Creating features")
features = []
nodes = list(graph.nodes)
columns = ["c-" + str(x) for x in range(10)]
nodes.sort()
for node in nodes:
    f = {c: random.random() for c in columns}
    features.append(f)

features_df = pd.DataFrame(features)
print(f"features_df: \n{features_df}")

for i in range(2):
    print(f"----- Iteration: {i} -----")

    # Create the model and generators
    print(f"Creating the model and generators")
    Gs = sg.StellarGraph(graph, node_features=features_df)
    unsupervisedSamples = UnsupervisedSampler(Gs, nodes=graph.nodes(), length=5, number_of_walks=3, seed=RANDOM_SEED)
    train_gen = GraphSAGELinkGenerator(Gs, 50, [5, 5], seed=RANDOM_SEED).flow(unsupervisedSamples)
    graphsage = GraphSAGE(layer_sizes=[100, 100], generator=train_gen, bias=True, dropout=0.0, normalize="l2")
    x_inp_src, x_out_src = graphsage.node_model(flatten_output=False)
    x_inp_dst, x_out_dst = graphsage.node_model(flatten_output=False)

    x_inp = [x for ab in zip(x_inp_src, x_inp_dst) for x in ab]
    x_out = [x_out_src, x_out_dst]
    edge_embedding_method = "l2"
    prediction = link_classification(output_dim=1, output_act="sigmoid", edge_embedding_method=edge_embedding_method)(x_out)

    # Create and train the Keras model
    model = keras.Model(inputs=x_inp, outputs=prediction)
    learning_rate = 1e-2
    model.compile(
        optimizer=keras.optimizers.Adam(lr=learning_rate),
        loss=keras.losses.binary_crossentropy,
        metrics=[keras.metrics.binary_accuracy])

    _ = model.fit_generator(train_gen, epochs=5, verbose=2, use_multiprocessing=False, workers=1, shuffle=False)

    # Create the embeddings
    print(f"Creating the embeddings")
    nodes = list(graph.nodes)
    nodes.sort()
    print(f"Nodes: {nodes}")

    # Create a generator that serves up nodes for use in embedding prediction / creation
    node_gen = GraphSAGENodeGenerator(Gs, 50, [5, 5], seed=RANDOM_SEED).flow(nodes)

    embedding_model = keras.Model(inputs=x_inp_src, outputs=x_out_src)
    embeddings = embedding_model.predict_generator(node_gen, workers=4, verbose=1)
    embeddings = embeddings[:, 0, :]

    np.set_printoptions(threshold=10)
    print(f"embeddings: {embeddings.shape} \n{embeddings}")

在执行代码时,有许多调试(打印输出)语句。(输出示例如下所示)注意,尽管有相同的输入、图形配置、模型配置和随机值,但是嵌入是不同的。

代码语言:javascript
复制
----- Iteration: 0 -----
:
:
1/1 [==============================] - 0s 58ms/step
embeddings: (34, 100) 
[[-0.10566715  0.02253576 -0.18743701 ... -0.1028127   0.03689012
  -0.02482301]
 [-0.03171733  0.01606975 -0.08616363 ... -0.11775644  0.0429472
  -0.02371055]
 [-0.05802531  0.03910012 -0.10229243 ... -0.15050544  0.06637941
  -0.01950052]
 ...
 [ 0.03011296  0.08852117 -0.01836969 ... -0.154132    0.03844732
  -0.08643046]
 [ 0.01052345 -0.0123206   0.08913474 ... -0.11741614  0.03202919
  -0.04432516]
 [ 0.01951274  0.06263477  0.07959272 ... -0.10350229  0.05735112
  -0.0368157 ]]
:
:
----- Iteration: 1 -----
embeddings: (34, 100) 
[[ 0.11182436 -0.02642134  0.01168384 ...  0.10322241 -0.01680471
  -0.03918815]
 [ 0.02391489  0.02674667 -0.00091334 ...  0.12946768 -0.02389602
  -0.01414653]
 [ 0.08718258 -0.01711811 -0.05704292 ...  0.13477756 -0.00658288
  -0.05889895]
 ...
 [ 0.06843725 -0.13134597 -0.10870655 ...  0.11091235 -0.05146989
  -0.06138216]
 [-0.00593233 -0.05901312 -0.02113489 ... -0.01590953 -0.02516254
  -0.02280537]
 [ 0.00871993 -0.04059998 -0.07237951 ... -0.01590569 -0.00954109
  -0.01116194]]
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-01-30 05:04:20

这以前是stellargraph中的一个bug,现在已经在v0.9.0 https://github.com/stellargraph/stellargraph/releases/tag/v0.9.0中解决了。

无监督的GraphSAGE现在已经更新和测试的重现性。确保所有种子都已设置,运行相同的管道应该提供可重复的嵌入。

目前,“确保所有种子都被设置”用于无监督的GraphSAGE意味着:

在构造random

  • providing和对象时,
    • 修复了这些外部包的种子:numpytensorflownumpy。这些类用于执行随机游动和邻里抽样。
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58549321

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档