首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >N元组网络的Keras (稀疏输入)

N元组网络的Keras (稀疏输入)
EN

Stack Overflow用户
提问于 2021-01-14 16:06:50
回答 1查看 99关注 0票数 0

我试着用keras训练N元组网络。N元组网络只是一个热激活模式的稀疏数组.假设棋盘上有64个正方形,每个正方形包含可能的N类棋子,所以总是有64个激活的棋盘,用于64*N可能的参数,并存储为2d数组64。或者每一个可能的2x2平方,所以N^4可能对每个这样的方格进行配置。这类网络是线性的,并将输出1值。培训是一个很好的,老SGD和类似的。

我使用c++中的代码、使用查找表和求和成功地训练了网络。但是我试着去做keras,因为keras允许不同的优化算法,GPU的使用等等。首先,我把2d数组变成了大矢量,但很快它就变得不切实际了。有数千个可能的参数,其中只有少数(固定)的1,其余的是零。

我想知道在keras (或类似的库)中是否可以使用这样的训练数据: 13,16,11,11,5,...,3,其中这些数字将是索引,而不是使用0,0,0,1,0,0,0,......,1,0,0,0,....,1,0,0,0,.

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-01-14 23:45:33

您可以使用,tf.sparse.SparseTensor(...),然后设置sparse=True,用于Tf.keras.Input(.)

代码语言:javascript
复制
def sparse_one_hot(y):

  m = len(y)
  n_classes = len(tf.unique(tf.squeeze(y))[0])

  dim2 = tf.range(m, dtype='int64')[:, None]
  indices = tf.concat([y, dim2], axis=1)

  ones = tf.ones(shape=(m, ), dtype='float32')

  sparse_y = tf.sparse.SparseTensor(indices, ones, dense_shape=(m, n_classes))
  
  return sparse_y
代码语言:javascript
复制
import tensorflow as tf

y = tf.random.uniform(shape=(10, 1), minval=0, maxval=4, dtype=tf.int64)

sparse_y = sparse_one_hot(y) # sparse_y.values, sparse_y.indices

# set sparse=True, for Input
# tf.keras.Input(..., sparse=True, ...)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65722516

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档