首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用trax进行多元回归

使用trax进行多元回归
EN

Stack Overflow用户
提问于 2020-08-28 06:20:02
回答 1查看 104关注 0票数 1

如何使用Trax设置多变量回归问题

我从下面的代码中获得来自L2Loss对象的AssertionError: Invalid shape (16, 2); expected (16,).

以下是我将sentiment analysis example改编为回归问题的尝试:

代码语言:javascript
复制
import os
import trax
from trax import layers as tl
from trax.supervised import training
import numpy
import random


#train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
#eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()


def generate_samples():
    # (text, lat/lon)
    data= [
        ("Aberdeen MS",  numpy.array((33.824742, -88.554591)) ),
        ("Aberdeen SD", numpy.array((45.463186, -98.471033))),
        ("Aberdeen WA", numpy.array((46.976432, -123.795781))),
        ("Amite City LA", numpy.array((30.733723, -90.5208))),
        ("Amory MS", numpy.array((33.984789, -88.48001))),
        ("Amouli AS", numpy.array((-14.26556, -170.589772))),
        ("Amsterdam NY", numpy.array((42.953149, -74.19505)))
    ]
    for i in range(1024*8):
        yield random.choice(data)


train_stream = generate_samples()
eval_stream = generate_samples()

model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),  # Average on axis 1 (length of sentence).
    tl.Dense(2),      # Regress to lat/lon
#    tl.LogSoftmax()   # Produce log-probabilities.
)

# You can print model structure.
print(model)



print(next(train_stream))  # See one example.

data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
#    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[   8, 128,],
                             batch_sizes=[256,   64, 4],
                             length_keys=[0]),
    trax.data.AddLossWeights()
  )

train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}')  # Check the shapes.:wq



# Training task.
train_task = training.TrainTask(
    labeled_data=train_batches_stream,
#    loss_layer=tl.CrossEntropyLoss(),
    loss_layer=tl.L2Loss(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=500,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.L2Loss(),],
    n_eval_batches=20  # For less variance in eval numbers.
)
# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)
# Run 2000 steps (batches).
training_loop.run(2000)
EN

回答 1

Stack Overflow用户

发布于 2020-08-28 08:17:38

问题可能出在generate_samples()生成器中:它只生成1024*8 (=8192)样本。如果我替换掉这行

代码语言:javascript
复制
for i in range(1024*8):

通过

代码语言:javascript
复制
while True:

因此生成了无限数量的样本,您的示例在我的机器上工作。

由于generate_samples()仅生成8192样本,因此train_batches_stream仅生成32批次的256样本,因此您最多只能训练32步骤。但是,您需要2000步骤。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63624890

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档