首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow.js数据集到张量?

Tensorflow.js数据集到张量?
EN

Stack Overflow用户
提问于 2019-03-02 12:50:18
回答 2查看 1.7K关注 0票数 0

tf.data.Dataset中的底层“数据示例”是平面数组时,有没有推荐/有效的方法将Dataset转换为Tensor

我使用tf.data.csv来读取和解析CSV,但随后希望使用Tensorflow.js核心API将数据处理为tf.Tensors

EN

回答 2

Stack Overflow用户

发布于 2019-03-02 15:26:46

tf.data.Dataset.iterator()返回一个迭代器的承诺。

代码语言:javascript
复制
const it = await flattenedDataset.iterator()
   const t = []
   // read only the data for the first 5 rows
   // all the data need not to be read once 
   // since it will consume a lot of memory
   for (let i = 0; i < 5; i++) {
        let e = await it.next()
      t.push(...e.value)
   }
  tf.concat(await t, 0)

使用for await of

代码语言:javascript
复制
const asyncIterable = {
  [Symbol.asyncIterator]() {
    return {
      i: 0,
      async next() {
        if (this.i < 5) {
          this.i++
          const e = await it.next()
          return Promise.resolve({ value: e.value, done: false });
        }

        return Promise.resolve({ done: true });
      }
    };
  }
};

  const t = []
  for await (let e of asyncIterable) {
        if(e) {
          t.push(e)
        }
   }

代码语言:javascript
复制
const csvUrl =
'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';

(async function run() {
   // We want to predict the column "medv", which represents a median value of
   // a home (in $1000s), so we mark it as a label.
   const csvDataset = tf.data.csv(
     csvUrl, {
       columnConfigs: {
         medv: {
           isLabel: true
         }
       }
     });

   // Number of features is the number of column names minus one for the label
   // column.
   const numOfFeatures = (await csvDataset.columnNames()).length - 1;

   // Prepare the Dataset for training.
   const flattenedDataset =
     csvDataset
     .map(([rawFeatures, rawLabel]) =>
       // Convert rows from object form (keyed by column name) to array form.
       [...Object.values(rawFeatures), ...Object.values(rawLabel)])
   			.batch(1)
  
	const it = await flattenedDataset.iterator()
  const asyncIterable = {
  [Symbol.asyncIterator]() {
    return {
      i: 0,
      async next() {
        if (this.i < 5) {
          this.i++
          const e = await it.next()
          return Promise.resolve({ value: e.value, done: false });
        }

        return Promise.resolve({ done: true });
      }
    };
  }
};
  
  const t = []
  for await (let e of asyncIterable) {
    	if(e) {
          t.push(e)
        }
   }
  console.log(tf.concat(t, 0).shape)
})()
代码语言:javascript
复制
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.14.1"> </script>
  </head>

  <body>
  </body>
</html>

票数 0
EN

Stack Overflow用户

发布于 2019-03-03 05:24:30

请注意,通常不推荐使用此工作流,因为将主JavaScript内存中的所有数据具体化可能不适用于大型CSV数据集。

您可以使用tf.data.Dataset对象的toArray()方法。例如:

代码语言:javascript
复制
  const csvUrl =
'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';

  const csvDataset = tf.data.csv(
     csvUrl, {
       columnConfigs: {
         medv: {
           isLabel: true
         }
       }
     }).batch(4);

  const tensors = await csvDataset.toArray();
  console.log(tensors.length);
  console.log(tensors[0][0]);
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54955341

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档