你好,像一个标题,我试着用合成软件包来表示时间序列GAN,我第一次想把整数加到输出也是数字,但它不是,输出数据是我使用ydata合成(https://github.com/ydataai/ydata-synthetic)的十进制数
这是我制作数据的代码,请帮助我
#导入练习所需的库
from os import path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from ydata_synthetic.synthesizers import ModelParameters
from ydata_synthetic.preprocessing.timeseries import processed_stock
from ydata_synthetic.synthesizers.timeseries import TimeGAN
import torch
arr_data = np.random.randint(0,600000,(100,1))
#Specific to TimeGANs
#stock_data
seq_len=20
n_seq = 1 #number of columns
hidden_dim=24
gamma=1
noise_dim = 32
dim = 128
batch_size = len(arr_data) - seq_len
log_step = 100
learning_rate = 5e-4
gan_args = ModelParameters(batch_size=batch_size,
lr=learning_rate,
noise_dim=noise_dim,
layers_dim=dim)
lst_temp = []
for i in range(0,len(arr_data) - seq_len):
_x = arr_data[i:i+20]
lst_temp.append(_x)
tens_rand_data = torch.tensor(lst_temp)
lst_rand_data = tens_rand_data.numpy()
synth = TimeGAN(model_parameters=gan_args, hidden_dim=24, seq_len=seq_len, n_seq=n_seq, gamma=1)
synth.train(lst_rand_data, train_steps=10)
synth_data = synth.sample(len(lst_rand_data))
print(synth_data.shape)
cols = ['Car price']
for j, col in enumerate(cols):
df = pd.DataFrame({'Real': lst_rand_data[-1][:, j],'Synthetic': synth_data[-1][:, j]})
df.plot(title = "Car price",secondary_y='Synthetic data', style=['-', '--'])
print(df)发布于 2021-12-14 17:08:28
您的输入应该在安装到MinMaxScaler之前使用TimeGAN进行处理,并且您将始终收到0到1之间的十进制输出,这是因为它的生成器的最后一层激活了sigmoid。
您可以通过以下两种方式更改代码:
https://stackoverflow.com/questions/69316150
复制相似问题