-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Expand file tree
/
Copy pathstateless_scope.py
More file actions
105 lines (86 loc) · 3.61 KB
/
stateless_scope.py
File metadata and controls
105 lines (86 loc) · 3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state
@keras_export("keras.StatelessScope")
class StatelessScope:
"""Scope to prevent any update to Keras Variables.
The values of variables to be used inside the scope
should be passed via the `state_mapping` argument, a
list of tuples `(k, v)` where `k` is a `KerasVariable`
and `v` is the intended value for this variable
(a backend tensor).
Updated values can be collected on scope exit via
`value = scope.get_current_value(variable)`. No updates
will be applied in-place to any variables for the duration
of the scope.
Example:
```python
state_mapping = [(k, ops.ones(k.shape, k.dtype)) for k in model.weights]
with keras.StatelessScope(state_mapping) as scope:
outputs = model.some_function(inputs)
# All model variables remain unchanged. Their new values can be
# collected via:
for k in model.weights:
new_value = scope.get_current_value(k)
print(f"New value for {k}: {new_value})
```
"""
def __init__(
self,
state_mapping=None,
collect_losses=False,
initialize_variables=True,
):
from keras.src import backend
from keras.src.backend.common.variables import KerasVariable
self.collect_losses = collect_losses
self.initialize_variables = initialize_variables
self.losses = []
self.state_mapping = {}
state_mapping = state_mapping or {}
for k, v in state_mapping:
if not isinstance(k, KerasVariable):
raise ValueError(
"Invalid reference variable in StatelessScope: "
"all keys in argument `mapping` must be KerasVariable "
f"instances. Received instead: {k}"
)
if isinstance(v, KerasVariable):
v = backend.cast(v.value, dtype=k.dtype)
else:
v = backend.convert_to_tensor(v, dtype=k.dtype)
if k.shape != v.shape:
raise ValueError(
"Invalid variable value in StatelessScope: "
"all values in argument `mapping` must be tensors with "
"a shape that matches the corresponding variable shape. "
f"For variable {k}, received invalid value {v} with shape "
f"{v.shape}."
)
self.state_mapping[id(k)] = v
def __enter__(self):
self.original_scope = get_stateless_scope()
global_state.set_global_attribute("stateless_scope", self)
return self
def add_loss(self, loss):
self.losses.append(loss)
def add_update(self, update):
variable, value = update
self.state_mapping[id(variable)] = value
def get_current_value(self, variable):
return self.state_mapping.get(id(variable), None)
def __exit__(self, *args, **kwargs):
global_state.set_global_attribute(
"stateless_scope", self.original_scope
)
if self.original_scope is None and self.initialize_variables:
# We're back in eager scope;
# if any variables were created within the stateless
# scope, we initialize them here.
from keras.src.backend.common.variables import (
initialize_all_variables,
)
initialize_all_variables()
def in_stateless_scope():
return global_state.get_global_attribute("stateless_scope") is not None
def get_stateless_scope():
return global_state.get_global_attribute("stateless_scope")