我正在与特拉克斯合作,这是一个由Google构建的框架,用于使用深度学习模型作为TensorFlow的替代方案。作为TensorFlow开发人员,我非常习惯于使用model.summary()方法(文档化的这里)来显示完整的模型摘要,例如:
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 16, 303)] 0
_________________________________________________________________
bidirectional (Bidirectional (None, 16, 256) 442368
_________________________________________________________________
time_distributed (TimeDistri (None, 16, 22) 5654
=================================================================
Total params: 448,022
Trainable params: 448,022
Non-trainable params: 0特拉克斯有类似的东西吗?
发布于 2022-07-28 08:24:53
目前,在Trax中似乎没有类似于.summary()的方法;最接近的是您可以打印模型。适应文档中的示例
from trax import layers as tl
model = tl.Serial(
tl.Embedding(vocab_size=8192, d_feature=256),
tl.Mean(axis=1), # Average on axis 1 (length of sentence).
tl.Dense(2), # Classify 2 classes.
)
print(model)结果:
Serial[
Embedding_8192_256
Mean
Dense_2
]虽然没有像Tensorflow的model.summary()那样详细,但是打印输出中仍然有有用的信息:注意嵌入层的参数包含在打印输出中;还请注意,如果将模型的最后一层更改为tl.Dense(3),相应的输出将更改为Dense_3。
https://stackoverflow.com/questions/73143800
复制相似问题