-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathlosses.py
More file actions
20 lines (14 loc) · 528 Bytes
/
losses.py
File metadata and controls
20 lines (14 loc) · 528 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import tensorflow as tf
import tensorflow.contrib.gan as tfgan
import configuration
conf = configuration.config()
def generator_loss_with_kl_KL_divergence(loss_fn):
def new_loss_fn(gan_model, **kargs):
kl_loss = tf.get_default_graph().get_tensor_by_name("Generator/KL_divergence/KL_loss:0")
return kl_loss + loss_fn(gan_model)
return new_loss_fn
def get_generator_loss(loss_fn):
if conf.is_training:
return generator_loss_with_kl_KL_divergence(loss_fn)
else:
return loss_fn