-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Expand file tree
/
Copy pathkeras_tensor.py
More file actions
313 lines (223 loc) · 8.76 KB
/
keras_tensor.py
File metadata and controls
313 lines (223 loc) · 8.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.utils.naming import auto_name
@keras_export("keras.KerasTensor")
class KerasTensor:
"""Symbolic tensor -- encapsulates a shape and a dtype.
You can use `KerasTensor` instances to build computation
graphs of Keras operations, such as `keras.Function`
objects or Functional `keras.models.Model` objects.
Example:
>>> x = keras.KerasTensor(shape=(3, 4), dtype="float32")
>>> x.shape
(3, 4)
>>> x.dtype
float32
Calling a Keras operation (including a layer or a model)
on a `KerasTensor` instance will return another `KerasTensor`
instance with the appropriate shape and dtype. This is
called a "symbolic call" (since there is no actual data
involved). The computation of the correct output shape and
dtype is called "static shape inference".
"""
def __init__(
self,
shape,
dtype="float32",
sparse=False,
record_history=True,
name=None,
):
from keras.src import backend
self.shape = backend.standardize_shape(shape)
self.dtype = backend.standardize_dtype(dtype)
self.sparse = sparse
self.name = name or auto_name(self.__class__.__name__)
self.record_history = record_history
@property
def ndim(self):
return len(self.shape)
def reshape(self, newshape):
from keras.src import ops
return ops.Reshape(newshape)(self)
def squeeze(self, axis=None):
from keras.src import ops
return ops.Squeeze(axis)(self)
def __array__(self):
raise ValueError(
"A KerasTensor is symbolic: it's a placeholder for a shape "
"an a dtype. It doesn't have any actual numerical value. "
"You cannot convert it to a NumPy array."
)
def __jax_array__(self):
raise ValueError(
"A KerasTensor cannot be used as input to a JAX function. "
"A KerasTensor is a symbolic placeholder for a shape and dtype, "
"used when constructing Keras Functional models "
"or Keras Functions. You can only use it as input to a Keras layer "
"or a Keras operation (from the namespaces `keras.layers` "
"and `keras.operations`). "
"You are likely doing something like:\n\n"
"```\n"
"x = Input(...)\n"
"...\n"
"jax_fn(x) # Invalid.\n"
"```\n\n"
"What you should do instead is wrap `jax_fn` in a layer:\n\n"
"```\n"
"class MyLayer(Layer):\n"
" def call(self, x):\n"
" return jax_fn(x)\n\n"
"x = MyLayer()(x)\n"
"```\n"
)
def __tf_tensor__(self, dtype=None, name=None):
raise ValueError(
"A KerasTensor cannot be used as input to a TensorFlow function. "
"A KerasTensor is a symbolic placeholder for a shape and dtype, "
"used when constructing Keras Functional models "
"or Keras Functions. You can only use it as input to a Keras layer "
"or a Keras operation (from the namespaces `keras.layers` "
"and `keras.operations`). "
"You are likely doing something like:\n\n"
"```\n"
"x = Input(...)\n"
"...\n"
"tf_fn(x) # Invalid.\n"
"```\n\n"
"What you should do instead is wrap `tf_fn` in a layer:\n\n"
"```\n"
"class MyLayer(Layer):\n"
" def call(self, x):\n"
" return tf_fn(x)\n\n"
"x = MyLayer()(x)\n"
"```\n"
)
def __repr__(self):
return (
f"<KerasTensor shape={self.shape}, dtype={self.dtype}, "
f"sparse={self.sparse}, name={self.name}>"
)
def __iter__(self):
raise NotImplementedError(
"Iterating over a symbolic KerasTensor is not supported."
)
def __bool__(self):
raise TypeError("A symbolic KerasTensor cannot be used as a boolean.")
def __add__(self, other):
from keras.src import ops
return ops.Add().symbolic_call(self, other)
def __radd__(self, other):
from keras.src import ops
return ops.Add().symbolic_call(other, self)
def __sub__(self, other):
from keras.src import ops
return ops.Subtract().symbolic_call(self, other)
def __rsub__(self, other):
from keras.src import ops
return ops.Subtract().symbolic_call(other, self)
def __mul__(self, other):
from keras.src import ops
return ops.Multiply().symbolic_call(self, other)
def __rmul__(self, other):
from keras.src import ops
return ops.Multiply().symbolic_call(other, self)
def __matmul__(self, other):
from keras.src import ops
return ops.Matmul().symbolic_call(self, other)
def __rmatmul__(self, other):
from keras.src import ops
return ops.Matmul().symbolic_call(other, self)
def __div__(self, other):
from keras.src import ops
return ops.Divide().symbolic_call(self, other)
def __rdiv__(self, other):
from keras.src import ops
return ops.Divide().symbolic_call(other, self)
def __truediv__(self, other):
from keras.src import ops
return ops.TrueDivide().symbolic_call(self, other)
def __rtruediv__(self, other):
from keras.src import ops
return ops.TrueDivide().symbolic_call(other, self)
def __neg__(self):
from keras.src import ops
return ops.Negative().symbolic_call(self)
def __abs__(self):
from keras.src import ops
return ops.Absolute().symbolic_call(self)
def __pow__(self, other):
from keras.src import ops
return ops.Power().symbolic_call(self, other)
def __rpow__(self, other):
from keras.src import ops
return ops.Power().symbolic_call(other, self)
def __floordiv__(self, other):
from keras.src import ops
return ops.FloorDivide().symbolic_call(self, other)
def __rfloordiv__(self, other):
from keras.src import ops
return ops.FloorDivide().symbolic_call(other, self)
def __mod__(self, other):
from keras.src import ops
return ops.Mod().symbolic_call(self, other)
def __rmod__(self, other):
from keras.src import ops
return ops.Mod().symbolic_call(other, self)
def __lt__(self, other):
from keras.src import ops
return ops.Less().symbolic_call(self, other)
def __le__(self, other):
from keras.src import ops
return ops.LessEqual().symbolic_call(self, other)
def __gt__(self, other):
from keras.src import ops
return ops.Greater().symbolic_call(self, other)
def __ge__(self, other):
from keras.src import ops
return ops.GreaterEqual().symbolic_call(self, other)
def __ne__(self, other):
from keras.src import ops
return ops.NotEqual().symbolic_call(self, other)
def __and__(self, other):
from keras.src import ops
return ops.LogicalAnd().symbolic_call(self, other)
def __rand__(self, other):
from keras.src import ops
return ops.LogicalAnd().symbolic_call(other, self)
def __or__(self, other):
from keras.src import ops
return ops.LogicalOr().symbolic_call(self, other)
def __ror__(self, other):
from keras.src import ops
return ops.LogicalOr().symbolic_call(other, self)
def __invert__(self):
from keras.src import ops
return ops.LogicalNot().symbolic_call(self)
def __xor__(self, other):
from keras.src import ops
return ops.LogicalXor().symbolic_call(self, other)
def __rxor__(self, other):
from keras.src import ops
return ops.LogicalXor().symbolic_call(other, self)
def __getitem__(self, key):
from keras.src import ops
return ops.GetItem().symbolic_call(self, key)
def any_symbolic_tensors(args=None, kwargs=None):
args = args or ()
kwargs = kwargs or {}
for x in tree.flatten((args, kwargs)):
if isinstance(x, KerasTensor):
return True
return False
@keras_export(["keras.utils.is_keras_tensor", "keras.backend.is_keras_tensor"])
def is_keras_tensor(x):
"""Returns whether `x` is a Keras tensor.
A "Keras tensor" is a *symbolic tensor*, such as a tensor
that was created via `Input()`. A "symbolic tensor"
can be understood as a placeholder -- it does not
contain any actual numerical data, only a shape and dtype.
It can be used for building Functional models, but it
cannot be used in actual computations.
"""
return isinstance(x, KerasTensor)