首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TensorFlow -不能将字符串转换为浮动错误?

TensorFlow -不能将字符串转换为浮动错误?
EN

Stack Overflow用户
提问于 2021-05-14 00:32:15
回答 1查看 507关注 0票数 0

我尝试从stellargraph示例运行一个示例,但遇到了一个奇怪的错误:

不支持tensorflow/core/framework/op_kernel.cc:1744] OP_REQUIRES在cast_op.cc:121上失败:未实现:将字符串转换为浮动

我使用的示例代码如下:

代码语言:javascript
复制
import pandas as pd
import numpy as np

import stellargraph as sg
from stellargraph.mapper import PaddedGraphGenerator
from stellargraph.layer import GCNSupervisedGraphClassification
from stellargraph import StellarGraph

from stellargraph import datasets

from sklearn import model_selection
from IPython.display import display, HTML

from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow as tf
import matplotlib.pyplot as plt

dataset = datasets.MUTAG()
display(HTML(dataset.description))
graphs, graph_labels = dataset.load()

print(graphs[0].info())
print(graphs[1].info())

summary = pd.DataFrame(
    [(g.number_of_nodes(), g.number_of_edges()) for g in graphs],
    columns=["nodes", "edges"],
)
print(summary.describe().round(1))

generator = PaddedGraphGenerator(graphs=graphs)

def create_graph_classification_model(generator):
    gc_model = GCNSupervisedGraphClassification(
        layer_sizes=[64, 64],
        activations=["relu", "relu"],
        generator=generator,
        dropout=0.5,
    )
    x_inp, x_out = gc_model.in_out_tensors()
    predictions = Dense(units=32, activation="relu")(x_out)
    predictions = Dense(units=16, activation="relu")(predictions)
    predictions = Dense(units=1, activation="sigmoid")(predictions)

    # Let's create the Keras model and prepare it for training
    model = Model(inputs=x_inp, outputs=predictions)
    model.compile(optimizer=Adam(0.005), loss=binary_crossentropy, metrics=["acc"])

    return model

epochs = 200  # maximum number of training epochs
folds = 10  # the number of folds for k-fold cross validation
n_repeats = 5  # the number of repeats for repeated k-fold cross validation
es = EarlyStopping(
    monitor="val_loss", min_delta=0, patience=25, restore_best_weights=True
)

def train_fold(model, train_gen, test_gen, es, epochs):
    history = model.fit(
        train_gen, epochs=epochs, validation_data=[test_gen], verbose=0, callbacks=es,
    )
    # calculate performance on the test data and return along with history
    test_metrics = model.evaluate(test_gen, verbose=0)
    test_acc = test_metrics[model.metrics_names.index("acc")]

    return history, test_acc

def get_generators(train_index, test_index, graph_labels, batch_size):
    train_gen = generator.flow(
        train_index, targets=graph_labels.iloc[train_index].values, batch_size=batch_size
    )
    test_gen = generator.flow(
        test_index, targets=graph_labels.iloc[test_index].values, batch_size=batch_size
    )

    return train_gen, test_gen

test_accs = []

stratified_folds = model_selection.RepeatedStratifiedKFold(
    n_splits=folds, n_repeats=n_repeats
).split(graph_labels, graph_labels)

for i, (train_index, test_index) in enumerate(stratified_folds):
    print(f"Training and evaluating on fold {i+1} out of {folds * n_repeats}...")
    train_gen, test_gen = get_generators(
        train_index, test_index, graph_labels, batch_size=30
    )

    model = create_graph_classification_model(generator)

    history, acc = train_fold(model, train_gen, test_gen, es, epochs)

    test_accs.append(acc)

print(
    f"Accuracy over all folds mean: {np.mean(test_accs)*100:.3}% and std: {np.std(test_accs)*100:.2}%"
)

整个错误消息是:

代码语言:javascript
复制
2021-05-14 03:23:24.176132: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2021-05-14 03:23:24.982603: W tensorflow/core/framework/op_kernel.cc:1744] OP_REQUIRES failed at cast_op.cc:121 : Unimplemented: Cast string to float is not supported
Traceback (most recent call last):
  File "C:/Users/1/PycharmProjects/University Homework/exmpl.py", line 96, in <module>
    history, acc = train_fold(model, train_gen, test_gen, es, epochs)
  File "C:/Users/1/PycharmProjects/University Homework/exmpl.py", line 63, in train_fold
    history = model.fit(
  File "C:\Users\1\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1183, in fit
    tmp_logs = self.train_function(iterator)
  File "C:\Users\1\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "C:\Users\1\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 950, in _call
    return self._stateless_fn(*args, **kwds)
  File "C:\Users\1\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\function.py", line 3023, in __call__
    return graph_function._call_flat(
  File "C:\Users\1\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\function.py", line 1960, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "C:\Users\1\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\function.py", line 591, in call
    outputs = execute.execute(
  File "C:\Users\1\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.UnimplementedError:  Cast string to float is not supported
     [[node binary_crossentropy/Cast (defined at /Users/1/PycharmProjects/University Homework/exmpl.py:63) ]] [Op:__inference_train_function_1247]

Function call stack:
train_function

我在任何地方都找不到一个浮点,它被赋予了一个字符串的值,所以我不知道这里发生了什么。任何帮助都是非常感谢的!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-06-01 10:49:14

显然,增加了一行:

代码语言:javascript
复制
graph_labels = pd.get_dummies(graph_labels, drop_first=True)

在创建PaddedGraphGenerator之前,似乎解决了这个问题。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67527713

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档