-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdiffusion_signal_4.py
More file actions
118 lines (103 loc) · 11.4 KB
/
diffusion_signal_4.py
File metadata and controls
118 lines (103 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import EMA
import torch.utils.data
from MLPDiffusion import *
x=torch.arange(1,801)
x=x.reshape(800,1)
x=(x-1)/799
lst=[147.741, 149.912, 149.407, 143.666, 145.694, 152.160, 143.794, 145.824, 144.888, 139.007, 145.912, 146.720, 147.598, 149.353, 150.959, 150.355, 135.796, 153.881, 144.815, 148.722, 140.750, 153.374, 146.950, 146.584, 151.527, 150.840, 148.849, 147.470, 145.923, 146.889, 147.034, 148.733, 147.643, 143.527, 150.587, 141.901, 147.540, 150.032, 142.701, 149.981, 146.962, 150.180, 139.487, 148.097, 150.173, 147.034, 141.807, 150.150, 148.457, 152.562, 149.416, 150.971, 144.991, 145.157, 150.696, 150.454, 147.079, 141.786, 141.411, 150.048, 142.118, 152.596, 148.003, 148.252, 148.624, 142.316, 148.677, 145.439, 144.154, 149.722, 147.375, 150.102, 145.668, 136.851, 139.936, 146.556, 148.534, 140.727, 140.635, 149.780, 149.840, 150.284, 144.063, 145.804, 134.899, 149.308, 142.406, 150.466, 151.156, 148.075, 141.614, 147.030, 147.509, 151.175, 152.114, 136.881, 143.802, 149.467, 145.558, 152.058, 139.366, 141.902, 148.539, 150.613, 150.893, 145.742, 145.985, 151.600, 151.243, 150.580, 135.247, 151.345, 149.299, 148.338, 149.222, 147.479, 147.767, 136.224, 151.600, 146.378, 151.191, 149.089, 139.519, 145.665, 143.314, 150.780, 147.440, 149.935, 138.983, 144.215, 149.490, 144.035, 148.114, 140.492, 141.722, 150.715, 150.517, 152.496, 150.745, 149.670, 152.597, 144.963, 145.117, 140.759, 155.108, 148.865, 149.426, 151.098, 147.363, 150.312, 145.158, 128.023, 148.074, 147.999, 147.036, 148.103, 147.873, 142.301, 144.098, 149.363, 149.441, 149.780, 151.697, 152.320, 151.143, 152.247, 149.372, 148.960, 152.353, 152.159, 150.022, 145.040, 147.292, 147.263, 153.294, 141.466, 150.804, 145.610, 146.022, 150.511, 147.562, 149.506, 153.100, 145.790, 151.962, 153.994, 149.066, 146.970, 151.455, 147.550, 152.357, 143.499, 151.097, 152.497, 149.006, 142.770, 144.885, 148.990, 148.925, 144.458, 150.742, 150.065, 151.872, 147.648, 154.372, 152.008, 150.486, 144.855, 148.064, 152.184, 140.785, 149.028, 149.626, 148.651, 150.797, 151.357, 152.541, 152.571, 153.347, 154.687, 143.746, 147.275, 153.266, 151.020, 150.797, 150.333, 151.395, 145.174, 150.959, 149.927, 149.782, 152.073, 148.134, 145.642, 147.800, 149.532, 147.532, 148.636, 148.131, 150.592, 146.944, 153.888, 150.054, 150.815, 150.433, 151.175, 147.559, 151.070, 143.747, 152.143, 150.039, 150.291, 149.385, 153.815, 151.305, 149.422, 153.130, 153.375, 149.547, 154.254, 152.280, 151.294, 152.346, 155.219, 152.249, 151.220, 152.070, 152.445, 151.388, 155.402, 150.496, 155.905, 154.861, 153.020, 155.269, 153.949, 152.299, 148.305, 152.680, 151.538, 151.998, 155.463, 150.143, 153.752, 153.586, 154.978, 153.120, 153.772, 152.471, 155.197, 151.335, 150.567, 151.958, 154.638, 157.695, 155.763, 154.127, 150.422, 155.720, 155.648, 152.972, 156.382, 155.680, 154.466, 154.593, 157.116, 156.728, 156.818, 151.306, 155.209, 156.733, 154.006, 153.874, 155.097, 154.665, 154.199, 155.533, 155.232, 157.351, 154.666, 157.130, 154.805, 154.576, 159.350, 157.411, 159.444, 154.953, 158.647, 155.812, 158.957, 156.573, 157.152, 158.557, 157.302, 156.926, 157.388, 158.881, 159.212, 159.608, 158.396, 159.155, 158.198, 158.596, 158.132, 156.970, 159.004, 158.198, 159.075, 159.739, 159.360, 159.449, 160.383, 160.117, 159.817, 159.761, 160.765, 162.122, 162.300, 163.393, 162.417, 160.196, 161.987, 162.259, 160.915, 161.524, 162.343, 162.830, 162.993, 162.065, 162.559, 163.457, 163.783, 164.099, 163.965, 164.236, 164.951, 164.675, 166.070, 165.993, 166.568, 165.483, 167.445, 167.361, 168.424, 168.127, 168.536, 168.468, 169.250, 169.812, 170.427, 170.060, 171.475, 172.073, 172.634, 173.738, 174.486, 175.275, 176.438, 177.368, 178.758, 212.955, 182.431, 185.024, 188.769, 195.346, 212.890, 193.284, 187.789, 184.476, 181.921, 180.176, 178.579, 177.279, 176.491, 175.084, 174.473, 173.315, 172.585, 172.045, 171.700, 170.781, 170.127, 169.479, 169.238, 168.721, 168.232, 168.155, 168.208, 167.156, 166.480, 166.543, 165.504, 164.909, 164.850, 165.161, 165.113, 164.608, 164.535, 164.387, 163.577, 163.978, 163.787, 162.949, 163.001, 162.891, 164.190, 163.256, 159.757, 161.731, 163.624, 161.124, 161.204, 159.736, 161.070, 160.737, 160.491, 159.059, 159.373, 159.844, 159.487, 155.268, 157.179, 159.057, 160.172, 158.851, 159.487, 158.958, 156.874, 159.466, 160.099, 157.377, 157.757, 156.603, 156.206, 154.062, 157.432, 157.367, 155.215, 156.231, 156.665, 154.183, 157.276, 157.687, 157.461, 154.923, 156.442, 158.527, 156.850, 156.554, 155.398, 155.885, 153.614, 156.250, 155.179, 157.150, 158.098, 156.719, 155.187, 156.604, 157.429, 152.826, 155.506, 154.501, 156.837, 153.468, 151.720, 152.339, 152.329, 154.516, 159.251, 153.294, 158.212, 154.967, 155.760, 150.488, 154.309, 156.134, 155.836, 149.917, 155.174, 153.097, 155.722, 154.249, 152.135, 155.510, 153.126, 153.962, 152.449, 153.558, 150.319, 149.510, 153.647, 152.426, 153.092, 154.739, 141.136, 153.202, 151.958, 153.582, 150.844, 147.229, 153.392, 153.824, 153.193, 155.663, 151.203, 150.469, 149.363, 154.941, 151.470, 151.012, 148.255, 150.445, 146.699, 151.088, 150.631, 150.013, 143.732, 149.190, 151.138, 150.461, 149.917, 149.325, 151.852, 149.795, 146.773, 142.513, 148.044, 156.092, 149.984, 155.387, 145.190, 148.452, 153.953, 150.361, 138.452, 151.802, 145.301, 152.615, 149.960, 152.933, 151.639, 153.243, 150.624, 153.091, 149.782, 151.880, 150.035, 153.982, 147.255, 149.665, 147.263, 149.899, 153.403, 147.715, 152.718, 150.227, 148.645, 150.786, 156.014, 149.027, 151.589, 153.205, 149.712, 151.056, 143.958, 151.401, 142.000, 146.356, 149.576, 152.605, 153.740, 154.460, 142.020, 153.195, 146.273, 150.504, 144.915, 144.111, 149.883, 144.998, 151.109, 150.463, 145.941, 149.981, 148.784, 149.861, 149.793, 152.704, 143.804, 153.496, 154.666, 149.086, 144.635, 152.086, 153.547, 147.053, 148.365, 145.241, 146.742, 154.134, 150.392, 145.066, 142.684, 143.487, 150.337, 152.477, 140.536, 149.041, 149.674, 142.015, 146.458, 145.850, 152.229, 145.644, 147.582, 150.376, 142.556, 140.570, 151.289, 145.879, 150.669, 147.918, 152.638, 144.606, 144.539, 148.944, 151.019, 152.536, 145.460, 152.262, 148.566, 145.402, 144.952, 149.980, 147.926, 151.943, 149.956, 151.250, 150.908, 146.520, 148.897, 148.311, 138.873, 147.952, 151.138, 151.272, 151.158, 150.164, 150.000, 151.639, 149.286, 150.780, 148.098, 148.069, 152.707, 151.752, 152.135, 144.123, 142.145, 149.004, 147.354, 153.986, 145.044, 149.970, 148.776, 144.805, 143.747, 149.825, 147.619, 147.093, 146.977, 147.650, 148.214, 146.625, 144.055, 149.971, 123.598, 143.527, 144.628, 142.117, 148.765, 146.966, 149.979, 146.969, 149.789, 147.880, 143.849, 144.398, 146.475, 152.721, 152.664, 150.211, 144.630, 144.644, 150.949, 149.367, 147.359, 144.647, 143.657, 148.506, 152.987, 148.248, 141.436, 138.874, 148.204, 138.372, 143.431, 149.992, 147.910, 148.011, 148.275, 149.499, 149.205, 150.582, 149.351, 144.445, 147.014, 136.469, 150.595, 153.750, 144.890, 147.877, 138.343, 146.009, 147.967, 146.485, 144.948, 145.971, 141.168, 149.267, 141.635, 142.412, 148.063, 150.530, 146.877, 147.650, 147.203, 151.179, 152.295, 145.303, 147.572, 144.491, 142.364, 143.014, 143.225, 143.205, 143.243, 150.398, 147.349, 151.073, 141.537, 138.800, 151.869, 146.346, 149.535, 139.166, 148.198, 141.775, 150.236]
lst=[(x-123.5976)/(212.9549-123.5976) for x in lst]
lst_tensor=torch.tensor(lst)/10
y=lst_tensor.reshape(800,1)
data=torch.cat([x,y],1)
dataset=torch.Tensor(data).float()
data=data.T
fig,ax=plt.subplots()
ax.plot(*data,color='red')# *data表示将两列看成两个参数传入函数
ax.axis('off')
num_steps=100
betas=torch.linspace(-6,6,num_steps) #beta初始化
betas=torch.sigmoid(betas)*( 0.5e-2 - 1e-5) + 1e-5 #beta做变换
alphas=1-betas #α=1-β
alphas_prod=torch.cumprod(alphas,0) #0维α累积,0维是↓
alphas_prod_p=torch.cat([torch.tensor([1]).float(),alphas_prod[:-1]],0)
alphas_bar_sqrt=torch.sqrt(alphas_prod)
one_minus_alphas_bar_log=torch.log(1-alphas_prod)
one_minus_alphas_bar_sqrt=torch.sqrt(1-alphas_prod)
assert alphas.shape==alphas_prod.shape==alphas_prod_p.shape==alphas_bar_sqrt.shape==one_minus_alphas_bar_log.shape\
==one_minus_alphas_bar_sqrt.shape
print("all the same shape:",betas.shape)
#计算任意时刻x的采样值,基于x_0和参数重整化技巧
def q_x(x_0,t):
noise=torch.randn_like(x_0) #torch.randn生成正态分布随机噪声
alphas_t=alphas_bar_sqrt[t] #算 sqrt(α_bar)
alphas_1_m_t= one_minus_alphas_bar_sqrt[t]#算 sqrt(1-α bar)
return (alphas_t*x_0+alphas_1_m_t*noise) #
num_shows=20
fig, axs=plt.subplots(2,10,figsize=(28,3))
plt.rc('text',color='blue')
for i in range(num_shows):
j=i//10
k=i%10
q_i=q_x(dataset,torch.tensor([i*num_steps//num_shows])) #0 5 10 15 ...step
# axs[j,k].sctter(q_i[:,0],q_i[:,1],color='red',edgecolor='white')#
x=q_i[:,0]
y=q_i[:,1]
sorted_indices=torch.argsort(x)
x_sorted=x[sorted_indices]
y_sorted=y[sorted_indices]
# axs[j, k].scatter(q_i[:,0],q_i[:, 1], color='red') #
size=2
# axs[j, k].scatter(x_sorted, y_sorted, color='red',s=size)
axs[j, k].plot(x_sorted, y_sorted, color='red',lw=1)
axs[j,k].set_axis_off()
axs[j,k].set_title('$q(\mathbf{x}_{'+str(i*num_steps//num_shows)+'})$')
plt.show()
seed=1234
print('Training model...')
batch_size=10
dataloader=torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
num_epoch=5000
plt.rc('text',color='blue')
model=MLPDiffusion(num_steps)
optimizer=torch.optim.Adam(model.parameters(),lr= 1e-3)
for t in range(num_epoch):
for idx,batch_x in enumerate(dataloader):#共40000个点,然后每次取128个点
#把这个batch里面放进model,算
loss=diffusion_loss_fn(model,batch_x,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,num_steps)
optimizer.zero_grad()#每一次循环要更新一次梯度的计算,避免梯度的累积
loss.backward()#计算出的loss反向传播进行梯度求导,并保存梯度
torch.nn.utils.clip_grad_norm_(model.parameters(),1.)#裁剪梯度防止梯度爆炸,
#计算所有参数的梯度的范数,并根据指定的最大梯度范数对梯度进行裁剪。
#如果梯度的范数超过指定最大值,则将梯度按比例缩放。使其范数等于最大梯度范数。避免梯度爆炸问题
optimizer.step()#反向传播结束后,利用计算得到的梯度信息更新模型参数
#Adam算法利用梯度的一阶和二阶矩估计来更新参数。
if (t%100==0):
print(loss)
x_seq=p_sample_loop(model,dataset.shape,num_steps,betas,one_minus_alphas_bar_sqrt)
if t>3000:
fig,axs=plt.subplots(1,10,figsize=(28,3))
for i in range(1,11):
cur_x=x_seq[i*10].detach()
x_1 = cur_x[:, 0]
y_1 = cur_x[:, 1]
sorted_indices = torch.argsort(x_1)
x_sorted = x_1[sorted_indices]
y_sorted = y_1[sorted_indices]
# axs[i-1].scatter(x_sorted,y_sorted,color='red',s=2)
axs[i-1].plot(x_sorted,y_sorted,color='red',lw=1)
axs[i-1].set_axis_off()
axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')
plt.show()
# fig=plt.gcf()
# x=cur_x[:,0]
# y=cur_x[:,1]
# sorted_indices=torch.argsort(x)
# x_sorted=x[sorted_indices]
# y_sorted=y[sorted_indices]
# plt.plot(x_sorted,y_sorted,lw=1)
# plt.axis('off')
# fig.savefig('save.png')