-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Expand file tree
/
Copy pathmetric.py
More file actions
247 lines (204 loc) · 8.11 KB
/
metric.py
File metadata and controls
247 lines (204 loc) · 8.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
from keras.src import backend
from keras.src import initializers
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.saving.keras_saveable import KerasSaveable
from keras.src.utils.naming import auto_name
from keras.src.utils.tracking import Tracker
@keras_export(["keras.Metric", "keras.metrics.Metric"])
class Metric(KerasSaveable):
"""Encapsulates metric logic and state.
Args:
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Example:
```python
m = SomeMetric(...)
for input in ...:
m.update_state(input)
print('Final result: ', m.result())
```
Usage with `compile()` API:
```python
model = keras.Sequential()
model.add(keras.layers.Dense(64, activation='relu'))
model.add(keras.layers.Dense(64, activation='relu'))
model.add(keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer=keras.optimizers.RMSprop(0.01),
loss=keras.losses.CategoricalCrossentropy(),
metrics=[keras.metrics.CategoricalAccuracy()])
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
model.fit(data, labels, epochs=10)
```
To be implemented by subclasses:
* `__init__()`: All state variables should be created in this method by
calling `self.add_variable()` like: `self.var = self.add_variable(...)`
* `update_state()`: Has all updates to the state variables like:
`self.var.assign(...)`.
* `result()`: Computes and returns a scalar value or a dict of scalar values
for the metric from the state variables.
Example subclass implementation:
```python
class BinaryTruePositives(Metric):
def __init__(self, name='binary_true_positives', **kwargs):
super().__init__(name=name, **kwargs)
self.true_positives = self.add_variable(
shape=(),
initializer='zeros',
name='true_positives'
)
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = ops.cast(y_true, "bool")
y_pred = ops.cast(y_pred, "bool")
values = ops.logical_and(
ops.equal(y_true, True), ops.equal(y_pred, True))
values = ops.cast(values, self.dtype)
if sample_weight is not None:
sample_weight = ops.cast(sample_weight, self.dtype)
sample_weight = ops.broadcast_to(
sample_weight, ops.shape(values)
)
values = ops.multiply(values, sample_weight)
self.true_positives.assign(self.true_positives + ops.sum(values))
def result(self):
return self.true_positives
```
"""
def __init__(self, dtype=None, name=None):
self.name = name or auto_name(self.__class__.__name__)
self._dtype = dtype or backend.floatx()
self._metrics = []
self._variables = []
self._tracker = Tracker(
{
"variables": (
lambda x: isinstance(x, backend.Variable),
self._variables,
),
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
}
)
def reset_state(self):
"""Reset all of the metric state variables.
This function is called between epochs/steps,
when a metric is evaluated during training.
"""
for v in self.variables:
v.assign(ops.zeros(v.shape, dtype=v.dtype))
def update_state(self, *args, **kwargs):
"""Accumulate statistics for the metric."""
raise NotImplementedError
def stateless_update_state(self, metric_variables, *args, **kwargs):
if len(metric_variables) != len(self.variables):
raise ValueError(
"Argument `metric_variables` must be a list of tensors "
f"corresponding 1:1 to {self.__class__.__name__}().variables. "
f"Received list with length {len(metric_variables)}, but "
f"expected {len(self.variables)} variables."
)
# Gather variable mapping
mapping = list(zip(self.variables, metric_variables))
# Call in stateless scope
with backend.StatelessScope(state_mapping=mapping) as scope:
self.update_state(*args, **kwargs)
# Gather updated variables
metric_variables = []
for v in self.variables:
new_v = scope.get_current_value(v)
if new_v is not None:
metric_variables.append(new_v)
else:
metric_variables.append(v)
return metric_variables
def result(self):
"""Compute the current metric value.
Returns:
A scalar tensor, or a dictionary of scalar tensors.
"""
raise NotImplementedError
def stateless_result(self, metric_variables):
if len(metric_variables) != len(self.variables):
raise ValueError(
"Argument `metric_variables` must be a list of tensors "
f"corresponding 1:1 to {self.__class__.__name__}().variables. "
f"Received list with length {len(metric_variables)}, but "
f"expected {len(self.variables)} variables."
)
# Gather variable mapping
mapping = list(zip(self.variables, metric_variables))
# Call in stateless scope
with backend.StatelessScope(state_mapping=mapping):
res = self.result()
return res
def stateless_reset_state(self):
# Call in stateless scope
with backend.StatelessScope() as scope:
self.reset_state()
# Gather updated variables
metric_variables = []
for v in self.variables:
new_v = scope.get_current_value(v)
if new_v is not None:
metric_variables.append(new_v)
else:
metric_variables.append(v)
return metric_variables
@property
def dtype(self):
return self._dtype
def _obj_type(self):
return "Metric"
def add_variable(
self, shape, initializer, dtype=None, aggregation="sum", name=None
):
self._check_super_called()
with backend.name_scope(self.name.replace("/", ">"), caller=self):
initializer = initializers.get(initializer)
variable = backend.Variable(
initializer=initializer,
shape=shape,
dtype=dtype,
trainable=False,
aggregation=aggregation,
name=name,
)
# Prevent double-tracking
self._tracker.add_to_store("variables", variable)
return variable
def add_weight(self, shape=(), initializer=None, dtype=None, name=None):
# Backwards compatibility alias
return self.add_variable(
shape=shape, initializer=initializer, dtype=dtype, name=name
)
@property
def variables(self):
variables = self._variables[:]
for metric in self._metrics:
variables.extend(metric._variables)
return variables
def __call__(self, *args, **kwargs):
self._check_super_called()
self.update_state(*args, **kwargs)
return self.result()
def get_config(self):
"""Return the serializable config of the metric."""
return {"name": self.name, "dtype": self.dtype}
@classmethod
def from_config(cls, config):
return cls(**config)
def __setattr__(self, name, value):
# Track Variables, Layers, Metrics
if hasattr(self, "_tracker"):
value = self._tracker.track(value)
return super().__setattr__(name, value)
def _check_super_called(self):
if not hasattr(self, "_tracker"):
raise RuntimeError(
"You forgot to call `super().__init__()` "
"in the `__init__()` method. Go add it!"
)
def __repr__(self):
return f"<{self.__class__.__name__} " f"name={self.name}>"
def __str__(self):
return self.__repr__()