码迷,mamicode.com
首页 > 其他好文 > 详细

GAN原理手写数据集生成

时间:2021-06-20 18:06:47      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:set   layer   buffer   bat   cal   mode   mic   dataset   tensor   

GAN原理介绍

  • GAN 来源于博弈论中的零和博弈,博弈双方,分别为生成模型与判别模型。
  • 生成模型G捕捉样本数据的分布,用服从某一分布例如正太,高斯分布的噪声z来生成一个类似真实训练数据的样本,追求的效果是越像真实越好。
  • 判别模型是一个二分类器,判别样本来自于训练数据还是真实数据的概率。如果来自于真实样本输出大概率,如果来自于训练数据,输出小概率。

实例demo

  • 以造小狗的假图片为例。首先生成小狗图片的模型,称之为generator,还有一个判断小狗图片是否是真假的判别模型 discrimator。
  • 首先输入一个的噪声,然后送入生成器,生成器的生成假图
  • 把真图与假图。进行拼接,然后打上标签,真图标签是1,假图标签是0,送入鉴别器,鉴别器输出属于真实样本与训练样本的概率。

技术图片

实际GAN(手写数据集为例)

数据预处理

导入函数库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

划分数据集

(train_image,train_labels),_=keras.datasets.mnist.load_data()

数据类型转换于归一化[-1,1]

train_images=2*tf.cast(train_image,tf.float32)/255.-1

expand_dims 设置通道,-1 加一维

train_images=2*tf.cast(train_image,tf.float32)/255.-1
# expand_dims 设置通道,-1 加一维
train_images=tf.expand_dims(train_images,-1)
train_images.shape

常用参数设置与数据集生成

Batch_Size=256
# 每回使用256
Buffer_Size=60000 #乱序范围
# 构建demo使用的数据集
dataset=tf.data.Dataset.from_tensor_slices(train_images).shuffle(Buffer_Size).batch(Batch_Size)

GAN模型的生成器与鉴别器构建

def generator_model():
  # 第一层
  model=tf.keras.Sequential()
  # 100->256
  # 第一层
  model.add(layers.Dense(256,input_shape=(100,),use_bias=False))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())
  # 第二层
  #256->512
  model.add(layers.Dense(512,use_bias=False))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())
  # 第三层
  #512->28*28
  model.add(layers.Dense(28*28,use_bias=False,activation="tanh"))
  model.add(layers.BatchNormalization())
  model.add(layers.Reshape([28,28,1]))
  
  return model

定义判别器

def discriminator_model():
  model=tf.keras.Sequential()
  # 第一层
  model.add(layers.Flatten())
  model.add(layers.Dense(512,use_bias=False))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())
  # 第二层
  model.add(layers.Dense(512, use_bias=False))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())
  model.add(layers.Dense(1))
  # 输出一个值
  return model

设置损失函数

定义优化器

generator_opt=tf.keras.optimizers.Adam(0.0001)
discriminator_opt=tf.keras.optimizers.Adam(0.0001)
cross_entropy=keras.losses.BinaryCrossentropy(from_logits=True)

计算判别器损失

def discriminator_loss(real_out,fake_out):
  real_loss=cross_entropy(tf.ones_like(real_out),real_out)
  fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)
  return real_loss+fake_loss

计算生成器损失

def generator_loss(fake_out):
    fake_loss = cross_entropy(tf.ones_like(fake_out), fake_out)
    return fake_loss

定义训练step

Epochs=100
input_dim=100
num_exp_to_generate=16
# 生成16*100
seed=tf.random.normal([num_exp_to_generate,input_dim])
# 定义训练步骤
generator=generator_model()
discriminator=discriminator_model()
def train_step(images):
  noise=tf.random.normal([Batch_Size,input_dim])
  with tf.GradientTape() as gen_tape,tf.GradientTape() as dis_tape:
    real_out=discriminator(images)
    gen_img=generator(noise)
    fake_out=discriminator(gen_img)
    dis_loss=discriminator_loss(real_out,fake_out)
    gen_loss=generator_loss(fake_out) 
    # 梯度下降参数计算
    gen_gard=gen_tape.gradient(gen_loss,generator.trainable_variables)
    dis_gard=dis_tape.gradient(dis_loss,discriminator.trainable_variables)
    # 进行参数更新,并反传
    discriminator_opt.apply_gradients(zip(dis_gard,discriminator.trainable_variables))
    generator_opt.apply_gradients(zip(gen_gard, generator.trainable_variables))
# 绘制训练后的图
def genrate_plot_image(gen_model,test_noise):
    pre_images=gen_model(test_noise,training=False)
    fig=plt.figure(figsize=(8,6))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((pre_images[i,:,:,0]+1)/2*255.)
        plt.axis(‘off‘)
    plt.show()

epoch设置

def train(dataset,epochs):
for epoch in range(epochs):
print(epoch)
for image_batch in dataset:
train_step(image_batch)
print(epoch)
if epoch%10==0:
print(epoch)
genrate_plot_image(generator,seed)

main函数

if __name__ == ‘__main__‘:
    train(dataset,500)
0
0
0

技术图片

1
1
2
2
3
3
4
4
5
5
6
6
7
7
8
8
9
9
10
10
10

技术图片

11
11
12
12
13
13
14
14
15
15
16
16
17
17
18
18
19
19
20
20
20

技术图片

21
21
22
22
23
23
24
24
25
25
26
26
27
27
28
28
29
29
30
30
30

技术图片

31
31
32
32
33
33
34
34
35
35
36
36
37
37
38
38
39
39
40
40
40

技术图片

41
41
42
42
43
43
44
44
45
45
46
46
47
47
48
48
49
49
50
50
50

技术图片

51
51
52
52
53
53
54
54
55
55
56
56
57
57
58
58
59
59
60
60
60

技术图片

61
61
62
62
63
63
64
64
65
65
66
66
67
67
68
68
69
69
70
70
70

技术图片

71
71
72
72
73
73
74
74
75
75
76
76
77
77
78
78
79
79
80
80
80

技术图片

81
81
82
82
83
83
84
84
85
85
86
86
87
87
88
88
89
89
90
90
90

技术图片

91
91
92
92
93
93
94
94
95
95
96
96
97
97
98
98
99
99
100
100
100

技术图片

101
101
102
102
103
103
104
104
105
105
106
106
107
107
108
108
109
109
110
110
110

技术图片

111
111
112
112
113
113
114
114
115
115
116
116
117
117
118
118
119
119
120
120
120

技术图片

121
121
122
122
123
123
124
124
125
125
126
126
127
127
128
128
129
129
130
130
130

技术图片

131
131
132
132
133
133
134
134
135
135
136
136
137
137
138
138
139
139
140
140
140

技术图片

141
141
142
142
143
143
144
144
145
145
146
146
147
147
148
148
149
149
150
150
150

技术图片

151
151
152
152
153
153
154
154
155
155
156
156
157
157
158
158
159
159
160
160
160

技术图片

161
161
162
162
163
163
164
164
165
165
166
166
167
167
168
168
169
169
170
170
170

技术图片

171
171
172
172
173
173
174
174
175
175
176
176
177
177
178
178
179
179
180
180
180

技术图片

181
181
182
182
183
183
184
184
185
185
186
186
187
187
188
188
189
189
190
190
190

技术图片

191
191
192
192
193
193
194
194
195
195
196
196
197
197
198
198
199
199
200
200
200

技术图片

201
201
202
202
203
203
204
204
205
205
206
206
207
207
208
208
209
209
210
210
210

技术图片

211
211
212
212
213
213
214
214
215
215
216
216
217
217
218
218
219
219
220
220
220

技术图片

221
221
222
222
223
223
224
224
225
225
226
226
227
227
228
228
229
229
230
230
230

技术图片

231
231
232
232
233
233
234
234
235
235
236
236
237
237
238
238
239
239
240
240
240

技术图片

241
241
242
242
243
243
244
244
245
245
246
246
247
247
248
248
249
249
250
250
250

技术图片

251
251
252
252
253
253
254
254
255
255
256
256
257
257
258
258
259
259
260
260
260

技术图片

261
261
262
262
263
263
264
264
265
265
266
266
267
267
268
268
269
269
270
270
270

技术图片

271
271
272
272
273
273
274
274
275
275
276
276
277
277
278
278
279
279
280
280
280

技术图片

281
281
282
282
283
283
284
284
285
285
286
286
287
287
288
288
289
289
290
290
290

技术图片

291
291
292
292
293
293
294
294
295
295
296
296
297
297
298
298
299
299
300
300
300

技术图片

301
301
302
302
303
303
304
304
305
305
306
306
307
307
308
308
309
309
310
310
310

技术图片

311
311
312
312
313
313
314
314
315
315
316
316
317
317
318
318
319
319
320
320
320

技术图片

321
321
322
322
323
323
324
324
325
325
326
326
327
327
328
328
329
329
330
330
330

技术图片

331
331
332
332
333
333
334
334
335
335
336
336
337
337
338
338
339
339
340
340
340

技术图片

341
341
342
342
343
343
344
344
345
345
346
346
347
347
348
348
349
349
350
350
350

技术图片

351
351
352
352
353
353
354
354
355
355
356
356
357
357
358
358
359
359
360
360
360

技术图片

361
361
362
362
363
363
364
364
365
365
366
366
367
367
368
368
369
369
370
370
370

技术图片

371
371
372
372
373
373
374
374
375
375
376
376
377
377
378
378
379
379
380
380
380

技术图片

381
381
382
382
383
383
384
384
385
385
386
386
387
387
388
388
389
389
390
390
390

技术图片

391
391
392
392
393
393
394
394
395
395
396
396
397
397
398
398
399
399
400
400
400

技术图片

401
401
402
402
403
403
404
404
405
405
406
406
407
407
408
408
409
409
410
410
410

技术图片

411
411
412
412
413
413
414
414
415
415
416
416
417
417
418
418
419
419
420
420
420

技术图片

421
421
422
422
423
423
424
424
425
425
426
426
427
427
428
428
429
429
430
430
430

技术图片

431
431
432
432
433
433
434
434
435
435
436
436
437
437
438
438
439
439
440
440
440

技术图片

441
441
442
442
443
443
444
444
445
445
446
446
447
447
448
448
449
449
450
450
450

技术图片

451
451
452
452
453
453
454
454
455
455
456
456
457
457
458
458
459
459
460
460
460

技术图片

461
461
462
462
463
463
464
464
465
465
466
466
467
467
468
468
469
469
470
470
470

技术图片

471
471
472
472
473
473
474
474
475
475
476
476
477
477
478
478
479
479
480
480
480

技术图片

481
481
482
482
483
483
484
484
485
485
486
486
487
487
488
488
489
489
490
490
490

技术图片

491
491
492
492
493
493
494
494
495
495
496
496
497
497
498
498
499
499

生成器描述

generator.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 256)               25600     
_________________________________________________________________
batch_normalization (BatchNo (None, 256)               1024      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               131072    
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 784)               401408    
_________________________________________________________________
batch_normalization_2 (Batch (None, 784)               3136      
_________________________________________________________________
reshape (Reshape)            (None, 28, 28, 1)         0         
=================================================================
Total params: 564,288
Trainable params: 561,184
Non-trainable params: 3,104
_________________________________________________________________

GAN原理手写数据集生成

标签:set   layer   buffer   bat   cal   mode   mic   dataset   tensor   

原文地址:https://www.cnblogs.com/hufeng2021/p/14906175.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!