-
Notifications
You must be signed in to change notification settings - Fork 752
Expand file tree
/
Copy pathpy_metric.py
More file actions
224 lines (185 loc) · 6.86 KB
/
py_metric.py
File metadata and controls
224 lines (185 loc) · 6.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# coding=utf-8
# Copyright 2020 The TF-Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for Python metrics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
from typing import Any, Optional, Sequence, Text, Union
from absl import logging
import numpy as np
import six
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from tf_agents.metrics import tf_metric
from tf_agents.trajectories import trajectory as traj
from tf_agents.typing import types
from tf_agents.utils import common
PyMetricType = types.ForwardRef('PyMetric') # pylint: disable=invalid-name
MetricType = Union[tf_metric.TFStepMetric, PyMetricType]
def run_summaries(
metrics: Sequence[PyMetricType],
session: Optional[tf.compat.v1.Session] = None,
):
"""Execute summary ops for py_metrics.
Args:
metrics: A list of py_metric.Base objects.
session: A TensorFlow session-like object. If it is not provided, it will
use the current TensorFlow session context manager.
Raises:
RuntimeError: If .tf_summaries() was not previously called on any of the
`metrics`.
AttributeError: If session is not provided and there is no default session
provided by a context manager.
"""
if session is None:
default_session = tf.compat.v1.get_default_session()
if default_session is None:
raise AttributeError(
'No TensorFlow session-like object was provided, and none '
"could be retrieved using 'tf.get_default_session()'."
)
session = default_session
for metric in metrics:
if metric.summary_op is None:
raise RuntimeError(
'metric.tf_summaries() must be called on py_metric '
'{} before attempting to run '
'summaries.'.format(metric.name)
)
summary_ops = [metric.summary_op for metric in metrics]
feed_dict = dict(
(metric.summary_placeholder, metric.result()) for metric in metrics
)
session.run(summary_ops, feed_dict=feed_dict)
@six.add_metaclass(abc.ABCMeta)
class PyMetric(tf.Module):
"""Defines the interface for metrics."""
def __init__(self, name: Text, prefix: Text = 'Metrics'):
"""Creates a metric."""
super(PyMetric, self).__init__(name)
self._prefix = prefix
self._summary_placeholder = None
self._summary_op = None
@property
def prefix(self) -> Text:
"""Prefix for the metric."""
return self._prefix
@abc.abstractmethod
def reset(self):
"""Resets internal stat gathering variables used to compute the metric."""
@abc.abstractmethod
def result(self) -> Any:
"""Evaluates the current value of the metric."""
def log(self):
tag = common.join_scope(self.prefix, self.name)
logging.info('%s', '{0} = {1}'.format(tag, self.result()))
def tf_summaries(
self,
train_step: types.Int = None,
step_metrics: Sequence[MetricType] = (),
) -> tf.Operation:
"""Build TF summary op and placeholder for this metric.
To execute the op, call py_metric.run_summaries.
Args:
train_step: Step counter for training iterations. If None, no metric is
generated against the global step.
step_metrics: Step values to plot as X axis in addition to global_step.
Returns:
The summary op.
Raises:
RuntimeError: If this method has already been called (it can only be
called once).
ValueError: If any item in step_metrics is not of type PyMetric or
tf_metric.TFStepMetric.
"""
if self.summary_op is not None:
raise RuntimeError('metric.tf_summaries() can only be called once.')
tag = common.join_scope(self.prefix, self.name)
summaries = []
summaries.append(
tf.compat.v2.summary.scalar(
name=tag, data=self.summary_placeholder, step=train_step
)
)
prefix = self.prefix
if prefix:
prefix += '_'
for step_metric in step_metrics:
# Skip plotting the metrics against itself.
if self.name == step_metric.name:
continue
step_tag = '{}vs_{}/{}'.format(prefix, step_metric.name, self.name)
if isinstance(step_metric, PyMetric):
step_tensor = step_metric.summary_placeholder
elif isinstance(step_metric, tf_metric.TFStepMetric):
step_tensor = step_metric.result()
else:
raise ValueError(
'step_metric is not PyMetric or TFStepMetric: {}'.format(
step_metric
)
)
summaries.append(
tf.compat.v2.summary.scalar(
name=step_tag, data=self.summary_placeholder, step=step_tensor
)
)
self._summary_op = tf.group(*summaries)
return self._summary_op
@property
def summary_placeholder(self) -> tf.compat.v1.placeholder:
"""TF placeholder to be used for the result of this metric."""
if self._summary_placeholder is None:
result = self.result()
if not isinstance(result, (np.ndarray, np.generic)):
result = np.array(result)
dtype = tf.as_dtype(result.dtype)
shape = result.shape
self._summary_placeholder = tf.compat.v1.placeholder(
dtype, shape=shape, name='{}_ph'.format(self.name)
)
return self._summary_placeholder
@property
def summary_op(self) -> tf.Operation:
"""TF summary op for this metric."""
return self._summary_op
@staticmethod
def aggregate(metrics: Sequence[PyMetricType]) -> types.Float:
"""Aggregates a list of metrics.
The default behaviour is to return the average of the metrics.
Args:
metrics: a list of metrics, of the same class.
Returns:
The result of aggregating this metric.
"""
return np.mean([metric.result() for metric in metrics])
def __call__(self, *args):
"""Method to update the metric contents.
To change the behavior of this function, override the call method.
Different subclasses might use this differently. For instance, the
PyStepMetric takes in a trajectory, while the CounterMetric takes no
parameters.
Args:
*args: See call method of subclass for specific arguments.
"""
self.call(*args)
class PyStepMetric(PyMetric):
"""Defines the interface for metrics that operate on trajectories."""
@abc.abstractmethod
def call(self, trajectory: traj.Trajectory):
"""Processes a trajectory to update the metric.
Args:
trajectory: A trajectory.Trajectory.
"""