-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Expand file tree
/
Copy pathoperation.py
More file actions
290 lines (252 loc) · 10.5 KB
/
operation.py
File metadata and controls
290 lines (252 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import inspect
import textwrap
from keras.src import backend
from keras.src import dtype_policies
from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.backend.common.keras_tensor import any_symbolic_tensors
from keras.src.ops.node import Node
from keras.src.utils import python_utils
from keras.src.utils import traceback_utils
from keras.src.utils.naming import auto_name
@keras_export("keras.Operation")
class Operation:
def __init__(self, dtype=None, name=None):
if name is None:
name = auto_name(self.__class__.__name__)
if not isinstance(name, str) or "/" in name:
raise ValueError(
"Argument `name` must be a string and "
"cannot contain character `/`. "
f"Received: name={name} (of type {type(name)})"
)
self._dtype_policy = dtype_policies.get(dtype)
self.name = name
self._inbound_nodes = []
self._outbound_nodes = []
@traceback_utils.filter_traceback
def __call__(self, *args, **kwargs):
if traceback_utils.is_traceback_filtering_enabled():
# Wrap self.call to provide helpful info in case of exception
if any_symbolic_tensors(args, kwargs):
call_fn = self.symbolic_call
else:
if isinstance(
self._dtype_policy, dtype_policies.QuantizedDTypePolicy
):
call_fn = self.quantized_call
else:
call_fn = self.call
call_fn = traceback_utils.inject_argument_info_in_traceback(
call_fn,
object_name=(f"{self.__class__.__name__}.call()"),
)
return call_fn(*args, **kwargs)
# Plain flow.
if any_symbolic_tensors(args, kwargs):
return self.symbolic_call(*args, **kwargs)
if isinstance(self._dtype_policy, dtype_policies.QuantizedDTypePolicy):
return self.quantized_call(*args, **kwargs)
else:
return self.call(*args, **kwargs)
def symbolic_call(self, *args, **kwargs):
# Perform shape/dtype inference.
outputs = self.compute_output_spec(*args, **kwargs)
# Record a new node in the operations graph.
# The Node wires itself to inbound and outbound ops. The
# Node constructor updates this op's self._inbound_nodes,
# sets _keras_history on the outputs, and adds itself to the
# `_outbound_nodes` of the ops that produced the inputs to this
# call.
Node(
operation=self, call_args=args, call_kwargs=kwargs, outputs=outputs
)
return outputs
def call(self, *args, **kwargs):
raise NotImplementedError
def quantized_call(self, *args, **kwargs):
raise NotImplementedError
def compute_output_spec(self, *args, **kwargs):
try:
return backend.compute_output_spec(self.call, *args, **kwargs)
except Exception as e:
if isinstance(e, TypeError):
raise e
else:
new_e = RuntimeError(
"Could not automatically infer the output shape / dtype of "
f"'{self.name}' (of type {self.__class__.__name__}). "
f"Either the `{self.__class__.__name__}.call()` method "
f"is incorrect, or you need to implement the "
f"`{self.__class__.__name__}.compute_output_spec() / "
"compute_output_shape()` method. "
f"Error encountered:\n\n{e}"
)
raise new_e.with_traceback(e.__traceback__) from None
def __new__(cls, *args, **kwargs):
"""We override __new__ to saving serializable constructor arguments.
These arguments are used to auto-generate an object serialization
config, which enables user-created subclasses to be serializable
out of the box in most cases without forcing the user
to manually implement `get_config()`.
"""
# Generate a config to be returned by default by `get_config()`.
arg_names = inspect.getfullargspec(cls.__init__).args
kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
instance = super(Operation, cls).__new__(cls)
# For safety, we only rely on auto-configs for a small set of
# serializable types.
supported_types = (str, int, float, bool, type(None))
try:
flat_arg_values = tree.flatten(kwargs)
auto_config = True
for value in flat_arg_values:
if not isinstance(value, supported_types):
auto_config = False
break
except TypeError:
auto_config = False
try:
instance._lock = False
if auto_config:
from keras.src.saving import serialization_lib
instance._auto_config = serialization_lib.SerializableDict(
**kwargs
)
else:
instance._auto_config = None
instance._lock = True
except RecursionError:
# Setting an instance attribute in __new__ has the potential
# to trigger an infinite recursion if a subclass overrides
# setattr in an unsafe way.
pass
return instance
@python_utils.default
def get_config(self):
"""Returns the config of the object.
An object config is a Python dictionary (serializable)
containing the information needed to re-instantiate it.
"""
config = {
"name": self.name,
}
if not python_utils.is_default(self.get_config):
# In this case the subclass implements get_config()
return config
# In this case the subclass doesn't implement get_config():
# Let's see if we can autogenerate it.
if getattr(self, "_auto_config", None) is not None:
xtra_args = set(config.keys())
config.update(self._auto_config.config)
# Remove args non explicitly supported
argspec = inspect.getfullargspec(self.__init__)
if argspec.varkw != "kwargs":
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
config.pop(key, None)
return config
else:
raise NotImplementedError(
textwrap.dedent(
f"""
Object {self.__class__.__name__} was created by passing
non-serializable argument values in `__init__()`,
and therefore the object must override `get_config()` in
order to be serializable. Please implement `get_config()`.
Example:
class CustomLayer(keras.layers.Layer):
def __init__(self, arg1, arg2, **kwargs):
super().__init__(**kwargs)
self.arg1 = arg1
self.arg2 = arg2
def get_config(self):
config = super().get_config()
config.update({{
"arg1": self.arg1,
"arg2": self.arg2,
}})
return config"""
)
)
@classmethod
def from_config(cls, config):
"""Creates a layer from its config.
This method is the reverse of `get_config`,
capable of instantiating the same layer from the config
dictionary. It does not handle layer connectivity
(handled by Network), nor weights (handled by `set_weights`).
Args:
config: A Python dictionary, typically the
output of get_config.
Returns:
A layer instance.
"""
try:
return cls(**config)
except Exception as e:
raise TypeError(
f"Error when deserializing class '{cls.__name__}' using "
f"config={config}.\n\nException encountered: {e}"
)
def __repr__(self):
return f"<Operation name={self.name}>"
@property
def input(self):
"""Retrieves the input tensor(s) of a symbolic operation.
Only returns the tensor(s) corresponding to the *first time*
the operation was called.
Returns:
Input tensor or list of input tensors.
"""
return self._get_node_attribute_at_index(0, "input_tensors", "input")
@property
def output(self):
"""Retrieves the output tensor(s) of a layer.
Only returns the tensor(s) corresponding to the *first time*
the operation was called.
Returns:
Output tensor or list of output tensors.
"""
return self._get_node_attribute_at_index(0, "output_tensors", "output")
def _get_node_attribute_at_index(self, node_index, attr, attr_name):
"""Private utility to retrieves an attribute (e.g. inputs) from a node.
This is used to implement the properties:
- output
- input
Args:
node_index: Integer index of the node from which
to retrieve the attribute.
attr: Exact node attribute name.
attr_name: Human-readable attribute name, for error messages.
Returns:
The operation's attribute `attr` at the node of index `node_index`.
"""
if not self._inbound_nodes:
raise ValueError(
f"The layer {self.name} has never been called "
f"and thus has no defined {attr_name}."
)
if not len(self._inbound_nodes) > node_index:
raise ValueError(
f"Asked to get {attr_name} at node "
f"{node_index}, but the operation has only "
f"{len(self._inbound_nodes)} inbound nodes."
)
values = getattr(self._inbound_nodes[node_index], attr)
if isinstance(values, list) and len(values) == 1:
return values[0]
else:
return values
# Hooks for backend layer classes
def _post_build(self):
"""Can be overridden for per backend post build actions."""
pass
def _setattr_hook(self, name, value):
"""Can be overridden for per backend post build actions."""
return name, value
def _post_track_variable(self, variable):
"""Can be overridden for per backend post track actions."""
pass
def _post_untrack_variable(self, variable):
"""Can be overridden for per backend post untrack actions."""
pass