-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Expand file tree
/
Copy pathdropout_rnn_cell.py
More file actions
53 lines (43 loc) · 2.08 KB
/
dropout_rnn_cell.py
File metadata and controls
53 lines (43 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from keras.src import backend
from keras.src import ops
class DropoutRNNCell:
"""Object that holds dropout-related functionality for RNN cells.
This class is not a standalone RNN cell. It suppose to be used with a RNN
cell by multiple inheritance. Any cell that mix with class should have
following fields:
- `dropout`: a float number in the range `[0, 1]`.
Dropout rate for the input tensor.
- `recurrent_dropout`: a float number in the range `[0, 1]`.
Dropout rate for the recurrent connections.
- `seed_generator`, an instance of `backend.random.SeedGenerator`.
This object will create and cache dropout masks, and reuse them for
all incoming steps, so that the same mask is used for every step.
"""
def get_dropout_mask(self, step_input):
if not hasattr(self, "_dropout_mask"):
self._dropout_mask = None
if self._dropout_mask is None and self.dropout > 0:
ones = ops.ones_like(step_input)
self._dropout_mask = backend.random.dropout(
ones, rate=self.dropout, seed=self.seed_generator
)
return self._dropout_mask
def get_recurrent_dropout_mask(self, step_input):
if not hasattr(self, "_recurrent_dropout_mask"):
self._recurrent_dropout_mask = None
if self._recurrent_dropout_mask is None and self.recurrent_dropout > 0:
ones = ops.ones_like(step_input)
self._recurrent_dropout_mask = backend.random.dropout(
ones, rate=self.recurrent_dropout, seed=self.seed_generator
)
return self._recurrent_dropout_mask
def reset_dropout_mask(self):
"""Reset the cached dropout mask if any.
The RNN layer invokes this in the `call()` method
so that the cached mask is cleared after calling `cell.call()`. The
mask should be cached across all timestep within the same batch, but
shouldn't be cached between batches.
"""
self._dropout_mask = None
def reset_recurrent_dropout_mask(self):
self._recurrent_dropout_mask = None