File: | .cache/bazel/_bazel_alan/39be661231df2a680c9b74265384c13c/execroot/org_tensorflow/tensorflow/python/eager/pywrap_tfe_src.cc |
Warning: | line 4260, column 19 PyObject ownership leak with reference count of 1 |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |||
2 | ||||
3 | Licensed under the Apache License, Version 2.0 (the "License"); | |||
4 | you may not use this file except in compliance with the License. | |||
5 | You may obtain a copy of the License at | |||
6 | ||||
7 | http://www.apache.org/licenses/LICENSE-2.0 | |||
8 | ||||
9 | Unless required by applicable law or agreed to in writing, software | |||
10 | distributed under the License is distributed on an "AS IS" BASIS, | |||
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
12 | See the License for the specific language governing permissions and | |||
13 | limitations under the License. | |||
14 | ==============================================================================*/ | |||
15 | ||||
16 | #include <atomic> | |||
17 | #include <cstring> | |||
18 | #include <unordered_map> | |||
19 | ||||
20 | #include "absl/debugging/leak_check.h" | |||
21 | #include "absl/strings/str_cat.h" | |||
22 | #include "absl/types/variant.h" | |||
23 | #include "tensorflow/c/c_api.h" | |||
24 | #include "tensorflow/c/c_api_internal.h" | |||
25 | #include "tensorflow/c/eager/c_api.h" | |||
26 | #include "tensorflow/c/eager/c_api_internal.h" | |||
27 | #include "tensorflow/c/eager/tape.h" | |||
28 | #include "tensorflow/c/eager/tfe_context_internal.h" | |||
29 | #include "tensorflow/c/eager/tfe_op_internal.h" | |||
30 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" | |||
31 | #include "tensorflow/c/tf_status.h" | |||
32 | #include "tensorflow/core/framework/types.pb.h" | |||
33 | #include "tensorflow/core/lib/core/errors.h" | |||
34 | #include "tensorflow/core/lib/gtl/cleanup.h" | |||
35 | #include "tensorflow/core/lib/gtl/compactptrset.h" | |||
36 | #include "tensorflow/core/lib/gtl/flatmap.h" | |||
37 | #include "tensorflow/core/lib/gtl/flatset.h" | |||
38 | #include "tensorflow/core/lib/strings/strcat.h" | |||
39 | #include "tensorflow/core/lib/strings/stringprintf.h" | |||
40 | #include "tensorflow/core/platform/casts.h" | |||
41 | #include "tensorflow/core/platform/errors.h" | |||
42 | #include "tensorflow/core/platform/mutex.h" | |||
43 | #include "tensorflow/core/platform/protobuf.h" | |||
44 | #include "tensorflow/core/platform/status.h" | |||
45 | #include "tensorflow/core/platform/types.h" | |||
46 | #include "tensorflow/core/profiler/lib/traceme.h" | |||
47 | #include "tensorflow/core/util/managed_stack_trace.h" | |||
48 | #include "tensorflow/python/eager/pywrap_gradient_exclusions.h" | |||
49 | #include "tensorflow/python/eager/pywrap_tensor.h" | |||
50 | #include "tensorflow/python/eager/pywrap_tfe.h" | |||
51 | #include "tensorflow/python/lib/core/py_util.h" | |||
52 | #include "tensorflow/python/lib/core/safe_ptr.h" | |||
53 | #include "tensorflow/python/util/stack_trace.h" | |||
54 | #include "tensorflow/python/util/util.h" | |||
55 | ||||
56 | using tensorflow::Status; | |||
57 | using tensorflow::string; | |||
58 | using tensorflow::strings::Printf; | |||
59 | ||||
60 | namespace { | |||
61 | // NOTE: Items are retrieved from and returned to these unique_ptrs, and they | |||
62 | // act as arenas. This is important if the same thread requests 2 items without | |||
63 | // releasing one. | |||
64 | // The following sequence of events on the same thread will still succeed: | |||
65 | // - GetOp <- Returns existing. | |||
66 | // - GetOp <- Allocates and returns a new pointer. | |||
67 | // - ReleaseOp <- Sets the item in the unique_ptr. | |||
68 | // - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one. | |||
69 | // This occurs when a PyFunc kernel is run. This behavior makes it safe in that | |||
70 | // case, as well as the case where python decides to reuse the underlying | |||
71 | // C++ thread in 2 python threads case. | |||
72 | struct OpDeleter { | |||
73 | void operator()(TFE_Op* op) const { TFE_DeleteOp(op); } | |||
74 | }; | |||
75 | thread_local std::unordered_map<TFE_Context*, | |||
76 | std::unique_ptr<TFE_Op, OpDeleter>> | |||
77 | thread_local_eager_operation_map; // NOLINT | |||
78 | thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT | |||
79 | nullptr; | |||
80 | ||||
81 | std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) { | |||
82 | auto it = thread_local_eager_operation_map.find(ctx); | |||
83 | if (it == thread_local_eager_operation_map.end()) { | |||
84 | return nullptr; | |||
85 | } | |||
86 | return std::move(it->second); | |||
87 | } | |||
88 | ||||
89 | TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, | |||
90 | const char* raw_device_name, TF_Status* status) { | |||
91 | auto op = ReleaseThreadLocalOp(ctx); | |||
92 | if (!op) { | |||
93 | op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation())); | |||
94 | } | |||
95 | status->status = | |||
96 | tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name); | |||
97 | if (!status->status.ok()) { | |||
98 | op.reset(); | |||
99 | } | |||
100 | return op.release(); | |||
101 | } | |||
102 | ||||
103 | void ReturnOp(TFE_Context* ctx, TFE_Op* op) { | |||
104 | if (op) { | |||
105 | tensorflow::unwrap(op)->Clear(); | |||
106 | thread_local_eager_operation_map[ctx].reset(op); | |||
107 | } | |||
108 | } | |||
109 | ||||
110 | TF_Status* ReleaseThreadLocalStatus() { | |||
111 | if (thread_local_tf_status == nullptr) { | |||
112 | return nullptr; | |||
113 | } | |||
114 | return thread_local_tf_status.release(); | |||
115 | } | |||
116 | ||||
117 | struct InputInfo { | |||
118 | InputInfo(int i, bool is_list) : i(i), is_list(is_list) {} | |||
119 | ||||
120 | int i; | |||
121 | bool is_list = false; | |||
122 | }; | |||
123 | ||||
124 | // Takes in output gradients, returns input gradients. | |||
125 | typedef std::function<PyObject*(PyObject*, const std::vector<int64_t>&)> | |||
126 | PyBackwardFunction; | |||
127 | ||||
128 | using AttrToInputsMap = | |||
129 | tensorflow::gtl::FlatMap<string, | |||
130 | tensorflow::gtl::InlinedVector<InputInfo, 4>>; | |||
131 | ||||
132 | tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() { | |||
133 | static auto* all_attr_to_input_maps = | |||
134 | new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>; | |||
135 | return all_attr_to_input_maps; | |||
136 | } | |||
137 | ||||
138 | // This function doesn't use a lock, since we depend on the GIL directly. | |||
139 | AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) { | |||
140 | #if PY_MAJOR_VERSION3 >= 3 && PY_MINOR_VERSION8 >= 4 | |||
141 | DCHECK(PyGILState_Check())while (false && (PyGILState_Check())) ::tensorflow::internal ::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 141) | |||
142 | << "This function needs to hold the GIL when called."; | |||
143 | #endif | |||
144 | auto* all_attr_to_input_maps = GetAllAttrToInputsMaps(); | |||
145 | auto* output = | |||
146 | tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name()); | |||
147 | if (output != nullptr) { | |||
148 | return output; | |||
149 | } | |||
150 | ||||
151 | std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap); | |||
152 | ||||
153 | // Store a list of InputIndex -> List of corresponding inputs. | |||
154 | for (int i = 0; i < op_def.input_arg_size(); i++) { | |||
155 | if (!op_def.input_arg(i).type_attr().empty()) { | |||
156 | auto it = m->find(op_def.input_arg(i).type_attr()); | |||
157 | if (it == m->end()) { | |||
158 | it = m->insert({op_def.input_arg(i).type_attr(), {}}).first; | |||
159 | } | |||
160 | it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty()); | |||
161 | } | |||
162 | } | |||
163 | ||||
164 | auto* retval = m.get(); | |||
165 | (*all_attr_to_input_maps)[op_def.name()] = m.release(); | |||
166 | ||||
167 | return retval; | |||
168 | } | |||
169 | ||||
170 | // This function doesn't use a lock, since we depend on the GIL directly. | |||
171 | tensorflow::gtl::FlatMap< | |||
172 | string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>* | |||
173 | GetAllAttrToDefaultsMaps() { | |||
174 | static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap< | |||
175 | string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>; | |||
176 | return all_attr_to_defaults_maps; | |||
177 | } | |||
178 | ||||
179 | tensorflow::gtl::FlatMap<string, tensorflow::DataType>* | |||
180 | GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) { | |||
181 | #if PY_MAJOR_VERSION3 >= 3 && PY_MINOR_VERSION8 >= 4 | |||
182 | DCHECK(PyGILState_Check())while (false && (PyGILState_Check())) ::tensorflow::internal ::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 182) | |||
183 | << "This function needs to hold the GIL when called."; | |||
184 | #endif | |||
185 | auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps(); | |||
186 | auto* output = | |||
187 | tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name()); | |||
188 | if (output != nullptr) { | |||
189 | return output; | |||
190 | } | |||
191 | ||||
192 | auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>; | |||
193 | ||||
194 | for (const auto& attr : op_def.attr()) { | |||
195 | if (attr.type() == "type" && attr.has_default_value()) { | |||
196 | new_map->insert({attr.name(), attr.default_value().type()}); | |||
197 | } | |||
198 | } | |||
199 | ||||
200 | (*all_attr_to_defaults_maps)[op_def.name()] = new_map; | |||
201 | ||||
202 | return new_map; | |||
203 | } | |||
204 | ||||
205 | struct FastPathOpExecInfo { | |||
206 | TFE_Context* ctx; | |||
207 | const char* device_name; | |||
208 | ||||
209 | bool run_callbacks; | |||
210 | bool run_post_exec_callbacks; | |||
211 | bool run_gradient_callback; | |||
212 | ||||
213 | // The op name of the main op being executed. | |||
214 | PyObject* name; | |||
215 | // The op type name of the main op being executed. | |||
216 | PyObject* op_name; | |||
217 | PyObject* callbacks; | |||
218 | ||||
219 | // All the args passed into the FastPathOpExecInfo. | |||
220 | PyObject* args; | |||
221 | ||||
222 | // DTypes can come from another input that has the same attr. So build that | |||
223 | // map. | |||
224 | const AttrToInputsMap* attr_to_inputs_map; | |||
225 | const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes; | |||
226 | tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes; | |||
227 | }; | |||
228 | ||||
229 | #define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \ | |||
230 | bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \ | |||
231 | type* value) { \ | |||
232 | if (check_fn(py_value)) { \ | |||
233 | *value = static_cast<type>(parse_fn(py_value)); \ | |||
234 | return true; \ | |||
235 | } else { \ | |||
236 | TF_SetStatus(status, TF_INVALID_ARGUMENT, \ | |||
237 | tensorflow::strings::StrCat( \ | |||
238 | "Expecting " #type " value for attr ", key, ", got ", \ | |||
239 | py_value->ob_type->tp_name) \ | |||
240 | .c_str()); \ | |||
241 | return false; \ | |||
242 | } \ | |||
243 | } | |||
244 | ||||
245 | #if PY_MAJOR_VERSION3 >= 3 | |||
246 | PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong) | |||
247 | PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong) | |||
248 | #else | |||
249 | PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong) | |||
250 | #endif | |||
251 | PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble) | |||
252 | #undef PARSE_VALUE | |||
253 | ||||
254 | #if PY_MAJOR_VERSION3 < 3 | |||
255 | bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status, | |||
256 | int64_t* value) { | |||
257 | if (PyInt_Check(py_value)) { | |||
258 | *value = static_cast<int64_t>(PyInt_AsLong(py_value)); | |||
259 | return true; | |||
260 | } else if (PyLong_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 24))) != 0)) { | |||
261 | *value = static_cast<int64_t>(PyLong_AsLong(py_value)); | |||
262 | return true; | |||
263 | } | |||
264 | TF_SetStatus( | |||
265 | status, TF_INVALID_ARGUMENT, | |||
266 | tensorflow::strings::StrCat("Expecting int or long value for attr ", key, | |||
267 | ", got ", py_value->ob_type->tp_name) | |||
268 | .c_str()); | |||
269 | return false; | |||
270 | } | |||
271 | #endif | |||
272 | ||||
273 | Py_ssize_t TensorShapeNumDims(PyObject* value) { | |||
274 | const auto size = PySequence_Size(value); | |||
275 | if (size == -1) { | |||
276 | // TensorShape.__len__ raises an error in the scenario where the shape is an | |||
277 | // unknown, which needs to be cleared. | |||
278 | // TODO(nareshmodi): ensure that this is actually a TensorShape. | |||
279 | PyErr_Clear(); | |||
280 | } | |||
281 | return size; | |||
282 | } | |||
283 | ||||
284 | bool IsInteger(PyObject* py_value) { | |||
285 | #if PY_MAJOR_VERSION3 >= 3 | |||
286 | return PyLong_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 24))) != 0); | |||
287 | #else | |||
288 | return PyInt_Check(py_value) || PyLong_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 24))) != 0); | |||
289 | #endif | |||
290 | } | |||
291 | ||||
292 | // This function considers a Dimension._value of None to be valid, and sets the | |||
293 | // value to be -1 in that case. | |||
294 | bool ParseDimensionValue(const string& key, PyObject* py_value, | |||
295 | TF_Status* status, int64_t* value) { | |||
296 | if (IsInteger(py_value)) { | |||
297 | return ParseInt64Value(key, py_value, status, value); | |||
298 | } | |||
299 | ||||
300 | tensorflow::Safe_PyObjectPtr dimension_value( | |||
301 | PyObject_GetAttrString(py_value, "_value")); | |||
302 | if (dimension_value == nullptr) { | |||
303 | PyErr_Clear(); | |||
304 | TF_SetStatus( | |||
305 | status, TF_INVALID_ARGUMENT, | |||
306 | tensorflow::strings::StrCat("Expecting a Dimension for attr ", key, | |||
307 | ", got ", py_value->ob_type->tp_name) | |||
308 | .c_str()); | |||
309 | return false; | |||
310 | } | |||
311 | ||||
312 | if (dimension_value.get() == Py_None(&_Py_NoneStruct)) { | |||
313 | *value = -1; | |||
314 | return true; | |||
315 | } | |||
316 | ||||
317 | return ParseInt64Value(key, dimension_value.get(), status, value); | |||
318 | } | |||
319 | ||||
320 | bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status, | |||
321 | tensorflow::StringPiece* value) { | |||
322 | if (PyBytes_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 27))) != 0)) { | |||
323 | Py_ssize_t size = 0; | |||
324 | char* buf = nullptr; | |||
325 | if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false; | |||
326 | *value = tensorflow::StringPiece(buf, size); | |||
327 | return true; | |||
328 | } | |||
329 | #if PY_MAJOR_VERSION3 >= 3 | |||
330 | if (PyUnicode_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 28))) != 0)) { | |||
331 | Py_ssize_t size = 0; | |||
332 | const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size); | |||
333 | if (buf == nullptr) return false; | |||
334 | *value = tensorflow::StringPiece(buf, size); | |||
335 | return true; | |||
336 | } | |||
337 | #endif | |||
338 | TF_SetStatus( | |||
339 | status, TF_INVALID_ARGUMENT, | |||
340 | tensorflow::strings::StrCat("Expecting a string value for attr ", key, | |||
341 | ", got ", py_value->ob_type->tp_name) | |||
342 | .c_str()); | |||
343 | return false; | |||
344 | } | |||
345 | ||||
346 | bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status, | |||
347 | unsigned char* value) { | |||
348 | *value = PyObject_IsTrue(py_value); | |||
349 | return true; | |||
350 | } | |||
351 | ||||
352 | // The passed in py_value is expected to be an object of the python type | |||
353 | // dtypes.DType or an int. | |||
354 | bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status, | |||
355 | int* value) { | |||
356 | if (IsInteger(py_value)) { | |||
357 | return ParseIntValue(key, py_value, status, value); | |||
358 | } | |||
359 | ||||
360 | tensorflow::Safe_PyObjectPtr py_type_enum( | |||
361 | PyObject_GetAttrString(py_value, "_type_enum")); | |||
362 | if (py_type_enum == nullptr) { | |||
363 | PyErr_Clear(); | |||
364 | TF_SetStatus( | |||
365 | status, TF_INVALID_ARGUMENT, | |||
366 | tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key, | |||
367 | ", got ", py_value->ob_type->tp_name) | |||
368 | .c_str()); | |||
369 | return false; | |||
370 | } | |||
371 | ||||
372 | return ParseIntValue(key, py_type_enum.get(), status, value); | |||
373 | } | |||
374 | ||||
375 | bool SetOpAttrList(TFE_Context* ctx, TFE_Op* op, const char* key, | |||
376 | PyObject* py_list, TF_AttrType type, | |||
377 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | |||
378 | TF_Status* status) { | |||
379 | if (!PySequence_Check(py_list)) { | |||
380 | TF_SetStatus( | |||
381 | status, TF_INVALID_ARGUMENT, | |||
382 | tensorflow::strings::StrCat("Expecting sequence value for attr ", key, | |||
383 | ", got ", py_list->ob_type->tp_name) | |||
384 | .c_str()); | |||
385 | return false; | |||
386 | } | |||
387 | const int num_values = PySequence_Size(py_list); | |||
388 | if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values; | |||
389 | ||||
390 | #define PARSE_LIST(c_type, parse_fn) \ | |||
391 | std::unique_ptr<c_type[]> values(new c_type[num_values]); \ | |||
392 | for (int i = 0; i < num_values; ++i) { \ | |||
393 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) )); \ | |||
394 | if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \ | |||
395 | } | |||
396 | ||||
397 | if (type == TF_ATTR_STRING) { | |||
398 | std::unique_ptr<const void*[]> values(new const void*[num_values]); | |||
399 | std::unique_ptr<size_t[]> lengths(new size_t[num_values]); | |||
400 | for (int i = 0; i < num_values; ++i) { | |||
401 | tensorflow::StringPiece value; | |||
402 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) )); | |||
403 | if (!ParseStringValue(key, py_value.get(), status, &value)) return false; | |||
404 | values[i] = value.data(); | |||
405 | lengths[i] = value.size(); | |||
406 | } | |||
407 | TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); | |||
408 | } else if (type == TF_ATTR_INT) { | |||
409 | PARSE_LIST(int64_t, ParseInt64Value); | |||
410 | TFE_OpSetAttrIntList(op, key, values.get(), num_values); | |||
411 | } else if (type == TF_ATTR_FLOAT) { | |||
412 | PARSE_LIST(float, ParseFloatValue); | |||
413 | TFE_OpSetAttrFloatList(op, key, values.get(), num_values); | |||
414 | } else if (type == TF_ATTR_BOOL) { | |||
415 | PARSE_LIST(unsigned char, ParseBoolValue); | |||
416 | TFE_OpSetAttrBoolList(op, key, values.get(), num_values); | |||
417 | } else if (type == TF_ATTR_TYPE) { | |||
418 | PARSE_LIST(int, ParseTypeValue); | |||
419 | TFE_OpSetAttrTypeList(op, key, | |||
420 | reinterpret_cast<const TF_DataType*>(values.get()), | |||
421 | num_values); | |||
422 | } else if (type == TF_ATTR_SHAPE) { | |||
423 | // Make one pass through the input counting the total number of | |||
424 | // dims across all the input lists. | |||
425 | int total_dims = 0; | |||
426 | for (int i = 0; i < num_values; ++i) { | |||
427 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) )); | |||
428 | if (py_value.get() != Py_None(&_Py_NoneStruct)) { | |||
429 | if (!PySequence_Check(py_value.get())) { | |||
430 | TF_SetStatus( | |||
431 | status, TF_INVALID_ARGUMENT, | |||
432 | tensorflow::strings::StrCat( | |||
433 | "Expecting None or sequence value for element", i, | |||
434 | " of attr ", key, ", got ", py_value->ob_type->tp_name) | |||
435 | .c_str()); | |||
436 | return false; | |||
437 | } | |||
438 | const auto size = TensorShapeNumDims(py_value.get()); | |||
439 | if (size >= 0) { | |||
440 | total_dims += size; | |||
441 | } | |||
442 | } | |||
443 | } | |||
444 | // Allocate a buffer that can fit all of the dims together. | |||
445 | std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]); | |||
446 | // Copy the input dims into the buffer and set dims to point to | |||
447 | // the start of each list's dims. | |||
448 | std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]); | |||
449 | std::unique_ptr<int[]> num_dims(new int[num_values]); | |||
450 | int64_t* offset = buffer.get(); | |||
451 | for (int i = 0; i < num_values; ++i) { | |||
452 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) )); | |||
453 | if (py_value.get() == Py_None(&_Py_NoneStruct)) { | |||
454 | dims[i] = nullptr; | |||
455 | num_dims[i] = -1; | |||
456 | } else { | |||
457 | const auto size = TensorShapeNumDims(py_value.get()); | |||
458 | if (size == -1) { | |||
459 | dims[i] = nullptr; | |||
460 | num_dims[i] = -1; | |||
461 | continue; | |||
462 | } | |||
463 | dims[i] = offset; | |||
464 | num_dims[i] = size; | |||
465 | for (int j = 0; j < size; ++j) { | |||
466 | tensorflow::Safe_PyObjectPtr inner_py_value( | |||
467 | PySequence_ITEM(py_value.get(), j)( (((PyObject*)(py_value.get()))->ob_type)->tp_as_sequence ->sq_item(py_value.get(), j) )); | |||
468 | if (inner_py_value.get() == Py_None(&_Py_NoneStruct)) { | |||
469 | *offset = -1; | |||
470 | } else if (!ParseDimensionValue(key, inner_py_value.get(), status, | |||
471 | offset)) { | |||
472 | return false; | |||
473 | } | |||
474 | ++offset; | |||
475 | } | |||
476 | } | |||
477 | } | |||
478 | TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, | |||
479 | status); | |||
480 | if (!status->status.ok()) return false; | |||
481 | } else if (type == TF_ATTR_FUNC) { | |||
482 | std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]); | |||
483 | for (int i = 0; i < num_values; ++i) { | |||
484 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) )); | |||
485 | // Allow: | |||
486 | // (1) String function name, OR | |||
487 | // (2) A Python object with a .name attribute | |||
488 | // (A crude test for being a | |||
489 | // tensorflow.python.framework.function._DefinedFunction) | |||
490 | // (which is what the various "defun" or "Defun" decorators do). | |||
491 | // And in the future also allow an object that can encapsulate | |||
492 | // the function name and its attribute values. | |||
493 | tensorflow::StringPiece func_name; | |||
494 | if (!ParseStringValue(key, py_value.get(), status, &func_name)) { | |||
495 | PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name"); | |||
496 | if (name_attr == nullptr || | |||
497 | !ParseStringValue(key, name_attr, status, &func_name)) { | |||
498 | TF_SetStatus( | |||
499 | status, TF_INVALID_ARGUMENT, | |||
500 | tensorflow::strings::StrCat( | |||
501 | "unable to set function value attribute from a ", | |||
502 | py_value.get()->ob_type->tp_name, | |||
503 | " object. If you think this is an error, please file an " | |||
504 | "issue at " | |||
505 | "https://github.com/tensorflow/tensorflow/issues/new") | |||
506 | .c_str()); | |||
507 | return false; | |||
508 | } | |||
509 | } | |||
510 | funcs[i] = TFE_NewOp(ctx, func_name.data(), status); | |||
511 | if (!status->status.ok()) return false; | |||
512 | } | |||
513 | TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values); | |||
514 | if (!status->status.ok()) return false; | |||
515 | } else { | |||
516 | TF_SetStatus(status, TF_UNIMPLEMENTED, | |||
517 | tensorflow::strings::StrCat("Attr ", key, | |||
518 | " has unhandled list type ", type) | |||
519 | .c_str()); | |||
520 | return false; | |||
521 | } | |||
522 | #undef PARSE_LIST | |||
523 | return true; | |||
524 | } | |||
525 | ||||
526 | TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, | |||
527 | TF_Status* status) { | |||
528 | TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); | |||
529 | for (const auto& attr : func.attr()) { | |||
530 | if (!status->status.ok()) return nullptr; | |||
531 | SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); | |||
532 | if (!status->status.ok()) return nullptr; | |||
533 | } | |||
534 | return func_op; | |||
535 | } | |||
536 | ||||
537 | void SetOpAttrListDefault( | |||
538 | TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, | |||
539 | const char* key, TF_AttrType type, | |||
540 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | |||
541 | TF_Status* status) { | |||
542 | if (type == TF_ATTR_STRING) { | |||
543 | int num_values = attr.default_value().list().s_size(); | |||
544 | std::unique_ptr<const void*[]> values(new const void*[num_values]); | |||
545 | std::unique_ptr<size_t[]> lengths(new size_t[num_values]); | |||
546 | (*attr_list_sizes)[key] = num_values; | |||
547 | for (int i = 0; i < num_values; i++) { | |||
548 | const string& v = attr.default_value().list().s(i); | |||
549 | values[i] = v.data(); | |||
550 | lengths[i] = v.size(); | |||
551 | } | |||
552 | TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); | |||
553 | } else if (type == TF_ATTR_INT) { | |||
554 | int num_values = attr.default_value().list().i_size(); | |||
555 | std::unique_ptr<int64_t[]> values(new int64_t[num_values]); | |||
556 | (*attr_list_sizes)[key] = num_values; | |||
557 | for (int i = 0; i < num_values; i++) { | |||
558 | values[i] = attr.default_value().list().i(i); | |||
559 | } | |||
560 | TFE_OpSetAttrIntList(op, key, values.get(), num_values); | |||
561 | } else if (type == TF_ATTR_FLOAT) { | |||
562 | int num_values = attr.default_value().list().f_size(); | |||
563 | std::unique_ptr<float[]> values(new float[num_values]); | |||
564 | (*attr_list_sizes)[key] = num_values; | |||
565 | for (int i = 0; i < num_values; i++) { | |||
566 | values[i] = attr.default_value().list().f(i); | |||
567 | } | |||
568 | TFE_OpSetAttrFloatList(op, key, values.get(), num_values); | |||
569 | } else if (type == TF_ATTR_BOOL) { | |||
570 | int num_values = attr.default_value().list().b_size(); | |||
571 | std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]); | |||
572 | (*attr_list_sizes)[key] = num_values; | |||
573 | for (int i = 0; i < num_values; i++) { | |||
574 | values[i] = attr.default_value().list().b(i); | |||
575 | } | |||
576 | TFE_OpSetAttrBoolList(op, key, values.get(), num_values); | |||
577 | } else if (type == TF_ATTR_TYPE) { | |||
578 | int num_values = attr.default_value().list().type_size(); | |||
579 | std::unique_ptr<int[]> values(new int[num_values]); | |||
580 | (*attr_list_sizes)[key] = num_values; | |||
581 | for (int i = 0; i < num_values; i++) { | |||
582 | values[i] = attr.default_value().list().type(i); | |||
583 | } | |||
584 | TFE_OpSetAttrTypeList(op, key, | |||
585 | reinterpret_cast<const TF_DataType*>(values.get()), | |||
586 | attr.default_value().list().type_size()); | |||
587 | } else if (type == TF_ATTR_SHAPE) { | |||
588 | int num_values = attr.default_value().list().shape_size(); | |||
589 | (*attr_list_sizes)[key] = num_values; | |||
590 | int total_dims = 0; | |||
591 | for (int i = 0; i < num_values; ++i) { | |||
592 | if (!attr.default_value().list().shape(i).unknown_rank()) { | |||
593 | total_dims += attr.default_value().list().shape(i).dim_size(); | |||
594 | } | |||
595 | } | |||
596 | // Allocate a buffer that can fit all of the dims together. | |||
597 | std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]); | |||
598 | // Copy the input dims into the buffer and set dims to point to | |||
599 | // the start of each list's dims. | |||
600 | std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]); | |||
601 | std::unique_ptr<int[]> num_dims(new int[num_values]); | |||
602 | int64_t* offset = buffer.get(); | |||
603 | for (int i = 0; i < num_values; ++i) { | |||
604 | const auto& shape = attr.default_value().list().shape(i); | |||
605 | if (shape.unknown_rank()) { | |||
606 | dims[i] = nullptr; | |||
607 | num_dims[i] = -1; | |||
608 | } else { | |||
609 | for (int j = 0; j < shape.dim_size(); j++) { | |||
610 | *offset = shape.dim(j).size(); | |||
611 | ++offset; | |||
612 | } | |||
613 | } | |||
614 | } | |||
615 | TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, | |||
616 | status); | |||
617 | } else if (type == TF_ATTR_FUNC) { | |||
618 | int num_values = attr.default_value().list().func_size(); | |||
619 | (*attr_list_sizes)[key] = num_values; | |||
620 | std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]); | |||
621 | for (int i = 0; i < num_values; i++) { | |||
622 | funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status); | |||
623 | } | |||
624 | TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values); | |||
625 | } else { | |||
626 | TF_SetStatus(status, TF_UNIMPLEMENTED, | |||
627 | "Lists of tensors are not yet implemented for default valued " | |||
628 | "attributes for an operation."); | |||
629 | } | |||
630 | } | |||
631 | ||||
632 | bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key, | |||
633 | PyObject* py_value, TF_AttrType type, | |||
634 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | |||
635 | TF_Status* status) { | |||
636 | if (type == TF_ATTR_STRING) { | |||
637 | tensorflow::StringPiece value; | |||
638 | if (!ParseStringValue(key, py_value, status, &value)) return false; | |||
639 | TFE_OpSetAttrString(op, key, value.data(), value.size()); | |||
640 | } else if (type == TF_ATTR_INT) { | |||
641 | int64_t value; | |||
642 | if (!ParseInt64Value(key, py_value, status, &value)) return false; | |||
643 | TFE_OpSetAttrInt(op, key, value); | |||
644 | // attr_list_sizes is set for all int attributes (since at this point we are | |||
645 | // not aware if that attribute might be used to calculate the size of an | |||
646 | // output list or not). | |||
647 | if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value; | |||
648 | } else if (type == TF_ATTR_FLOAT) { | |||
649 | float value; | |||
650 | if (!ParseFloatValue(key, py_value, status, &value)) return false; | |||
651 | TFE_OpSetAttrFloat(op, key, value); | |||
652 | } else if (type == TF_ATTR_BOOL) { | |||
653 | unsigned char value; | |||
654 | if (!ParseBoolValue(key, py_value, status, &value)) return false; | |||
655 | TFE_OpSetAttrBool(op, key, value); | |||
656 | } else if (type == TF_ATTR_TYPE) { | |||
657 | int value; | |||
658 | if (!ParseTypeValue(key, py_value, status, &value)) return false; | |||
659 | TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value)); | |||
660 | } else if (type == TF_ATTR_SHAPE) { | |||
661 | if (py_value == Py_None(&_Py_NoneStruct)) { | |||
662 | TFE_OpSetAttrShape(op, key, nullptr, -1, status); | |||
663 | } else { | |||
664 | if (!PySequence_Check(py_value)) { | |||
665 | TF_SetStatus(status, TF_INVALID_ARGUMENT, | |||
666 | tensorflow::strings::StrCat( | |||
667 | "Expecting None or sequence value for attr", key, | |||
668 | ", got ", py_value->ob_type->tp_name) | |||
669 | .c_str()); | |||
670 | return false; | |||
671 | } | |||
672 | const auto num_dims = TensorShapeNumDims(py_value); | |||
673 | if (num_dims == -1) { | |||
674 | TFE_OpSetAttrShape(op, key, nullptr, -1, status); | |||
675 | return true; | |||
676 | } | |||
677 | std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]); | |||
678 | for (int i = 0; i < num_dims; ++i) { | |||
679 | tensorflow::Safe_PyObjectPtr inner_py_value( | |||
680 | PySequence_ITEM(py_value, i)( (((PyObject*)(py_value))->ob_type)->tp_as_sequence-> sq_item(py_value, i) )); | |||
681 | if (inner_py_value.get() == Py_None(&_Py_NoneStruct)) { | |||
682 | dims[i] = -1; | |||
683 | } else if (!ParseDimensionValue(key, inner_py_value.get(), status, | |||
684 | &dims[i])) { | |||
685 | return false; | |||
686 | } | |||
687 | } | |||
688 | TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status); | |||
689 | } | |||
690 | if (!status->status.ok()) return false; | |||
691 | } else if (type == TF_ATTR_FUNC) { | |||
692 | // Allow: | |||
693 | // (1) String function name, OR | |||
694 | // (2) A Python object with a .name attribute | |||
695 | // (A crude test for being a | |||
696 | // tensorflow.python.framework.function._DefinedFunction) | |||
697 | // (which is what the various "defun" or "Defun" decorators do). | |||
698 | // And in the future also allow an object that can encapsulate | |||
699 | // the function name and its attribute values. | |||
700 | tensorflow::StringPiece func_name; | |||
701 | if (!ParseStringValue(key, py_value, status, &func_name)) { | |||
702 | PyObject* name_attr = PyObject_GetAttrString(py_value, "name"); | |||
703 | if (name_attr == nullptr || | |||
704 | !ParseStringValue(key, name_attr, status, &func_name)) { | |||
705 | TF_SetStatus( | |||
706 | status, TF_INVALID_ARGUMENT, | |||
707 | tensorflow::strings::StrCat( | |||
708 | "unable to set function value attribute from a ", | |||
709 | py_value->ob_type->tp_name, | |||
710 | " object. If you think this is an error, please file an issue " | |||
711 | "at https://github.com/tensorflow/tensorflow/issues/new") | |||
712 | .c_str()); | |||
713 | return false; | |||
714 | } | |||
715 | } | |||
716 | TF_SetStatus(status, TF_OK, ""); | |||
717 | TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size()); | |||
718 | } else { | |||
719 | TF_SetStatus( | |||
720 | status, TF_UNIMPLEMENTED, | |||
721 | tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type) | |||
722 | .c_str()); | |||
723 | return false; | |||
724 | } | |||
725 | return true; | |||
726 | } | |||
727 | ||||
728 | void SetOpAttrScalarDefault( | |||
729 | TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, | |||
730 | const char* attr_name, | |||
731 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | |||
732 | TF_Status* status) { | |||
733 | SetOpAttrValueScalar(ctx, op, default_value, attr_name, status); | |||
734 | if (default_value.value_case() == tensorflow::AttrValue::kI) { | |||
735 | (*attr_list_sizes)[attr_name] = default_value.i(); | |||
736 | } | |||
737 | } | |||
738 | ||||
739 | // start_index is the index at which the Tuple/List attrs will start getting | |||
740 | // processed. | |||
741 | void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index, | |||
742 | TF_Status* out_status) { | |||
743 | if (attrs == Py_None(&_Py_NoneStruct)) return; | |||
744 | Py_ssize_t len = PyTuple_GET_SIZE(attrs)(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(attrs))))->ob_size) - start_index; | |||
745 | if ((len & 1) != 0) { | |||
746 | TF_SetStatus(out_status, TF_INVALID_ARGUMENT, | |||
747 | "Expecting attrs tuple to have even length."); | |||
748 | return; | |||
749 | } | |||
750 | // Parse attrs | |||
751 | for (Py_ssize_t i = 0; i < len; i += 2) { | |||
752 | PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i)(((static_cast<void> (0)), (PyTupleObject *)(attrs))-> ob_item[start_index + i]); | |||
753 | PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1)(((static_cast<void> (0)), (PyTupleObject *)(attrs))-> ob_item[start_index + i + 1]); | |||
754 | #if PY_MAJOR_VERSION3 >= 3 | |||
755 | const char* key = PyBytes_Check(py_key)((((((PyObject*)(py_key))->ob_type))->tp_flags & (( 1UL << 27))) != 0) ? PyBytes_AsString(py_key) | |||
756 | : PyUnicode_AsUTF8(py_key); | |||
757 | #else | |||
758 | const char* key = PyBytes_AsString(py_key); | |||
759 | #endif | |||
760 | unsigned char is_list = 0; | |||
761 | const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status); | |||
762 | if (!out_status->status.ok()) return; | |||
763 | if (is_list != 0) { | |||
764 | if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status)) | |||
765 | return; | |||
766 | } else { | |||
767 | if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status)) | |||
768 | return; | |||
769 | } | |||
770 | } | |||
771 | } | |||
772 | ||||
773 | // This function will set the op attrs required. If an attr has the value of | |||
774 | // None, then it will read the AttrDef to get the default value and set that | |||
775 | // instead. Any failure in this function will simply fall back to the slow | |||
776 | // path. | |||
777 | void SetOpAttrWithDefaults( | |||
778 | TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, | |||
779 | const char* attr_name, PyObject* attr_value, | |||
780 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | |||
781 | TF_Status* status) { | |||
782 | unsigned char is_list = 0; | |||
783 | const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status); | |||
784 | if (!status->status.ok()) return; | |||
785 | if (attr_value == Py_None(&_Py_NoneStruct)) { | |||
786 | if (is_list != 0) { | |||
787 | SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes, | |||
788 | status); | |||
789 | } else { | |||
790 | SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name, | |||
791 | attr_list_sizes, status); | |||
792 | } | |||
793 | } else { | |||
794 | if (is_list != 0) { | |||
795 | SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes, | |||
796 | status); | |||
797 | } else { | |||
798 | SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes, | |||
799 | status); | |||
800 | } | |||
801 | } | |||
802 | } | |||
803 | ||||
804 | PyObject* GetPythonObjectFromInt(int num) { | |||
805 | #if PY_MAJOR_VERSION3 >= 3 | |||
806 | return PyLong_FromLong(num); | |||
807 | #else | |||
808 | return PyInt_FromLong(num); | |||
809 | #endif | |||
810 | } | |||
811 | ||||
812 | // Python subclass of Exception that is created on not ok Status. | |||
813 | tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED); | |||
814 | PyObject* exception_class TF_GUARDED_BY(exception_class_mutex)__attribute__((guarded_by(exception_class_mutex))) = nullptr; | |||
815 | ||||
816 | // Python subclass of Exception that is created to signal fallback. | |||
817 | PyObject* fallback_exception_class = nullptr; | |||
818 | ||||
819 | // Python function that returns input gradients given output gradients. | |||
820 | PyObject* gradient_function = nullptr; | |||
821 | ||||
822 | // Python function that returns output gradients given input gradients. | |||
823 | PyObject* forward_gradient_function = nullptr; | |||
824 | ||||
825 | static std::atomic<int64_t> _uid; | |||
826 | ||||
827 | } // namespace | |||
828 | ||||
829 | TF_Status* GetStatus() { | |||
830 | TF_Status* maybe_status = ReleaseThreadLocalStatus(); | |||
831 | if (maybe_status) { | |||
832 | TF_SetStatus(maybe_status, TF_OK, ""); | |||
833 | return maybe_status; | |||
834 | } else { | |||
835 | return TF_NewStatus(); | |||
836 | } | |||
837 | } | |||
838 | ||||
839 | void ReturnStatus(TF_Status* status) { | |||
840 | TF_SetStatus(status, TF_OK, ""); | |||
841 | thread_local_tf_status.reset(status); | |||
842 | } | |||
843 | ||||
844 | void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, | |||
845 | const char* op_name, TFE_InputTensorHandles* inputs, | |||
846 | PyObject* attrs, TFE_OutputTensorHandles* outputs, | |||
847 | TF_Status* out_status) { | |||
848 | TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs, | |||
849 | /*cancellation_manager=*/nullptr, outputs, | |||
850 | out_status); | |||
851 | } | |||
852 | ||||
853 | void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name, | |||
854 | const char* op_name, | |||
855 | TFE_InputTensorHandles* inputs, PyObject* attrs, | |||
856 | TFE_CancellationManager* cancellation_manager, | |||
857 | TFE_OutputTensorHandles* outputs, | |||
858 | TF_Status* out_status) { | |||
859 | tensorflow::profiler::TraceMe activity( | |||
860 | "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo); | |||
861 | ||||
862 | TFE_Op* op = GetOp(ctx, op_name, device_name, out_status); | |||
863 | ||||
864 | auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); }); | |||
865 | if (!out_status->status.ok()) return; | |||
866 | ||||
867 | tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace( | |||
868 | tensorflow::StackTrace::kStackTraceInitialSize)); | |||
869 | ||||
870 | for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) { | |||
871 | TFE_OpAddInput(op, inputs->at(i), out_status); | |||
872 | } | |||
873 | if (cancellation_manager && out_status->status.ok()) { | |||
874 | TFE_OpSetCancellationManager(op, cancellation_manager, out_status); | |||
875 | } | |||
876 | if (out_status->status.ok()) { | |||
877 | SetOpAttrs(ctx, op, attrs, 0, out_status); | |||
878 | } | |||
879 | Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();; | |||
880 | ||||
881 | int num_outputs = outputs->size(); | |||
882 | ||||
883 | if (out_status->status.ok()) { | |||
884 | TFE_Execute(op, outputs->data(), &num_outputs, out_status); | |||
885 | } | |||
886 | ||||
887 | if (out_status->status.ok()) { | |||
888 | outputs->resize(num_outputs); | |||
889 | } else { | |||
890 | TF_SetStatus(out_status, TF_GetCode(out_status), | |||
891 | tensorflow::strings::StrCat(TF_Message(out_status), | |||
892 | " [Op:", op_name, "]") | |||
893 | .c_str()); | |||
894 | } | |||
895 | ||||
896 | Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); }; | |||
897 | } | |||
898 | ||||
899 | PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) { | |||
900 | tensorflow::mutex_lock l(exception_class_mutex); | |||
901 | if (exception_class != nullptr) { | |||
902 | Py_DECREF(exception_class)_Py_DECREF(((PyObject*)(exception_class))); | |||
903 | } | |||
904 | if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { | |||
905 | exception_class = nullptr; | |||
906 | PyErr_SetString(PyExc_TypeError, | |||
907 | "TFE_Py_RegisterExceptionClass: " | |||
908 | "Registered class should be subclass of Exception."); | |||
909 | return nullptr; | |||
910 | } | |||
911 | ||||
912 | Py_INCREF(e)_Py_INCREF(((PyObject*)(e))); | |||
913 | exception_class = e; | |||
914 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
915 | } | |||
916 | ||||
917 | PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) { | |||
918 | if (fallback_exception_class != nullptr) { | |||
919 | Py_DECREF(fallback_exception_class)_Py_DECREF(((PyObject*)(fallback_exception_class))); | |||
920 | } | |||
921 | if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { | |||
922 | fallback_exception_class = nullptr; | |||
923 | PyErr_SetString(PyExc_TypeError, | |||
924 | "TFE_Py_RegisterFallbackExceptionClass: " | |||
925 | "Registered class should be subclass of Exception."); | |||
926 | return nullptr; | |||
927 | } else { | |||
928 | Py_INCREF(e)_Py_INCREF(((PyObject*)(e))); | |||
929 | fallback_exception_class = e; | |||
930 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
931 | } | |||
932 | } | |||
933 | ||||
934 | PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) { | |||
935 | if (gradient_function != nullptr) { | |||
936 | Py_DECREF(gradient_function)_Py_DECREF(((PyObject*)(gradient_function))); | |||
937 | } | |||
938 | if (!PyCallable_Check(e)) { | |||
939 | gradient_function = nullptr; | |||
940 | PyErr_SetString(PyExc_TypeError, | |||
941 | "TFE_Py_RegisterGradientFunction: " | |||
942 | "Registered object should be function."); | |||
943 | return nullptr; | |||
944 | } else { | |||
945 | Py_INCREF(e)_Py_INCREF(((PyObject*)(e))); | |||
946 | gradient_function = e; | |||
947 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
948 | } | |||
949 | } | |||
950 | ||||
951 | PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) { | |||
952 | if (forward_gradient_function != nullptr) { | |||
953 | Py_DECREF(forward_gradient_function)_Py_DECREF(((PyObject*)(forward_gradient_function))); | |||
954 | } | |||
955 | if (!PyCallable_Check(e)) { | |||
956 | forward_gradient_function = nullptr; | |||
957 | PyErr_SetString(PyExc_TypeError, | |||
958 | "TFE_Py_RegisterJVPFunction: " | |||
959 | "Registered object should be function."); | |||
960 | return nullptr; | |||
961 | } else { | |||
962 | Py_INCREF(e)_Py_INCREF(((PyObject*)(e))); | |||
963 | forward_gradient_function = e; | |||
964 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
965 | } | |||
966 | } | |||
967 | ||||
968 | void RaiseFallbackException(const char* message) { | |||
969 | if (fallback_exception_class != nullptr) { | |||
970 | PyErr_SetString(fallback_exception_class, message); | |||
971 | return; | |||
972 | } | |||
973 | ||||
974 | PyErr_SetString( | |||
975 | PyExc_RuntimeError, | |||
976 | tensorflow::strings::StrCat( | |||
977 | "Fallback exception type not set, attempting to fallback due to ", | |||
978 | message) | |||
979 | .data()); | |||
980 | } | |||
981 | ||||
982 | // Format and return `status`' error message with the attached stack trace if | |||
983 | // available. `status` must have an error. | |||
984 | std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) { | |||
985 | tensorflow::DCheckPyGilState(); | |||
986 | DCHECK(!status.ok())while (false && (!status.ok())) ::tensorflow::internal ::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 986); | |||
987 | ||||
988 | if (status.stack_trace().empty()) return status.error_message(); | |||
989 | ||||
990 | const std::vector<tensorflow::StackFrame>& stack_trace = status.stack_trace(); | |||
991 | ||||
992 | PyObject* linecache = PyImport_ImportModule("linecache"); | |||
993 | PyObject* getline = | |||
994 | PyObject_GetAttr(linecache, PyUnicode_FromString("getline")); | |||
995 | DCHECK(getline)while (false && (getline)) ::tensorflow::internal::LogMessageFatal ("tensorflow/python/eager/pywrap_tfe_src.cc", 995); | |||
996 | ||||
997 | std::ostringstream result; | |||
998 | result << "Exception originated from\n\n"; | |||
999 | ||||
1000 | for (const tensorflow::StackFrame& stack_frame : stack_trace) { | |||
1001 | PyObject* line_str_obj = PyObject_CallFunction( | |||
1002 | getline, const_cast<char*>("si"), stack_frame.file_name.c_str(), | |||
1003 | stack_frame.line_number); | |||
1004 | tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj); | |||
1005 | tensorflow::str_util::RemoveWhitespaceContext(&line_str); | |||
1006 | result << " File \"" << stack_frame.file_name << "\", line " | |||
1007 | << stack_frame.line_number << ", in " << stack_frame.function_name | |||
1008 | << '\n'; | |||
1009 | ||||
1010 | if (!line_str.empty()) result << " " << line_str << '\n'; | |||
1011 | Py_XDECREF(line_str_obj)_Py_XDECREF(((PyObject*)(line_str_obj))); | |||
1012 | } | |||
1013 | ||||
1014 | Py_DecRef(getline); | |||
1015 | Py_DecRef(linecache); | |||
1016 | ||||
1017 | result << '\n' << status.error_message(); | |||
1018 | return result.str(); | |||
1019 | } | |||
1020 | ||||
1021 | int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) { | |||
1022 | if (status->status.ok()) return 0; | |||
1023 | const char* msg = TF_Message(status); | |||
1024 | if (exception == nullptr) { | |||
1025 | tensorflow::mutex_lock l(exception_class_mutex); | |||
1026 | if (exception_class != nullptr) { | |||
1027 | tensorflow::Safe_PyObjectPtr payloads(PyDict_New()); | |||
1028 | for (const auto& payload : | |||
1029 | tensorflow::errors::GetPayloads(status->status)) { | |||
1030 | PyDict_SetItem(payloads.get(), | |||
1031 | PyBytes_FromString(payload.first.c_str()), | |||
1032 | PyBytes_FromString(payload.second.c_str())); | |||
1033 | } | |||
1034 | tensorflow::Safe_PyObjectPtr val(Py_BuildValue( | |||
1035 | "siO", FormatErrorStatusStackTrace(status->status).c_str(), | |||
1036 | TF_GetCode(status), payloads.get())); | |||
1037 | if (PyErr_Occurred()) { | |||
1038 | // NOTE: This hides the actual error (i.e. the reason `status` was not | |||
1039 | // TF_OK), but there is nothing we can do at this point since we can't | |||
1040 | // generate a reasonable error from the status. | |||
1041 | // Consider adding a message explaining this. | |||
1042 | return -1; | |||
1043 | } | |||
1044 | PyErr_SetObject(exception_class, val.get()); | |||
1045 | return -1; | |||
1046 | } else { | |||
1047 | exception = PyExc_RuntimeError; | |||
1048 | } | |||
1049 | } | |||
1050 | // May be update already set exception. | |||
1051 | PyErr_SetString(exception, msg); | |||
1052 | return -1; | |||
1053 | } | |||
1054 | ||||
1055 | int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, | |||
1056 | PyObject* exception) { | |||
1057 | if (status.ok()) return 0; | |||
1058 | const char* msg = status.error_message().c_str(); | |||
1059 | if (exception == nullptr) { | |||
1060 | tensorflow::mutex_lock l(exception_class_mutex); | |||
1061 | if (exception_class != nullptr) { | |||
1062 | tensorflow::Safe_PyObjectPtr payloads(PyDict_New()); | |||
1063 | for (const auto& element : tensorflow::errors::GetPayloads(status)) { | |||
1064 | PyDict_SetItem(payloads.get(), | |||
1065 | PyBytes_FromString(element.first.c_str()), | |||
1066 | PyBytes_FromString(element.second.c_str())); | |||
1067 | } | |||
1068 | tensorflow::Safe_PyObjectPtr val( | |||
1069 | Py_BuildValue("siO", FormatErrorStatusStackTrace(status).c_str(), | |||
1070 | status.code(), payloads.get())); | |||
1071 | PyErr_SetObject(exception_class, val.get()); | |||
1072 | return -1; | |||
1073 | } else { | |||
1074 | exception = PyExc_RuntimeError; | |||
1075 | } | |||
1076 | } | |||
1077 | // May be update already set exception. | |||
1078 | PyErr_SetString(exception, msg); | |||
1079 | return -1; | |||
1080 | } | |||
1081 | ||||
1082 | const char* TFE_GetPythonString(PyObject* o) { | |||
1083 | #if PY_MAJOR_VERSION3 >= 3 | |||
1084 | if (PyBytes_Check(o)((((((PyObject*)(o))->ob_type))->tp_flags & ((1UL << 27))) != 0)) { | |||
1085 | return PyBytes_AsString(o); | |||
1086 | } else { | |||
1087 | return PyUnicode_AsUTF8(o); | |||
1088 | } | |||
1089 | #else | |||
1090 | return PyBytes_AsString(o); | |||
1091 | #endif | |||
1092 | } | |||
1093 | ||||
1094 | int64_t get_uid() { return _uid++; } | |||
1095 | ||||
1096 | PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); } | |||
1097 | ||||
1098 | void TFE_DeleteContextCapsule(PyObject* context) { | |||
1099 | TFE_Context* ctx = | |||
1100 | reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr)); | |||
1101 | auto op = ReleaseThreadLocalOp(ctx); | |||
1102 | op.reset(); | |||
1103 | TFE_DeleteContext(ctx); | |||
1104 | } | |||
1105 | ||||
1106 | static int64_t MakeInt(PyObject* integer) { | |||
1107 | #if PY_MAJOR_VERSION3 >= 3 | |||
1108 | return PyLong_AsLong(integer); | |||
1109 | #else | |||
1110 | return PyInt_AsLong(integer); | |||
1111 | #endif | |||
1112 | } | |||
1113 | ||||
1114 | static int64_t FastTensorId(PyObject* tensor) { | |||
1115 | if (EagerTensor_CheckExact(tensor)) { | |||
1116 | return PyEagerTensor_ID(tensor); | |||
1117 | } | |||
1118 | PyObject* id_field = PyObject_GetAttrString(tensor, "_id"); | |||
1119 | if (id_field == nullptr) { | |||
1120 | return -1; | |||
1121 | } | |||
1122 | int64_t id = MakeInt(id_field); | |||
1123 | Py_DECREF(id_field)_Py_DECREF(((PyObject*)(id_field))); | |||
1124 | return id; | |||
1125 | } | |||
1126 | ||||
1127 | namespace tensorflow { | |||
1128 | DataType PyTensor_DataType(PyObject* tensor) { | |||
1129 | if (EagerTensor_CheckExact(tensor)) { | |||
1130 | return PyEagerTensor_Dtype(tensor); | |||
1131 | } else { | |||
1132 | #if PY_MAJOR_VERSION3 < 3 | |||
1133 | // Python 2.x: | |||
1134 | static PyObject* dtype_attr = PyString_InternFromString("dtype"); | |||
1135 | static PyObject* type_enum_attr = PyString_InternFromString("_type_enum"); | |||
1136 | #else | |||
1137 | // Python 3.x: | |||
1138 | static PyObject* dtype_attr = PyUnicode_InternFromString("dtype"); | |||
1139 | static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum"); | |||
1140 | #endif | |||
1141 | Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr)); | |||
1142 | if (!dtype_field) { | |||
1143 | return DT_INVALID; | |||
1144 | } | |||
1145 | ||||
1146 | Safe_PyObjectPtr enum_field( | |||
1147 | PyObject_GetAttr(dtype_field.get(), type_enum_attr)); | |||
1148 | if (!enum_field) { | |||
1149 | return DT_INVALID; | |||
1150 | } | |||
1151 | ||||
1152 | return static_cast<DataType>(MakeInt(enum_field.get())); | |||
1153 | } | |||
1154 | } | |||
1155 | } // namespace tensorflow | |||
1156 | ||||
1157 | class PyTapeTensor { | |||
1158 | public: | |||
1159 | PyTapeTensor(int64_t id, tensorflow::DataType dtype, | |||
1160 | const tensorflow::TensorShape& shape) | |||
1161 | : id_(id), dtype_(dtype), shape_(shape) {} | |||
1162 | PyTapeTensor(int64_t id, tensorflow::DataType dtype, PyObject* shape) | |||
1163 | : id_(id), dtype_(dtype), shape_(shape) { | |||
1164 | Py_INCREF(absl::get<1>(shape_))_Py_INCREF(((PyObject*)(absl::get<1>(shape_)))); | |||
1165 | } | |||
1166 | PyTapeTensor(const PyTapeTensor& other) { | |||
1167 | id_ = other.id_; | |||
1168 | dtype_ = other.dtype_; | |||
1169 | shape_ = other.shape_; | |||
1170 | if (shape_.index() == 1) { | |||
1171 | Py_INCREF(absl::get<1>(shape_))_Py_INCREF(((PyObject*)(absl::get<1>(shape_)))); | |||
1172 | } | |||
1173 | } | |||
1174 | ||||
1175 | ~PyTapeTensor() { | |||
1176 | if (shape_.index() == 1) { | |||
1177 | Py_DECREF(absl::get<1>(shape_))_Py_DECREF(((PyObject*)(absl::get<1>(shape_)))); | |||
1178 | } | |||
1179 | } | |||
1180 | PyObject* GetShape() const; | |||
1181 | PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); } | |||
1182 | int64_t GetID() const { return id_; } | |||
1183 | tensorflow::DataType GetDType() const { return dtype_; } | |||
1184 | ||||
1185 | PyObject* OnesLike() const; | |||
1186 | PyObject* ZerosLike() const; | |||
1187 | ||||
1188 | private: | |||
1189 | int64_t id_; | |||
1190 | tensorflow::DataType dtype_; | |||
1191 | ||||
1192 | // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that | |||
1193 | // PyObject is the tensor itself. This is used to support tf.shape(tensor) for | |||
1194 | // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype | |||
1195 | // tensors. | |||
1196 | absl::variant<tensorflow::TensorShape, PyObject*> shape_; | |||
1197 | }; | |||
1198 | ||||
1199 | static PyTapeTensor TapeTensorFromTensor(PyObject* tensor); | |||
1200 | ||||
1201 | class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction, | |||
1202 | PyTapeTensor> { | |||
1203 | public: | |||
1204 | explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { | |||
1205 | Py_INCREF(py_vspace_)_Py_INCREF(((PyObject*)(py_vspace_))); | |||
1206 | } | |||
1207 | ||||
1208 | tensorflow::Status Initialize() { | |||
1209 | num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); | |||
1210 | if (num_elements_ == nullptr) { | |||
1211 | return tensorflow::errors::InvalidArgument("invalid vspace"); | |||
1212 | } | |||
1213 | aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn"); | |||
1214 | if (aggregate_fn_ == nullptr) { | |||
1215 | return tensorflow::errors::InvalidArgument("invalid vspace"); | |||
1216 | } | |||
1217 | zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn"); | |||
1218 | if (zeros_fn_ == nullptr) { | |||
1219 | return tensorflow::errors::InvalidArgument("invalid vspace"); | |||
1220 | } | |||
1221 | zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn"); | |||
1222 | if (zeros_like_fn_ == nullptr) { | |||
1223 | return tensorflow::errors::InvalidArgument("invalid vspace"); | |||
1224 | } | |||
1225 | ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn"); | |||
1226 | if (ones_fn_ == nullptr) { | |||
1227 | return tensorflow::errors::InvalidArgument("invalid vspace"); | |||
1228 | } | |||
1229 | ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn"); | |||
1230 | if (ones_like_fn_ == nullptr) { | |||
1231 | return tensorflow::errors::InvalidArgument("invalid vspace"); | |||
1232 | } | |||
1233 | graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn"); | |||
1234 | if (graph_shape_fn_ == nullptr) { | |||
1235 | return tensorflow::errors::InvalidArgument("invalid vspace"); | |||
1236 | } | |||
1237 | return tensorflow::Status::OK(); | |||
1238 | } | |||
1239 | ||||
1240 | ~PyVSpace() override { | |||
1241 | Py_XDECREF(num_elements_)_Py_XDECREF(((PyObject*)(num_elements_))); | |||
1242 | Py_XDECREF(aggregate_fn_)_Py_XDECREF(((PyObject*)(aggregate_fn_))); | |||
1243 | Py_XDECREF(zeros_fn_)_Py_XDECREF(((PyObject*)(zeros_fn_))); | |||
1244 | Py_XDECREF(zeros_like_fn_)_Py_XDECREF(((PyObject*)(zeros_like_fn_))); | |||
1245 | Py_XDECREF(ones_fn_)_Py_XDECREF(((PyObject*)(ones_fn_))); | |||
1246 | Py_XDECREF(ones_like_fn_)_Py_XDECREF(((PyObject*)(ones_like_fn_))); | |||
1247 | Py_XDECREF(graph_shape_fn_)_Py_XDECREF(((PyObject*)(graph_shape_fn_))); | |||
1248 | ||||
1249 | Py_DECREF(py_vspace_)_Py_DECREF(((PyObject*)(py_vspace_))); | |||
1250 | } | |||
1251 | ||||
1252 | int64_t NumElements(PyObject* tensor) const final { | |||
1253 | if (EagerTensor_CheckExact(tensor)) { | |||
1254 | return PyEagerTensor_NumElements(tensor); | |||
1255 | } | |||
1256 | PyObject* arglist = | |||
1257 | Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor)); | |||
1258 | PyObject* result = PyEval_CallObject(num_elements_, arglist)PyEval_CallObjectWithKeywords(num_elements_, arglist, (PyObject *)__null); | |||
1259 | Py_DECREF(arglist)_Py_DECREF(((PyObject*)(arglist))); | |||
1260 | if (result == nullptr) { | |||
1261 | // The caller detects whether a python exception has been raised. | |||
1262 | return -1; | |||
1263 | } | |||
1264 | int64_t r = MakeInt(result); | |||
1265 | Py_DECREF(result)_Py_DECREF(((PyObject*)(result))); | |||
1266 | return r; | |||
1267 | } | |||
1268 | ||||
1269 | PyObject* AggregateGradients( | |||
1270 | tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final { | |||
1271 | PyObject* list = PyList_New(gradient_tensors.size()); | |||
1272 | for (int i = 0; i < gradient_tensors.size(); ++i) { | |||
1273 | // Note: stealing a reference to the gradient tensors. | |||
1274 | CHECK(gradient_tensors[i] != nullptr)if ((__builtin_expect(!(gradient_tensors[i] != nullptr), 0))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 1274) << "Check failed: " "gradient_tensors[i] != nullptr" " "; | |||
1275 | CHECK(gradient_tensors[i] != Py_None)if ((__builtin_expect(!(gradient_tensors[i] != (&_Py_NoneStruct )), 0))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 1275) << "Check failed: " "gradient_tensors[i] != Py_None" " "; | |||
1276 | PyList_SET_ITEM(list, i,PyList_SetItem(list, i, reinterpret_cast<PyObject*>(gradient_tensors [i])) | |||
1277 | reinterpret_cast<PyObject*>(gradient_tensors[i]))PyList_SetItem(list, i, reinterpret_cast<PyObject*>(gradient_tensors [i])); | |||
1278 | } | |||
1279 | PyObject* arglist = Py_BuildValue("(O)", list); | |||
1280 | CHECK(arglist != nullptr)if ((__builtin_expect(!(arglist != nullptr), 0))) ::tensorflow ::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 1280) << "Check failed: " "arglist != nullptr" " "; | |||
1281 | PyObject* result = PyEval_CallObject(aggregate_fn_, arglist)PyEval_CallObjectWithKeywords(aggregate_fn_, arglist, (PyObject *)__null); | |||
1282 | Py_DECREF(arglist)_Py_DECREF(((PyObject*)(arglist))); | |||
1283 | Py_DECREF(list)_Py_DECREF(((PyObject*)(list))); | |||
1284 | return result; | |||
1285 | } | |||
1286 | ||||
1287 | int64_t TensorId(PyObject* tensor) const final { | |||
1288 | return FastTensorId(tensor); | |||
1289 | } | |||
1290 | ||||
1291 | void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient)_Py_INCREF(((PyObject*)(gradient))); } | |||
1292 | ||||
1293 | PyObject* Ones(PyObject* shape, PyObject* dtype) const { | |||
1294 | if (PyErr_Occurred()) { | |||
1295 | return nullptr; | |||
1296 | } | |||
1297 | PyObject* arg_list = Py_BuildValue("OO", shape, dtype); | |||
1298 | PyObject* result = PyEval_CallObject(ones_fn_, arg_list)PyEval_CallObjectWithKeywords(ones_fn_, arg_list, (PyObject * )__null); | |||
1299 | Py_DECREF(arg_list)_Py_DECREF(((PyObject*)(arg_list))); | |||
1300 | return result; | |||
1301 | } | |||
1302 | ||||
1303 | PyObject* OnesLike(PyObject* tensor) const { | |||
1304 | if (PyErr_Occurred()) { | |||
1305 | return nullptr; | |||
1306 | } | |||
1307 | return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL__null); | |||
1308 | } | |||
1309 | ||||
1310 | // Builds a tensor filled with ones with the same shape and dtype as `t`. | |||
1311 | Status BuildOnesLike(const PyTapeTensor& t, | |||
1312 | PyObject** result) const override { | |||
1313 | *result = t.OnesLike(); | |||
1314 | return Status::OK(); | |||
1315 | } | |||
1316 | ||||
1317 | PyObject* Zeros(PyObject* shape, PyObject* dtype) const { | |||
1318 | if (PyErr_Occurred()) { | |||
1319 | return nullptr; | |||
1320 | } | |||
1321 | PyObject* arg_list = Py_BuildValue("OO", shape, dtype); | |||
1322 | PyObject* result = PyEval_CallObject(zeros_fn_, arg_list)PyEval_CallObjectWithKeywords(zeros_fn_, arg_list, (PyObject * )__null); | |||
1323 | Py_DECREF(arg_list)_Py_DECREF(((PyObject*)(arg_list))); | |||
1324 | return result; | |||
1325 | } | |||
1326 | ||||
1327 | PyObject* ZerosLike(PyObject* tensor) const { | |||
1328 | if (PyErr_Occurred()) { | |||
1329 | return nullptr; | |||
1330 | } | |||
1331 | return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL__null); | |||
1332 | } | |||
1333 | ||||
1334 | PyObject* GraphShape(PyObject* tensor) const { | |||
1335 | PyObject* arg_list = Py_BuildValue("(O)", tensor); | |||
1336 | PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list)PyEval_CallObjectWithKeywords(graph_shape_fn_, arg_list, (PyObject *)__null); | |||
1337 | Py_DECREF(arg_list)_Py_DECREF(((PyObject*)(arg_list))); | |||
1338 | return result; | |||
1339 | } | |||
1340 | ||||
1341 | tensorflow::Status CallBackwardFunction( | |||
1342 | const string& op_type, PyBackwardFunction* backward_function, | |||
1343 | const std::vector<int64_t>& unneeded_gradients, | |||
1344 | tensorflow::gtl::ArraySlice<PyObject*> output_gradients, | |||
1345 | absl::Span<PyObject*> result) const final { | |||
1346 | PyObject* grads = PyTuple_New(output_gradients.size()); | |||
1347 | for (int i = 0; i < output_gradients.size(); ++i) { | |||
1348 | if (output_gradients[i] == nullptr) { | |||
1349 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | |||
1350 | PyTuple_SET_ITEM(grads, i, Py_None)PyTuple_SetItem(grads, i, (&_Py_NoneStruct)); | |||
1351 | } else { | |||
1352 | PyTuple_SET_ITEM(grads, i,PyTuple_SetItem(grads, i, reinterpret_cast<PyObject*>(output_gradients [i])) | |||
1353 | reinterpret_cast<PyObject*>(output_gradients[i]))PyTuple_SetItem(grads, i, reinterpret_cast<PyObject*>(output_gradients [i])); | |||
1354 | } | |||
1355 | } | |||
1356 | PyObject* py_result = (*backward_function)(grads, unneeded_gradients); | |||
1357 | Py_DECREF(grads)_Py_DECREF(((PyObject*)(grads))); | |||
1358 | if (py_result == nullptr) { | |||
1359 | return tensorflow::errors::Internal("gradient function threw exceptions"); | |||
1360 | } | |||
1361 | PyObject* seq = | |||
1362 | PySequence_Fast(py_result, "expected a sequence of gradients"); | |||
1363 | if (seq == nullptr) { | |||
1364 | return tensorflow::errors::InvalidArgument( | |||
1365 | "gradient function did not return a list"); | |||
1366 | } | |||
1367 | int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size)); | |||
1368 | if (len != result.size()) { | |||
1369 | return tensorflow::errors::Internal( | |||
1370 | "Recorded operation '", op_type, | |||
1371 | "' returned too few gradients. Expected ", result.size(), | |||
1372 | " but received ", len); | |||
1373 | } | |||
1374 | PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item); | |||
1375 | VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 1375, tensorflow::INFO) << "Gradient length is " << len; | |||
1376 | for (int i = 0; i < len; ++i) { | |||
1377 | PyObject* item = seq_array[i]; | |||
1378 | if (item == Py_None(&_Py_NoneStruct)) { | |||
1379 | result[i] = nullptr; | |||
1380 | } else { | |||
1381 | Py_INCREF(item)_Py_INCREF(((PyObject*)(item))); | |||
1382 | result[i] = item; | |||
1383 | } | |||
1384 | } | |||
1385 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | |||
1386 | Py_DECREF(py_result)_Py_DECREF(((PyObject*)(py_result))); | |||
1387 | return tensorflow::Status::OK(); | |||
1388 | } | |||
1389 | ||||
1390 | void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor)_Py_XDECREF(((PyObject*)(tensor))); } | |||
1391 | ||||
1392 | PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final { | |||
1393 | return TapeTensorFromTensor(tensor); | |||
1394 | } | |||
1395 | ||||
1396 | private: | |||
1397 | PyObject* py_vspace_; | |||
1398 | ||||
1399 | PyObject* num_elements_; | |||
1400 | PyObject* aggregate_fn_; | |||
1401 | PyObject* zeros_fn_; | |||
1402 | PyObject* zeros_like_fn_; | |||
1403 | PyObject* ones_fn_; | |||
1404 | PyObject* ones_like_fn_; | |||
1405 | PyObject* graph_shape_fn_; | |||
1406 | }; | |||
1407 | PyVSpace* py_vspace = nullptr; | |||
1408 | ||||
1409 | bool HasAccumulator(); | |||
1410 | ||||
1411 | PyObject* TFE_Py_RegisterVSpace(PyObject* e) { | |||
1412 | if (py_vspace != nullptr) { | |||
1413 | if (HasAccumulator()) { | |||
1414 | // Accumulators reference py_vspace, so we can't swap it out while one is | |||
1415 | // active. This is unlikely to ever happen. | |||
1416 | MaybeRaiseExceptionFromStatus( | |||
1417 | tensorflow::errors::Internal( | |||
1418 | "Can't change the vspace implementation while a " | |||
1419 | "forward accumulator is active."), | |||
1420 | nullptr); | |||
1421 | } | |||
1422 | delete py_vspace; | |||
1423 | } | |||
1424 | ||||
1425 | py_vspace = new PyVSpace(e); | |||
1426 | auto status = py_vspace->Initialize(); | |||
1427 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { | |||
1428 | delete py_vspace; | |||
1429 | return nullptr; | |||
1430 | } | |||
1431 | ||||
1432 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
1433 | } | |||
1434 | ||||
1435 | PyObject* PyTapeTensor::GetShape() const { | |||
1436 | if (shape_.index() == 0) { | |||
1437 | auto& shape = absl::get<0>(shape_); | |||
1438 | PyObject* py_shape = PyTuple_New(shape.dims()); | |||
1439 | for (int i = 0; i < shape.dims(); ++i) { | |||
1440 | PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)))PyTuple_SetItem(py_shape, i, PyLong_FromLong(shape.dim_size(i ))); | |||
1441 | } | |||
1442 | ||||
1443 | return py_shape; | |||
1444 | } | |||
1445 | ||||
1446 | return py_vspace->GraphShape(absl::get<1>(shape_)); | |||
1447 | } | |||
1448 | ||||
1449 | PyObject* PyTapeTensor::OnesLike() const { | |||
1450 | if (shape_.index() == 1) { | |||
1451 | PyObject* tensor = absl::get<1>(shape_); | |||
1452 | return py_vspace->OnesLike(tensor); | |||
1453 | } | |||
1454 | PyObject* py_shape = GetShape(); | |||
1455 | PyObject* dtype_field = GetPyDType(); | |||
1456 | PyObject* result = py_vspace->Ones(py_shape, dtype_field); | |||
1457 | Py_DECREF(dtype_field)_Py_DECREF(((PyObject*)(dtype_field))); | |||
1458 | Py_DECREF(py_shape)_Py_DECREF(((PyObject*)(py_shape))); | |||
1459 | return result; | |||
1460 | } | |||
1461 | ||||
1462 | PyObject* PyTapeTensor::ZerosLike() const { | |||
1463 | if (GetDType() == tensorflow::DT_RESOURCE) { | |||
1464 | // Gradient functions for ops which return resource tensors accept | |||
1465 | // None. This is the behavior of py_vspace->Zeros, but checking here avoids | |||
1466 | // issues with ZerosLike. | |||
1467 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
1468 | } | |||
1469 | if (shape_.index() == 1) { | |||
1470 | PyObject* tensor = absl::get<1>(shape_); | |||
1471 | return py_vspace->ZerosLike(tensor); | |||
1472 | } | |||
1473 | PyObject* py_shape = GetShape(); | |||
1474 | PyObject* dtype_field = GetPyDType(); | |||
1475 | PyObject* result = py_vspace->Zeros(py_shape, dtype_field); | |||
1476 | Py_DECREF(dtype_field)_Py_DECREF(((PyObject*)(dtype_field))); | |||
1477 | Py_DECREF(py_shape)_Py_DECREF(((PyObject*)(py_shape))); | |||
1478 | return result; | |||
1479 | } | |||
1480 | ||||
1481 | // Keeps track of all variables that have been accessed during execution. | |||
1482 | class VariableWatcher { | |||
1483 | public: | |||
1484 | VariableWatcher() {} | |||
1485 | ||||
1486 | ~VariableWatcher() { | |||
1487 | for (const IdAndVariable& v : watched_variables_) { | |||
1488 | Py_DECREF(v.variable)_Py_DECREF(((PyObject*)(v.variable))); | |||
1489 | } | |||
1490 | } | |||
1491 | ||||
1492 | int64_t WatchVariable(PyObject* v) { | |||
1493 | tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle")); | |||
1494 | if (handle == nullptr) { | |||
1495 | return -1; | |||
1496 | } | |||
1497 | int64_t id = FastTensorId(handle.get()); | |||
1498 | ||||
1499 | tensorflow::mutex_lock l(watched_variables_mu_); | |||
1500 | auto insert_result = watched_variables_.emplace(id, v); | |||
1501 | ||||
1502 | if (insert_result.second) { | |||
1503 | // Only increment the reference count if we aren't already watching this | |||
1504 | // variable. | |||
1505 | Py_INCREF(v)_Py_INCREF(((PyObject*)(v))); | |||
1506 | } | |||
1507 | ||||
1508 | return id; | |||
1509 | } | |||
1510 | ||||
1511 | PyObject* GetVariablesAsPyTuple() { | |||
1512 | tensorflow::mutex_lock l(watched_variables_mu_); | |||
1513 | PyObject* result = PyTuple_New(watched_variables_.size()); | |||
1514 | Py_ssize_t pos = 0; | |||
1515 | for (const IdAndVariable& id_and_variable : watched_variables_) { | |||
1516 | PyTuple_SET_ITEM(result, pos++, id_and_variable.variable)PyTuple_SetItem(result, pos++, id_and_variable.variable); | |||
1517 | Py_INCREF(id_and_variable.variable)_Py_INCREF(((PyObject*)(id_and_variable.variable))); | |||
1518 | } | |||
1519 | return result; | |||
1520 | } | |||
1521 | ||||
1522 | private: | |||
1523 | // We store an IdAndVariable in the map since the map needs to be locked | |||
1524 | // during insert, but should not call back into python during insert to avoid | |||
1525 | // deadlocking with the GIL. | |||
1526 | struct IdAndVariable { | |||
1527 | int64_t id; | |||
1528 | PyObject* variable; | |||
1529 | ||||
1530 | IdAndVariable(int64_t id, PyObject* variable) | |||
1531 | : id(id), variable(variable) {} | |||
1532 | }; | |||
1533 | struct CompareById { | |||
1534 | bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const { | |||
1535 | return lhs.id < rhs.id; | |||
1536 | } | |||
1537 | }; | |||
1538 | ||||
1539 | tensorflow::mutex watched_variables_mu_; | |||
1540 | std::set<IdAndVariable, CompareById> watched_variables_ | |||
1541 | TF_GUARDED_BY(watched_variables_mu_)__attribute__((guarded_by(watched_variables_mu_))); | |||
1542 | }; | |||
1543 | ||||
1544 | class GradientTape | |||
1545 | : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, | |||
1546 | PyTapeTensor> { | |||
1547 | public: | |||
1548 | explicit GradientTape(bool persistent, bool watch_accessed_variables) | |||
1549 | : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, | |||
1550 | PyTapeTensor>(persistent), | |||
1551 | watch_accessed_variables_(watch_accessed_variables) {} | |||
1552 | ||||
1553 | virtual ~GradientTape() {} | |||
1554 | ||||
1555 | void VariableAccessed(PyObject* v) { | |||
1556 | if (watch_accessed_variables_) { | |||
1557 | WatchVariable(v); | |||
1558 | } | |||
1559 | } | |||
1560 | ||||
1561 | void WatchVariable(PyObject* v) { | |||
1562 | int64_t id = variable_watcher_.WatchVariable(v); | |||
1563 | ||||
1564 | if (!PyErr_Occurred()) { | |||
1565 | this->Watch(id); | |||
1566 | } | |||
1567 | } | |||
1568 | ||||
1569 | PyObject* GetVariablesAsPyTuple() { | |||
1570 | return variable_watcher_.GetVariablesAsPyTuple(); | |||
1571 | } | |||
1572 | ||||
1573 | private: | |||
1574 | bool watch_accessed_variables_; | |||
1575 | VariableWatcher variable_watcher_; | |||
1576 | }; | |||
1577 | ||||
1578 | typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction, | |||
1579 | PyTapeTensor> | |||
1580 | ForwardAccumulator; | |||
1581 | ||||
1582 | // Incremented when a GradientTape or accumulator is newly added to a set, and | |||
1583 | // used to enforce an ordering between them. | |||
1584 | std::atomic_uint_fast64_t tape_nesting_id_counter(0); | |||
1585 | ||||
1586 | typedef struct { | |||
1587 | PyObject_HEADPyObject ob_base; | |||
1588 | /* Type-specific fields go here. */ | |||
1589 | GradientTape* tape; | |||
1590 | // A nesting order between GradientTapes and ForwardAccumulators, used to | |||
1591 | // ensure that GradientTapes do not watch the products of outer | |||
1592 | // ForwardAccumulators. | |||
1593 | int64_t nesting_id; | |||
1594 | } TFE_Py_Tape; | |||
1595 | ||||
1596 | static void TFE_Py_Tape_Delete(PyObject* tape) { | |||
1597 | delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape; | |||
1598 | Py_TYPE(tape)(((PyObject*)(tape))->ob_type)->tp_free(tape); | |||
1599 | } | |||
1600 | ||||
1601 | static PyTypeObject TFE_Py_Tape_Type = { | |||
1602 | PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, "tfe.Tape", /* tp_name */ | |||
1603 | sizeof(TFE_Py_Tape), /* tp_basicsize */ | |||
1604 | 0, /* tp_itemsize */ | |||
1605 | &TFE_Py_Tape_Delete, /* tp_dealloc */ | |||
1606 | #if PY_VERSION_HEX((3 << 24) | (8 << 16) | (5 << 8) | (0xF << 4) | (0 << 0)) < 0x03080000 | |||
1607 | nullptr, /* tp_print */ | |||
1608 | #else | |||
1609 | 0, /* tp_vectorcall_offset */ | |||
1610 | #endif | |||
1611 | nullptr, /* tp_getattr */ | |||
1612 | nullptr, /* tp_setattr */ | |||
1613 | nullptr, /* tp_reserved */ | |||
1614 | nullptr, /* tp_repr */ | |||
1615 | nullptr, /* tp_as_number */ | |||
1616 | nullptr, /* tp_as_sequence */ | |||
1617 | nullptr, /* tp_as_mapping */ | |||
1618 | nullptr, /* tp_hash */ | |||
1619 | nullptr, /* tp_call */ | |||
1620 | nullptr, /* tp_str */ | |||
1621 | nullptr, /* tp_getattro */ | |||
1622 | nullptr, /* tp_setattro */ | |||
1623 | nullptr, /* tp_as_buffer */ | |||
1624 | Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0), /* tp_flags */ | |||
1625 | "TFE_Py_Tape objects", /* tp_doc */ | |||
1626 | }; | |||
1627 | ||||
1628 | typedef struct { | |||
1629 | PyObject_HEADPyObject ob_base; | |||
1630 | /* Type-specific fields go here. */ | |||
1631 | ForwardAccumulator* accumulator; | |||
1632 | // A nesting order between GradientTapes and ForwardAccumulators, used to | |||
1633 | // ensure that GradientTapes do not watch the products of outer | |||
1634 | // ForwardAccumulators. | |||
1635 | int64_t nesting_id; | |||
1636 | } TFE_Py_ForwardAccumulator; | |||
1637 | ||||
1638 | static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) { | |||
1639 | delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator; | |||
1640 | Py_TYPE(accumulator)(((PyObject*)(accumulator))->ob_type)->tp_free(accumulator); | |||
1641 | } | |||
1642 | ||||
1643 | static PyTypeObject TFE_Py_ForwardAccumulator_Type = { | |||
1644 | PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, "ForwardAccumulator", /* tp_name */ | |||
1645 | sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */ | |||
1646 | 0, /* tp_itemsize */ | |||
1647 | &TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */ | |||
1648 | #if PY_VERSION_HEX((3 << 24) | (8 << 16) | (5 << 8) | (0xF << 4) | (0 << 0)) < 0x03080000 | |||
1649 | nullptr, /* tp_print */ | |||
1650 | #else | |||
1651 | 0, /* tp_vectorcall_offset */ | |||
1652 | #endif | |||
1653 | nullptr, /* tp_getattr */ | |||
1654 | nullptr, /* tp_setattr */ | |||
1655 | nullptr, /* tp_reserved */ | |||
1656 | nullptr, /* tp_repr */ | |||
1657 | nullptr, /* tp_as_number */ | |||
1658 | nullptr, /* tp_as_sequence */ | |||
1659 | nullptr, /* tp_as_mapping */ | |||
1660 | nullptr, /* tp_hash */ | |||
1661 | nullptr, /* tp_call */ | |||
1662 | nullptr, /* tp_str */ | |||
1663 | nullptr, /* tp_getattro */ | |||
1664 | nullptr, /* tp_setattro */ | |||
1665 | nullptr, /* tp_as_buffer */ | |||
1666 | Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0), /* tp_flags */ | |||
1667 | "TFE_Py_ForwardAccumulator objects", /* tp_doc */ | |||
1668 | }; | |||
1669 | ||||
1670 | typedef struct { | |||
1671 | PyObject_HEADPyObject ob_base; | |||
1672 | /* Type-specific fields go here. */ | |||
1673 | VariableWatcher* variable_watcher; | |||
1674 | } TFE_Py_VariableWatcher; | |||
1675 | ||||
1676 | static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) { | |||
1677 | delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher) | |||
1678 | ->variable_watcher; | |||
1679 | Py_TYPE(variable_watcher)(((PyObject*)(variable_watcher))->ob_type)->tp_free(variable_watcher); | |||
1680 | } | |||
1681 | ||||
1682 | static PyTypeObject TFE_Py_VariableWatcher_Type = { | |||
1683 | PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, "tfe.VariableWatcher", /* tp_name */ | |||
1684 | sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */ | |||
1685 | 0, /* tp_itemsize */ | |||
1686 | &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */ | |||
1687 | #if PY_VERSION_HEX((3 << 24) | (8 << 16) | (5 << 8) | (0xF << 4) | (0 << 0)) < 0x03080000 | |||
1688 | nullptr, /* tp_print */ | |||
1689 | #else | |||
1690 | 0, /* tp_vectorcall_offset */ | |||
1691 | #endif | |||
1692 | nullptr, /* tp_getattr */ | |||
1693 | nullptr, /* tp_setattr */ | |||
1694 | nullptr, /* tp_reserved */ | |||
1695 | nullptr, /* tp_repr */ | |||
1696 | nullptr, /* tp_as_number */ | |||
1697 | nullptr, /* tp_as_sequence */ | |||
1698 | nullptr, /* tp_as_mapping */ | |||
1699 | nullptr, /* tp_hash */ | |||
1700 | nullptr, /* tp_call */ | |||
1701 | nullptr, /* tp_str */ | |||
1702 | nullptr, /* tp_getattro */ | |||
1703 | nullptr, /* tp_setattro */ | |||
1704 | nullptr, /* tp_as_buffer */ | |||
1705 | Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0), /* tp_flags */ | |||
1706 | "TFE_Py_VariableWatcher objects", /* tp_doc */ | |||
1707 | }; | |||
1708 | ||||
1709 | // Note: in the current design no mutex is needed here because of the python | |||
1710 | // GIL, which is always held when any TFE_Py_* methods are called. We should | |||
1711 | // revisit this if/when decide to not hold the GIL while manipulating the tape | |||
1712 | // stack. | |||
1713 | tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() { | |||
1714 | thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> | |||
1715 | tape_set = nullptr; | |||
1716 | if (tape_set == nullptr) { | |||
1717 | tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>); | |||
1718 | } | |||
1719 | return tape_set.get(); | |||
1720 | } | |||
1721 | ||||
1722 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>* | |||
1723 | GetVariableWatcherSet() { | |||
1724 | thread_local std::unique_ptr< | |||
1725 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> | |||
1726 | variable_watcher_set = nullptr; | |||
1727 | if (variable_watcher_set == nullptr) { | |||
1728 | variable_watcher_set.reset( | |||
1729 | new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>); | |||
1730 | } | |||
1731 | return variable_watcher_set.get(); | |||
1732 | } | |||
1733 | ||||
1734 | // A linked hash set, where iteration is in insertion order. | |||
1735 | // | |||
1736 | // Nested accumulators rely on op recording happening in insertion order, so an | |||
1737 | // unordered data structure like CompactPointerSet is not suitable. Outer | |||
1738 | // accumulators need to observe operations first so they know to watch the inner | |||
1739 | // accumulator's jvp computation. | |||
1740 | // | |||
1741 | // Not thread safe. | |||
1742 | class AccumulatorSet { | |||
1743 | public: | |||
1744 | // Returns true if `element` was newly inserted, false if it already exists. | |||
1745 | bool insert(TFE_Py_ForwardAccumulator* element) { | |||
1746 | if (map_.find(element) != map_.end()) { | |||
1747 | return false; | |||
1748 | } | |||
1749 | ListType::iterator it = ordered_.insert(ordered_.end(), element); | |||
1750 | map_.insert(std::make_pair(element, it)); | |||
1751 | return true; | |||
1752 | } | |||
1753 | ||||
1754 | void erase(TFE_Py_ForwardAccumulator* element) { | |||
1755 | MapType::iterator existing = map_.find(element); | |||
1756 | if (existing == map_.end()) { | |||
1757 | return; | |||
1758 | } | |||
1759 | ListType::iterator list_position = existing->second; | |||
1760 | map_.erase(existing); | |||
1761 | ordered_.erase(list_position); | |||
1762 | } | |||
1763 | ||||
1764 | bool empty() const { return ordered_.empty(); } | |||
1765 | ||||
1766 | size_t size() const { return ordered_.size(); } | |||
1767 | ||||
1768 | private: | |||
1769 | typedef std::list<TFE_Py_ForwardAccumulator*> ListType; | |||
1770 | typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*, | |||
1771 | ListType::iterator> | |||
1772 | MapType; | |||
1773 | ||||
1774 | public: | |||
1775 | typedef ListType::const_iterator const_iterator; | |||
1776 | typedef ListType::const_reverse_iterator const_reverse_iterator; | |||
1777 | ||||
1778 | const_iterator begin() const { return ordered_.begin(); } | |||
1779 | const_iterator end() const { return ordered_.end(); } | |||
1780 | ||||
1781 | const_reverse_iterator rbegin() const { return ordered_.rbegin(); } | |||
1782 | const_reverse_iterator rend() const { return ordered_.rend(); } | |||
1783 | ||||
1784 | private: | |||
1785 | MapType map_; | |||
1786 | ListType ordered_; | |||
1787 | }; | |||
1788 | ||||
1789 | AccumulatorSet* GetAccumulatorSet() { | |||
1790 | thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr}; | |||
1791 | if (accumulator_set == nullptr) { | |||
1792 | accumulator_set.reset(new AccumulatorSet); | |||
1793 | } | |||
1794 | return accumulator_set.get(); | |||
1795 | } | |||
1796 | ||||
1797 | inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); } | |||
1798 | ||||
1799 | inline bool HasGradientTape() { return !GetTapeSet()->empty(); } | |||
1800 | ||||
1801 | inline bool HasAccumulatorOrTape() { | |||
1802 | return HasGradientTape() || HasAccumulator(); | |||
1803 | } | |||
1804 | ||||
1805 | // A safe copy of a set, used for tapes and accumulators. The copy is not | |||
1806 | // affected by other python threads changing the set of active tapes. | |||
1807 | template <typename ContainerType> | |||
1808 | class SafeSetCopy { | |||
1809 | public: | |||
1810 | explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) { | |||
1811 | for (auto* member : set_copy_) { | |||
1812 | Py_INCREF(member)_Py_INCREF(((PyObject*)(member))); | |||
1813 | } | |||
1814 | } | |||
1815 | ||||
1816 | ~SafeSetCopy() { | |||
1817 | for (auto* member : set_copy_) { | |||
1818 | Py_DECREF(member)_Py_DECREF(((PyObject*)(member))); | |||
1819 | } | |||
1820 | } | |||
1821 | ||||
1822 | typename ContainerType::const_iterator begin() const { | |||
1823 | return set_copy_.begin(); | |||
1824 | } | |||
1825 | ||||
1826 | typename ContainerType::const_iterator end() const { return set_copy_.end(); } | |||
1827 | ||||
1828 | bool empty() const { return set_copy_.empty(); } | |||
1829 | size_t size() const { return set_copy_.size(); } | |||
1830 | ||||
1831 | protected: | |||
1832 | ContainerType set_copy_; | |||
1833 | }; | |||
1834 | ||||
1835 | class SafeTapeSet | |||
1836 | : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> { | |||
1837 | public: | |||
1838 | SafeTapeSet() | |||
1839 | : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>( | |||
1840 | *GetTapeSet()) {} | |||
1841 | }; | |||
1842 | ||||
1843 | class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> { | |||
1844 | public: | |||
1845 | SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {} | |||
1846 | ||||
1847 | typename AccumulatorSet::const_reverse_iterator rbegin() const { | |||
1848 | return set_copy_.rbegin(); | |||
1849 | } | |||
1850 | ||||
1851 | typename AccumulatorSet::const_reverse_iterator rend() const { | |||
1852 | return set_copy_.rend(); | |||
1853 | } | |||
1854 | }; | |||
1855 | ||||
1856 | class SafeVariableWatcherSet | |||
1857 | : public SafeSetCopy< | |||
1858 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> { | |||
1859 | public: | |||
1860 | SafeVariableWatcherSet() | |||
1861 | : SafeSetCopy< | |||
1862 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>( | |||
1863 | *GetVariableWatcherSet()) {} | |||
1864 | }; | |||
1865 | ||||
1866 | bool* ThreadTapeIsStopped() { | |||
1867 | thread_local bool thread_tape_is_stopped{false}; | |||
1868 | return &thread_tape_is_stopped; | |||
1869 | } | |||
1870 | ||||
1871 | void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; } | |||
1872 | ||||
1873 | void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; } | |||
1874 | ||||
1875 | PyObject* TFE_Py_TapeSetIsStopped() { | |||
1876 | if (*ThreadTapeIsStopped()) { | |||
1877 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | |||
1878 | } | |||
1879 | Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct); | |||
1880 | } | |||
1881 | ||||
1882 | PyObject* TFE_Py_TapeSetNew(PyObject* persistent, | |||
1883 | PyObject* watch_accessed_variables) { | |||
1884 | TFE_Py_Tape_Type.tp_new = PyType_GenericNew; | |||
1885 | if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; | |||
1886 | TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type)( (TFE_Py_Tape *) PyObject_Init( (PyObject *) PyObject_Malloc ( ( (&TFE_Py_Tape_Type)->tp_basicsize ) ), (&TFE_Py_Tape_Type )) ); | |||
1887 | tape->tape = new GradientTape(persistent == Py_True((PyObject *) &_Py_TrueStruct), | |||
1888 | watch_accessed_variables == Py_True((PyObject *) &_Py_TrueStruct)); | |||
1889 | Py_INCREF(tape)_Py_INCREF(((PyObject*)(tape))); | |||
1890 | tape->nesting_id = tape_nesting_id_counter.fetch_add(1); | |||
1891 | GetTapeSet()->insert(tape); | |||
1892 | return reinterpret_cast<PyObject*>(tape); | |||
1893 | } | |||
1894 | ||||
1895 | void TFE_Py_TapeSetAdd(PyObject* tape) { | |||
1896 | Py_INCREF(tape)_Py_INCREF(((PyObject*)(tape))); | |||
1897 | TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape); | |||
1898 | if (!GetTapeSet()->insert(tfe_tape).second) { | |||
1899 | // Already exists in the tape set. | |||
1900 | Py_DECREF(tape)_Py_DECREF(((PyObject*)(tape))); | |||
1901 | } else { | |||
1902 | tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1); | |||
1903 | } | |||
1904 | } | |||
1905 | ||||
1906 | PyObject* TFE_Py_TapeSetIsEmpty() { | |||
1907 | if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) { | |||
1908 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | |||
1909 | } | |||
1910 | Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct); | |||
1911 | } | |||
1912 | ||||
1913 | void TFE_Py_TapeSetRemove(PyObject* tape) { | |||
1914 | auto* stack = GetTapeSet(); | |||
1915 | stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape)); | |||
1916 | // We kept a reference to the tape in the set to ensure it wouldn't get | |||
1917 | // deleted under us; cleaning it up here. | |||
1918 | Py_DECREF(tape)_Py_DECREF(((PyObject*)(tape))); | |||
1919 | } | |||
1920 | ||||
1921 | static std::vector<int64_t> MakeIntList(PyObject* list) { | |||
1922 | if (list == Py_None(&_Py_NoneStruct)) { | |||
1923 | return {}; | |||
1924 | } | |||
1925 | PyObject* seq = PySequence_Fast(list, "expected a sequence"); | |||
1926 | if (seq == nullptr) { | |||
1927 | return {}; | |||
1928 | } | |||
1929 | int len = PySequence_Size(list); | |||
1930 | PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item); | |||
1931 | std::vector<int64_t> tensor_ids; | |||
1932 | tensor_ids.reserve(len); | |||
1933 | for (int i = 0; i < len; ++i) { | |||
1934 | PyObject* item = seq_array[i]; | |||
1935 | #if PY_MAJOR_VERSION3 >= 3 | |||
1936 | if (PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL << 24))) != 0)) { | |||
1937 | #else | |||
1938 | if (PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL << 24))) != 0) || PyInt_Check(item)) { | |||
1939 | #endif | |||
1940 | int64_t id = MakeInt(item); | |||
1941 | tensor_ids.push_back(id); | |||
1942 | } else { | |||
1943 | tensor_ids.push_back(-1); | |||
1944 | } | |||
1945 | } | |||
1946 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | |||
1947 | return tensor_ids; | |||
1948 | } | |||
1949 | ||||
1950 | // Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be | |||
1951 | // null. Returns true on success and false on a Python exception. | |||
1952 | bool TensorShapesAndDtypes(PyObject* tensors, std::vector<int64_t>* tensor_ids, | |||
1953 | std::vector<tensorflow::DataType>* dtypes) { | |||
1954 | tensorflow::Safe_PyObjectPtr seq( | |||
1955 | PySequence_Fast(tensors, "expected a sequence")); | |||
1956 | if (seq == nullptr) { | |||
1957 | return false; | |||
1958 | } | |||
1959 | int len = PySequence_Fast_GET_SIZE(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(seq.get()))->ob_size)) : (((PyVarObject* )(((static_cast<void> (0)), (PyTupleObject *)(seq.get() ))))->ob_size)); | |||
1960 | PyObject** seq_array = PySequence_Fast_ITEMS(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq.get()))-> ob_item : ((PyTupleObject *)(seq.get()))->ob_item); | |||
1961 | tensor_ids->reserve(len); | |||
1962 | dtypes->reserve(len); | |||
1963 | for (int i = 0; i < len; ++i) { | |||
1964 | PyObject* item = seq_array[i]; | |||
1965 | tensor_ids->push_back(FastTensorId(item)); | |||
1966 | dtypes->push_back(tensorflow::PyTensor_DataType(item)); | |||
1967 | } | |||
1968 | return true; | |||
1969 | } | |||
1970 | ||||
1971 | bool TapeCouldPossiblyRecord(PyObject* tensors) { | |||
1972 | if (tensors == Py_None(&_Py_NoneStruct)) { | |||
1973 | return false; | |||
1974 | } | |||
1975 | if (*ThreadTapeIsStopped()) { | |||
1976 | return false; | |||
1977 | } | |||
1978 | if (!HasAccumulatorOrTape()) { | |||
1979 | return false; | |||
1980 | } | |||
1981 | return true; | |||
1982 | } | |||
1983 | ||||
1984 | bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); } | |||
1985 | ||||
1986 | bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); } | |||
1987 | ||||
1988 | PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) { | |||
1989 | if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) { | |||
1990 | Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct); | |||
1991 | } | |||
1992 | // TODO(apassos) consider not building a list and changing the API to check | |||
1993 | // each tensor individually. | |||
1994 | std::vector<int64_t> tensor_ids; | |||
1995 | std::vector<tensorflow::DataType> dtypes; | |||
1996 | if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) { | |||
1997 | return nullptr; | |||
1998 | } | |||
1999 | auto tape_set = *GetTapeSet(); | |||
2000 | for (TFE_Py_Tape* tape : tape_set) { | |||
2001 | if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { | |||
2002 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | |||
2003 | } | |||
2004 | } | |||
2005 | ||||
2006 | Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct); | |||
2007 | } | |||
2008 | ||||
2009 | PyObject* TFE_Py_ForwardAccumulatorPushState() { | |||
2010 | auto forward_accumulators = *GetAccumulatorSet(); | |||
2011 | for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { | |||
2012 | accumulator->accumulator->PushState(); | |||
2013 | } | |||
2014 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
2015 | } | |||
2016 | ||||
2017 | PyObject* TFE_Py_ForwardAccumulatorPopState() { | |||
2018 | auto forward_accumulators = *GetAccumulatorSet(); | |||
2019 | for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { | |||
2020 | accumulator->accumulator->PopState(); | |||
2021 | } | |||
2022 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
2023 | } | |||
2024 | ||||
2025 | PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) { | |||
2026 | if (!TapeCouldPossiblyRecord(tensors)) { | |||
2027 | return GetPythonObjectFromInt(0); | |||
2028 | } | |||
2029 | std::vector<int64_t> tensor_ids; | |||
2030 | std::vector<tensorflow::DataType> dtypes; | |||
2031 | if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) { | |||
2032 | return nullptr; | |||
2033 | } | |||
2034 | ||||
2035 | // If there is a persistent tape watching, or if there are multiple tapes | |||
2036 | // watching, we'll return immediately indicating that higher-order tape | |||
2037 | // gradients are possible. | |||
2038 | bool some_tape_watching = false; | |||
2039 | if (CouldBackprop()) { | |||
2040 | auto tape_set = *GetTapeSet(); | |||
2041 | for (TFE_Py_Tape* tape : tape_set) { | |||
2042 | if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { | |||
2043 | if (tape->tape->IsPersistent() || some_tape_watching) { | |||
2044 | // Either this is the second tape watching, or this tape is | |||
2045 | // persistent: higher-order gradients are possible. | |||
2046 | return GetPythonObjectFromInt(2); | |||
2047 | } | |||
2048 | some_tape_watching = true; | |||
2049 | } | |||
2050 | } | |||
2051 | } | |||
2052 | if (CouldForwardprop()) { | |||
2053 | auto forward_accumulators = *GetAccumulatorSet(); | |||
2054 | for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { | |||
2055 | if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) { | |||
2056 | if (some_tape_watching) { | |||
2057 | // This is the second tape watching: higher-order gradients are | |||
2058 | // possible. Note that there's no equivalent of persistence for | |||
2059 | // forward-mode. | |||
2060 | return GetPythonObjectFromInt(2); | |||
2061 | } | |||
2062 | some_tape_watching = true; | |||
2063 | } | |||
2064 | } | |||
2065 | } | |||
2066 | if (some_tape_watching) { | |||
2067 | // There's exactly one non-persistent tape. The user can request first-order | |||
2068 | // gradients but won't be able to get higher-order tape gradients. | |||
2069 | return GetPythonObjectFromInt(1); | |||
2070 | } else { | |||
2071 | // There are no tapes. The user can't request tape gradients. | |||
2072 | return GetPythonObjectFromInt(0); | |||
2073 | } | |||
2074 | } | |||
2075 | ||||
2076 | void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { | |||
2077 | if (!CouldBackprop()) { | |||
2078 | return; | |||
2079 | } | |||
2080 | int64_t tensor_id = FastTensorId(tensor); | |||
2081 | if (PyErr_Occurred()) { | |||
2082 | return; | |||
2083 | } | |||
2084 | reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id); | |||
2085 | } | |||
2086 | ||||
2087 | bool ListContainsNone(PyObject* list) { | |||
2088 | if (list == Py_None(&_Py_NoneStruct)) return true; | |||
2089 | tensorflow::Safe_PyObjectPtr seq( | |||
2090 | PySequence_Fast(list, "expected a sequence")); | |||
2091 | if (seq == nullptr) { | |||
2092 | return false; | |||
2093 | } | |||
2094 | ||||
2095 | int len = PySequence_Size(list); | |||
2096 | PyObject** seq_array = PySequence_Fast_ITEMS(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq.get()))-> ob_item : ((PyTupleObject *)(seq.get()))->ob_item); | |||
2097 | for (int i = 0; i < len; ++i) { | |||
2098 | PyObject* item = seq_array[i]; | |||
2099 | if (item == Py_None(&_Py_NoneStruct)) return true; | |||
2100 | } | |||
2101 | ||||
2102 | return false; | |||
2103 | } | |||
2104 | ||||
2105 | // As an optimization, the tape generally keeps only the shape and dtype of | |||
2106 | // tensors, and uses this information to generate ones/zeros tensors. However, | |||
2107 | // some tensors require OnesLike/ZerosLike because their gradients do not match | |||
2108 | // their inference shape/dtype. | |||
2109 | bool DTypeNeedsHandleData(tensorflow::DataType dtype) { | |||
2110 | return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE; | |||
2111 | } | |||
2112 | ||||
2113 | static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { | |||
2114 | if (EagerTensor_CheckExact(tensor)) { | |||
2115 | tensorflow::ImmediateExecutionTensorHandle* handle = | |||
2116 | tensorflow::unwrap(EagerTensor_Handle(tensor)); | |||
2117 | int64_t id = PyEagerTensor_ID(tensor); | |||
2118 | tensorflow::DataType dtype = | |||
2119 | static_cast<tensorflow::DataType>(handle->DataType()); | |||
2120 | if (DTypeNeedsHandleData(dtype)) { | |||
2121 | return PyTapeTensor(id, dtype, tensor); | |||
2122 | } | |||
2123 | ||||
2124 | tensorflow::TensorShape tensor_shape; | |||
2125 | int num_dims; | |||
2126 | tensorflow::Status status = handle->NumDims(&num_dims); | |||
2127 | if (status.ok()) { | |||
2128 | for (int i = 0; i < num_dims; ++i) { | |||
2129 | int64_t dim_size; | |||
2130 | status = handle->Dim(i, &dim_size); | |||
2131 | if (!status.ok()) break; | |||
2132 | tensor_shape.AddDim(dim_size); | |||
2133 | } | |||
2134 | } | |||
2135 | ||||
2136 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { | |||
2137 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), | |||
2138 | tensorflow::TensorShape({})); | |||
2139 | } else { | |||
2140 | return PyTapeTensor(id, dtype, tensor_shape); | |||
2141 | } | |||
2142 | } | |||
2143 | int64_t id = FastTensorId(tensor); | |||
2144 | if (PyErr_Occurred()) { | |||
2145 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), | |||
2146 | tensorflow::TensorShape({})); | |||
2147 | } | |||
2148 | PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype"); | |||
2149 | PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum"); | |||
2150 | Py_DECREF(dtype_object)_Py_DECREF(((PyObject*)(dtype_object))); | |||
2151 | tensorflow::DataType dtype = | |||
2152 | static_cast<tensorflow::DataType>(MakeInt(dtype_enum)); | |||
2153 | Py_DECREF(dtype_enum)_Py_DECREF(((PyObject*)(dtype_enum))); | |||
2154 | if (PyErr_Occurred()) { | |||
2155 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), | |||
2156 | tensorflow::TensorShape({})); | |||
2157 | } | |||
2158 | static char _shape_tuple[] = "_shape_tuple"; | |||
2159 | tensorflow::Safe_PyObjectPtr shape_tuple( | |||
2160 | PyObject_CallMethod(tensor, _shape_tuple, nullptr)); | |||
2161 | if (PyErr_Occurred()) { | |||
2162 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), | |||
2163 | tensorflow::TensorShape({})); | |||
2164 | } | |||
2165 | ||||
2166 | if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) { | |||
2167 | return PyTapeTensor(id, dtype, tensor); | |||
2168 | } | |||
2169 | ||||
2170 | auto l = MakeIntList(shape_tuple.get()); | |||
2171 | // Replace -1, which represents accidental Nones which can occur in graph mode | |||
2172 | // and can cause errors in shape construction with 0s. | |||
2173 | for (auto& c : l) { | |||
2174 | if (c < 0) { | |||
2175 | c = 0; | |||
2176 | } | |||
2177 | } | |||
2178 | tensorflow::TensorShape shape(l); | |||
2179 | return PyTapeTensor(id, dtype, shape); | |||
2180 | } | |||
2181 | ||||
2182 | // Populates output_info from output_seq, which must come from PySequence_Fast. | |||
2183 | // | |||
2184 | // Does not take ownership of output_seq. Returns true on success and false if a | |||
2185 | // Python exception has been set. | |||
2186 | bool TapeTensorsFromTensorSequence(PyObject* output_seq, | |||
2187 | std::vector<PyTapeTensor>* output_info) { | |||
2188 | Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq)(((((((PyObject*)(output_seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(output_seq))->ob_size)) : (((PyVarObject *)(((static_cast<void> (0)), (PyTupleObject *)(output_seq ))))->ob_size)); | |||
2189 | PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq)(((((((PyObject*)(output_seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(output_seq))-> ob_item : ((PyTupleObject *)(output_seq))->ob_item); | |||
2190 | output_info->reserve(output_len); | |||
2191 | for (Py_ssize_t i = 0; i < output_len; ++i) { | |||
2192 | output_info->push_back(TapeTensorFromTensor(output_seq_array[i])); | |||
2193 | if (PyErr_Occurred() != nullptr) { | |||
2194 | return false; | |||
2195 | } | |||
2196 | } | |||
2197 | return true; | |||
2198 | } | |||
2199 | ||||
2200 | std::vector<int64_t> MakeTensorIDList(PyObject* tensors) { | |||
2201 | PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); | |||
2202 | if (seq == nullptr) { | |||
2203 | return {}; | |||
2204 | } | |||
2205 | int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size)); | |||
2206 | PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item); | |||
2207 | std::vector<int64_t> list; | |||
2208 | list.reserve(len); | |||
2209 | for (int i = 0; i < len; ++i) { | |||
2210 | PyObject* tensor = seq_array[i]; | |||
2211 | list.push_back(FastTensorId(tensor)); | |||
2212 | if (PyErr_Occurred()) { | |||
2213 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | |||
2214 | return list; | |||
2215 | } | |||
2216 | } | |||
2217 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | |||
2218 | return list; | |||
2219 | } | |||
2220 | ||||
2221 | void TFE_Py_TapeVariableAccessed(PyObject* variable) { | |||
2222 | if (!CouldBackprop()) { | |||
2223 | return; | |||
2224 | } | |||
2225 | for (TFE_Py_Tape* tape : SafeTapeSet()) { | |||
2226 | tape->tape->VariableAccessed(variable); | |||
2227 | } | |||
2228 | } | |||
2229 | ||||
2230 | void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) { | |||
2231 | if (!CouldBackprop()) { | |||
2232 | return; | |||
2233 | } | |||
2234 | reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable); | |||
2235 | } | |||
2236 | ||||
2237 | PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { | |||
2238 | return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple(); | |||
2239 | } | |||
2240 | ||||
2241 | PyObject* TFE_Py_VariableWatcherNew() { | |||
2242 | TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew; | |||
2243 | if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr; | |||
2244 | TFE_Py_VariableWatcher* variable_watcher = | |||
2245 | PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type)( (TFE_Py_VariableWatcher *) PyObject_Init( (PyObject *) PyObject_Malloc ( ( (&TFE_Py_VariableWatcher_Type)->tp_basicsize ) ), ( &TFE_Py_VariableWatcher_Type)) ); | |||
2246 | variable_watcher->variable_watcher = new VariableWatcher(); | |||
2247 | Py_INCREF(variable_watcher)_Py_INCREF(((PyObject*)(variable_watcher))); | |||
2248 | GetVariableWatcherSet()->insert(variable_watcher); | |||
2249 | return reinterpret_cast<PyObject*>(variable_watcher); | |||
2250 | } | |||
2251 | ||||
2252 | void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) { | |||
2253 | auto* stack = GetVariableWatcherSet(); | |||
2254 | stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)); | |||
2255 | // We kept a reference to the variable watcher in the set to ensure it | |||
2256 | // wouldn't get deleted under us; cleaning it up here. | |||
2257 | Py_DECREF(variable_watcher)_Py_DECREF(((PyObject*)(variable_watcher))); | |||
2258 | } | |||
2259 | ||||
2260 | void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) { | |||
2261 | for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) { | |||
2262 | variable_watcher->variable_watcher->WatchVariable(variable); | |||
2263 | } | |||
2264 | } | |||
2265 | ||||
2266 | PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) { | |||
2267 | return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher) | |||
2268 | ->variable_watcher->GetVariablesAsPyTuple(); | |||
2269 | } | |||
2270 | ||||
2271 | namespace { | |||
2272 | std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) { | |||
2273 | PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); | |||
2274 | if (seq == nullptr) { | |||
2275 | return {}; | |||
2276 | } | |||
2277 | int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size)); | |||
2278 | PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item); | |||
2279 | std::vector<tensorflow::DataType> list; | |||
2280 | list.reserve(len); | |||
2281 | for (int i = 0; i < len; ++i) { | |||
2282 | PyObject* tensor = seq_array[i]; | |||
2283 | list.push_back(tensorflow::PyTensor_DataType(tensor)); | |||
2284 | } | |||
2285 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | |||
2286 | return list; | |||
2287 | } | |||
2288 | ||||
2289 | PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id, | |||
2290 | PyObject* weak_tensor_ref) { | |||
2291 | int64_t parsed_tensor_id = MakeInt(tensor_id); | |||
2292 | for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) { | |||
2293 | accumulator->accumulator->DeleteGradient(parsed_tensor_id); | |||
2294 | } | |||
2295 | Py_DECREF(weak_tensor_ref)_Py_DECREF(((PyObject*)(weak_tensor_ref))); | |||
2296 | Py_DECREF(tensor_id)_Py_DECREF(((PyObject*)(tensor_id))); | |||
2297 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | |||
2298 | return Py_None(&_Py_NoneStruct); | |||
2299 | } | |||
2300 | ||||
2301 | static PyMethodDef forward_accumulator_delete_gradient_method_def = { | |||
2302 | "ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient, | |||
2303 | METH_O0x0008, "ForwardAccumulatorDeleteGradient"}; | |||
2304 | ||||
2305 | void RegisterForwardAccumulatorCleanup(PyObject* tensor, int64_t tensor_id) { | |||
2306 | tensorflow::Safe_PyObjectPtr callback( | |||
2307 | PyCFunction_New(&forward_accumulator_delete_gradient_method_def,PyCFunction_NewEx((&forward_accumulator_delete_gradient_method_def ), (PyLong_FromLong(tensor_id)), __null) | |||
2308 | PyLong_FromLong(tensor_id))PyCFunction_NewEx((&forward_accumulator_delete_gradient_method_def ), (PyLong_FromLong(tensor_id)), __null)); | |||
2309 | // We need to keep a reference to the weakref active if we want our callback | |||
2310 | // called. The callback itself now owns the weakref object and the tensor ID | |||
2311 | // object. | |||
2312 | PyWeakref_NewRef(tensor, callback.get()); | |||
2313 | } | |||
2314 | ||||
2315 | void TapeSetRecordBackprop( | |||
2316 | const string& op_type, const std::vector<PyTapeTensor>& output_info, | |||
2317 | const std::vector<int64_t>& input_ids, | |||
2318 | const std::vector<tensorflow::DataType>& input_dtypes, | |||
2319 | const std::function<PyBackwardFunction*()>& backward_function_getter, | |||
2320 | const std::function<void(PyBackwardFunction*)>& backward_function_killer, | |||
2321 | tensorflow::uint64 max_gradient_tape_id) { | |||
2322 | if (!CouldBackprop()) { | |||
2323 | return; | |||
2324 | } | |||
2325 | for (TFE_Py_Tape* tape : SafeTapeSet()) { | |||
2326 | if (tape->nesting_id < max_gradient_tape_id) { | |||
2327 | tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes, | |||
2328 | backward_function_getter, | |||
2329 | backward_function_killer); | |||
2330 | } | |||
2331 | } | |||
2332 | } | |||
2333 | ||||
2334 | bool TapeSetRecordForwardprop( | |||
2335 | const string& op_type, PyObject* output_seq, | |||
2336 | const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors, | |||
2337 | const std::vector<int64_t>& input_ids, | |||
2338 | const std::vector<tensorflow::DataType>& input_dtypes, | |||
2339 | const std::function<PyBackwardFunction*()>& backward_function_getter, | |||
2340 | const std::function<void(PyBackwardFunction*)>& backward_function_killer, | |||
2341 | const tensorflow::eager::ForwardFunction<PyObject>* forward_function, | |||
2342 | PyObject* forwardprop_output_indices, | |||
2343 | tensorflow::uint64* max_gradient_tape_id) { | |||
2344 | *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max(); | |||
2345 | if (!CouldForwardprop()) { | |||
2346 | return true; | |||
2347 | } | |||
2348 | auto accumulator_set = SafeAccumulatorSet(); | |||
2349 | tensorflow::Safe_PyObjectPtr input_seq( | |||
2350 | PySequence_Fast(input_tensors, "expected a sequence of tensors")); | |||
2351 | if (input_seq == nullptr || PyErr_Occurred()) return false; | |||
2352 | Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get())(((((((PyObject*)(input_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(input_seq.get()))->ob_size)) : ((( PyVarObject*)(((static_cast<void> (0)), (PyTupleObject * )(input_seq.get()))))->ob_size)); | |||
2353 | PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq)(((((((PyObject*)(output_seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(output_seq))-> ob_item : ((PyTupleObject *)(output_seq))->ob_item); | |||
2354 | for (int i = 0; i < output_info.size(); ++i) { | |||
2355 | RegisterForwardAccumulatorCleanup(output_seq_array[i], | |||
2356 | output_info[i].GetID()); | |||
2357 | } | |||
2358 | if (forwardprop_output_indices != nullptr && | |||
2359 | forwardprop_output_indices != Py_None(&_Py_NoneStruct)) { | |||
2360 | tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast( | |||
2361 | forwardprop_output_indices, "Expected a sequence of indices")); | |||
2362 | if (indices_fast == nullptr || PyErr_Occurred()) { | |||
2363 | return false; | |||
2364 | } | |||
2365 | if (PySequence_Fast_GET_SIZE(indices_fast.get())(((((((PyObject*)(indices_fast.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(indices_fast.get()))->ob_size)) : ( ((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(indices_fast.get()))))->ob_size)) != | |||
2366 | accumulator_set.size()) { | |||
2367 | MaybeRaiseExceptionFromStatus( | |||
2368 | tensorflow::errors::Internal( | |||
2369 | "Accumulators were added or removed from the active set " | |||
2370 | "between packing and unpacking."), | |||
2371 | nullptr); | |||
2372 | } | |||
2373 | PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get())(((((((PyObject*)(indices_fast.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(indices_fast .get()))->ob_item : ((PyTupleObject *)(indices_fast.get()) )->ob_item); | |||
2374 | Py_ssize_t accumulator_index = 0; | |||
2375 | for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin(); | |||
2376 | it != accumulator_set.rend(); ++it, ++accumulator_index) { | |||
2377 | tensorflow::Safe_PyObjectPtr jvp_index_seq( | |||
2378 | PySequence_Fast(indices_fast_array[accumulator_index], | |||
2379 | "Expected a sequence of jvp indices.")); | |||
2380 | if (jvp_index_seq == nullptr || PyErr_Occurred()) { | |||
2381 | return false; | |||
2382 | } | |||
2383 | Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get())(((((((PyObject*)(jvp_index_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(jvp_index_seq.get()))->ob_size)) : (((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(jvp_index_seq.get()))))->ob_size)); | |||
2384 | PyObject** jvp_index_seq_array = | |||
2385 | PySequence_Fast_ITEMS(jvp_index_seq.get())(((((((PyObject*)(jvp_index_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(jvp_index_seq .get()))->ob_item : ((PyTupleObject *)(jvp_index_seq.get() ))->ob_item); | |||
2386 | for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) { | |||
2387 | PyObject* tuple = jvp_index_seq_array[jvp_index]; | |||
2388 | int64_t primal_tensor_id = | |||
2389 | output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID(); | |||
2390 | (*it)->accumulator->Watch( | |||
2391 | primal_tensor_id, | |||
2392 | output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]); | |||
2393 | } | |||
2394 | } | |||
2395 | } else { | |||
2396 | std::vector<PyTapeTensor> input_info; | |||
2397 | input_info.reserve(input_len); | |||
2398 | PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get())(((((((PyObject*)(input_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(input_seq .get()))->ob_item : ((PyTupleObject *)(input_seq.get()))-> ob_item); | |||
2399 | for (Py_ssize_t i = 0; i < input_len; ++i) { | |||
2400 | input_info.push_back(TapeTensorFromTensor(input_seq_array[i])); | |||
2401 | } | |||
2402 | for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) { | |||
2403 | tensorflow::Status status = accumulator->accumulator->Accumulate( | |||
2404 | op_type, input_info, output_info, input_ids, input_dtypes, | |||
2405 | forward_function, backward_function_getter, backward_function_killer); | |||
2406 | if (PyErr_Occurred()) return false; // Don't swallow Python exceptions. | |||
2407 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { | |||
2408 | return false; | |||
2409 | } | |||
2410 | if (accumulator->accumulator->BusyAccumulating()) { | |||
2411 | // Ensure inner accumulators don't see outer accumulators' jvps. This | |||
2412 | // mostly happens on its own, with some potentially surprising | |||
2413 | // exceptions, so the blanket policy is for consistency. | |||
2414 | *max_gradient_tape_id = accumulator->nesting_id; | |||
2415 | break; | |||
2416 | } | |||
2417 | } | |||
2418 | } | |||
2419 | return true; | |||
2420 | } | |||
2421 | ||||
2422 | PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) { | |||
2423 | PyObject* py_input_tangents = PyTuple_New(input_tangents.size()); | |||
2424 | for (int i = 0; i < input_tangents.size(); ++i) { | |||
2425 | PyObject* element; | |||
2426 | if (input_tangents[i] == nullptr) { | |||
2427 | element = Py_None(&_Py_NoneStruct); | |||
2428 | } else { | |||
2429 | element = input_tangents[i]; | |||
2430 | } | |||
2431 | Py_INCREF(element)_Py_INCREF(((PyObject*)(element))); | |||
2432 | PyTuple_SET_ITEM(py_input_tangents, i, element)PyTuple_SetItem(py_input_tangents, i, element); | |||
2433 | } | |||
2434 | return py_input_tangents; | |||
2435 | } | |||
2436 | ||||
2437 | tensorflow::Status ParseTangentOutputs( | |||
2438 | PyObject* user_output, std::vector<PyObject*>* output_tangents) { | |||
2439 | if (user_output == Py_None(&_Py_NoneStruct)) { | |||
2440 | // No connected gradients. | |||
2441 | return tensorflow::Status::OK(); | |||
2442 | } | |||
2443 | tensorflow::Safe_PyObjectPtr fast_result( | |||
2444 | PySequence_Fast(user_output, "expected a sequence of forward gradients")); | |||
2445 | if (fast_result == nullptr) { | |||
2446 | return tensorflow::errors::InvalidArgument( | |||
2447 | "forward gradient function did not return a sequence."); | |||
2448 | } | |||
2449 | int len = PySequence_Fast_GET_SIZE(fast_result.get())(((((((PyObject*)(fast_result.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_result.get()))->ob_size)) : ( ((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(fast_result.get()))))->ob_size)); | |||
2450 | PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get())(((((((PyObject*)(fast_result.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_result .get()))->ob_item : ((PyTupleObject *)(fast_result.get())) ->ob_item); | |||
2451 | output_tangents->reserve(len); | |||
2452 | for (int i = 0; i < len; ++i) { | |||
2453 | PyObject* item = fast_result_array[i]; | |||
2454 | if (item == Py_None(&_Py_NoneStruct)) { | |||
2455 | output_tangents->push_back(nullptr); | |||
2456 | } else { | |||
2457 | Py_INCREF(item)_Py_INCREF(((PyObject*)(item))); | |||
2458 | output_tangents->push_back(item); | |||
2459 | } | |||
2460 | } | |||
2461 | return tensorflow::Status::OK(); | |||
2462 | } | |||
2463 | ||||
2464 | // Calls the registered forward_gradient_function, computing `output_tangents` | |||
2465 | // from `input_tangents`. `output_tangents` must not be null. | |||
2466 | // | |||
2467 | // `op_name`, `attrs`, `inputs`, and `results` describe the operation for which | |||
2468 | // the forward function is being called. | |||
2469 | tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs, | |||
2470 | PyObject* inputs, PyObject* results, | |||
2471 | const std::vector<PyObject*>& input_tangents, | |||
2472 | std::vector<PyObject*>* output_tangents, | |||
2473 | bool use_batch) { | |||
2474 | if (forward_gradient_function == nullptr) { | |||
2475 | return tensorflow::errors::Internal( | |||
2476 | "No forward gradient function registered."); | |||
2477 | } | |||
2478 | tensorflow::Safe_PyObjectPtr py_input_tangents( | |||
2479 | TangentsAsPyTuple(input_tangents)); | |||
2480 | ||||
2481 | // Normalize the input sequence to a tuple so it works with function | |||
2482 | // caching; otherwise it may be an opaque _InputList object. | |||
2483 | tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs)); | |||
2484 | PyObject* to_batch = (use_batch) ? Py_True((PyObject *) &_Py_TrueStruct) : Py_False((PyObject *) &_Py_FalseStruct); | |||
2485 | tensorflow::Safe_PyObjectPtr callback_args( | |||
2486 | Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results, | |||
2487 | py_input_tangents.get(), to_batch)); | |||
2488 | tensorflow::Safe_PyObjectPtr py_result( | |||
2489 | PyObject_CallObject(forward_gradient_function, callback_args.get())); | |||
2490 | if (py_result == nullptr || PyErr_Occurred()) { | |||
2491 | return tensorflow::errors::Internal( | |||
2492 | "forward gradient function threw exceptions"); | |||
2493 | } | |||
2494 | return ParseTangentOutputs(py_result.get(), output_tangents); | |||
2495 | } | |||
2496 | ||||
2497 | // Like CallJVPFunction, but calls a pre-bound forward function. | |||
2498 | // These are passed in from a record_gradient argument. | |||
2499 | tensorflow::Status CallOpSpecificJVPFunction( | |||
2500 | PyObject* op_specific_forward_function, | |||
2501 | const std::vector<PyObject*>& input_tangents, | |||
2502 | std::vector<PyObject*>* output_tangents) { | |||
2503 | tensorflow::Safe_PyObjectPtr py_input_tangents( | |||
2504 | TangentsAsPyTuple(input_tangents)); | |||
2505 | ||||
2506 | tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject( | |||
2507 | op_specific_forward_function, py_input_tangents.get())); | |||
2508 | if (py_result == nullptr || PyErr_Occurred()) { | |||
2509 | return tensorflow::errors::Internal( | |||
2510 | "forward gradient function threw exceptions"); | |||
2511 | } | |||
2512 | return ParseTangentOutputs(py_result.get(), output_tangents); | |||
2513 | } | |||
2514 | ||||
2515 | bool ParseOpTypeString(PyObject* op_type, string* op_type_string) { | |||
2516 | if (PyBytes_Check(op_type)((((((PyObject*)(op_type))->ob_type))->tp_flags & ( (1UL << 27))) != 0)) { | |||
2517 | *op_type_string = PyBytes_AsString(op_type); | |||
2518 | } else if (PyUnicode_Check(op_type)((((((PyObject*)(op_type))->ob_type))->tp_flags & ( (1UL << 28))) != 0)) { | |||
2519 | #if PY_MAJOR_VERSION3 >= 3 | |||
2520 | *op_type_string = PyUnicode_AsUTF8(op_type); | |||
2521 | #else | |||
2522 | PyObject* py_str = PyUnicode_AsUTF8String(op_type); | |||
2523 | if (py_str == nullptr) { | |||
2524 | return false; | |||
2525 | } | |||
2526 | *op_type_string = PyBytes_AS_STRING(py_str)((static_cast<void> (0)), (((PyBytesObject *)(py_str))-> ob_sval)); | |||
2527 | Py_DECREF(py_str)_Py_DECREF(((PyObject*)(py_str))); | |||
2528 | #endif | |||
2529 | } else { | |||
2530 | PyErr_SetString(PyExc_RuntimeError, "op_type should be a string."); | |||
2531 | return false; | |||
2532 | } | |||
2533 | return true; | |||
2534 | } | |||
2535 | ||||
2536 | bool TapeSetRecordOperation( | |||
2537 | PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors, | |||
2538 | const std::vector<int64_t>& input_ids, | |||
2539 | const std::vector<tensorflow::DataType>& input_dtypes, | |||
2540 | const std::function<PyBackwardFunction*()>& backward_function_getter, | |||
2541 | const std::function<void(PyBackwardFunction*)>& backward_function_killer, | |||
2542 | const tensorflow::eager::ForwardFunction<PyObject>* forward_function) { | |||
2543 | std::vector<PyTapeTensor> output_info; | |||
2544 | tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( | |||
2545 | output_tensors, "expected a sequence of integer tensor ids")); | |||
2546 | if (PyErr_Occurred() || | |||
2547 | !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { | |||
2548 | return false; | |||
2549 | } | |||
2550 | string op_type_str; | |||
2551 | if (!ParseOpTypeString(op_type, &op_type_str)) { | |||
2552 | return false; | |||
2553 | } | |||
2554 | tensorflow::uint64 max_gradient_tape_id; | |||
2555 | if (!TapeSetRecordForwardprop( | |||
2556 | op_type_str, output_seq.get(), output_info, input_tensors, input_ids, | |||
2557 | input_dtypes, backward_function_getter, backward_function_killer, | |||
2558 | forward_function, nullptr /* No special-cased jvps. */, | |||
2559 | &max_gradient_tape_id)) { | |||
2560 | return false; | |||
2561 | } | |||
2562 | TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes, | |||
2563 | backward_function_getter, backward_function_killer, | |||
2564 | max_gradient_tape_id); | |||
2565 | return true; | |||
2566 | } | |||
2567 | } // namespace | |||
2568 | ||||
2569 | PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type, | |||
2570 | PyObject* output_tensors, | |||
2571 | PyObject* input_tensors, | |||
2572 | PyObject* backward_function, | |||
2573 | PyObject* forward_function) { | |||
2574 | if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) { | |||
2575 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
2576 | } | |||
2577 | std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors); | |||
2578 | if (PyErr_Occurred()) return nullptr; | |||
2579 | ||||
2580 | std::vector<tensorflow::DataType> input_dtypes = | |||
2581 | MakeTensorDtypeList(input_tensors); | |||
2582 | if (PyErr_Occurred()) return nullptr; | |||
2583 | ||||
2584 | std::function<PyBackwardFunction*()> backward_function_getter( | |||
2585 | [backward_function]() { | |||
2586 | Py_INCREF(backward_function)_Py_INCREF(((PyObject*)(backward_function))); | |||
2587 | PyBackwardFunction* function = new PyBackwardFunction( | |||
2588 | [backward_function](PyObject* out_grads, | |||
2589 | const std::vector<int64_t>& unused) { | |||
2590 | return PyObject_CallObject(backward_function, out_grads); | |||
2591 | }); | |||
2592 | return function; | |||
2593 | }); | |||
2594 | std::function<void(PyBackwardFunction*)> backward_function_killer( | |||
2595 | [backward_function](PyBackwardFunction* py_backward_function) { | |||
2596 | Py_DECREF(backward_function)_Py_DECREF(((PyObject*)(backward_function))); | |||
2597 | delete py_backward_function; | |||
2598 | }); | |||
2599 | ||||
2600 | if (forward_function == Py_None(&_Py_NoneStruct)) { | |||
2601 | if (!TapeSetRecordOperation( | |||
2602 | op_type, input_tensors, output_tensors, input_ids, input_dtypes, | |||
2603 | backward_function_getter, backward_function_killer, | |||
2604 | nullptr /* No special-cased forward function */)) { | |||
2605 | return nullptr; | |||
2606 | } | |||
2607 | } else { | |||
2608 | tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function( | |||
2609 | [forward_function](const std::vector<PyObject*>& input_tangents, | |||
2610 | std::vector<PyObject*>* output_tangents, | |||
2611 | bool use_batch = false) { | |||
2612 | return CallOpSpecificJVPFunction(forward_function, input_tangents, | |||
2613 | output_tangents); | |||
2614 | }); | |||
2615 | if (!TapeSetRecordOperation( | |||
2616 | op_type, input_tensors, output_tensors, input_ids, input_dtypes, | |||
2617 | backward_function_getter, backward_function_killer, | |||
2618 | &wrapped_forward_function)) { | |||
2619 | return nullptr; | |||
2620 | } | |||
2621 | } | |||
2622 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
2623 | } | |||
2624 | ||||
2625 | PyObject* TFE_Py_TapeSetRecordOperationForwardprop( | |||
2626 | PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors, | |||
2627 | PyObject* backward_function, PyObject* forwardprop_output_indices) { | |||
2628 | if (!HasAccumulator() || *ThreadTapeIsStopped()) { | |||
2629 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
2630 | } | |||
2631 | std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors); | |||
2632 | if (PyErr_Occurred()) return nullptr; | |||
2633 | ||||
2634 | std::vector<tensorflow::DataType> input_dtypes = | |||
2635 | MakeTensorDtypeList(input_tensors); | |||
2636 | if (PyErr_Occurred()) return nullptr; | |||
2637 | ||||
2638 | std::function<PyBackwardFunction*()> backward_function_getter( | |||
2639 | [backward_function]() { | |||
2640 | Py_INCREF(backward_function)_Py_INCREF(((PyObject*)(backward_function))); | |||
2641 | PyBackwardFunction* function = new PyBackwardFunction( | |||
2642 | [backward_function](PyObject* out_grads, | |||
2643 | const std::vector<int64_t>& unused) { | |||
2644 | return PyObject_CallObject(backward_function, out_grads); | |||
2645 | }); | |||
2646 | return function; | |||
2647 | }); | |||
2648 | std::function<void(PyBackwardFunction*)> backward_function_killer( | |||
2649 | [backward_function](PyBackwardFunction* py_backward_function) { | |||
2650 | Py_DECREF(backward_function)_Py_DECREF(((PyObject*)(backward_function))); | |||
2651 | delete py_backward_function; | |||
2652 | }); | |||
2653 | std::vector<PyTapeTensor> output_info; | |||
2654 | tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( | |||
2655 | output_tensors, "expected a sequence of integer tensor ids")); | |||
2656 | if (PyErr_Occurred() || | |||
2657 | !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { | |||
2658 | return nullptr; | |||
2659 | } | |||
2660 | string op_type_str; | |||
2661 | if (!ParseOpTypeString(op_type, &op_type_str)) { | |||
2662 | return nullptr; | |||
2663 | } | |||
2664 | tensorflow::uint64 max_gradient_tape_id; | |||
2665 | if (!TapeSetRecordForwardprop( | |||
2666 | op_type_str, output_seq.get(), output_info, input_tensors, input_ids, | |||
2667 | input_dtypes, backward_function_getter, backward_function_killer, | |||
2668 | nullptr /* no special-cased forward function */, | |||
2669 | forwardprop_output_indices, &max_gradient_tape_id)) { | |||
2670 | return nullptr; | |||
2671 | } | |||
2672 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
2673 | } | |||
2674 | ||||
2675 | PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type, | |||
2676 | PyObject* output_tensors, | |||
2677 | PyObject* input_tensors, | |||
2678 | PyObject* backward_function) { | |||
2679 | if (!CouldBackprop()) { | |||
2680 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
2681 | } | |||
2682 | std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors); | |||
2683 | if (PyErr_Occurred()) return nullptr; | |||
2684 | ||||
2685 | std::vector<tensorflow::DataType> input_dtypes = | |||
2686 | MakeTensorDtypeList(input_tensors); | |||
2687 | if (PyErr_Occurred()) return nullptr; | |||
2688 | ||||
2689 | std::function<PyBackwardFunction*()> backward_function_getter( | |||
2690 | [backward_function]() { | |||
2691 | Py_INCREF(backward_function)_Py_INCREF(((PyObject*)(backward_function))); | |||
2692 | PyBackwardFunction* function = new PyBackwardFunction( | |||
2693 | [backward_function](PyObject* out_grads, | |||
2694 | const std::vector<int64_t>& unused) { | |||
2695 | return PyObject_CallObject(backward_function, out_grads); | |||
2696 | }); | |||
2697 | return function; | |||
2698 | }); | |||
2699 | std::function<void(PyBackwardFunction*)> backward_function_killer( | |||
2700 | [backward_function](PyBackwardFunction* py_backward_function) { | |||
2701 | Py_DECREF(backward_function)_Py_DECREF(((PyObject*)(backward_function))); | |||
2702 | delete py_backward_function; | |||
2703 | }); | |||
2704 | std::vector<PyTapeTensor> output_info; | |||
2705 | tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( | |||
2706 | output_tensors, "expected a sequence of integer tensor ids")); | |||
2707 | if (PyErr_Occurred() || | |||
2708 | !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { | |||
2709 | return nullptr; | |||
2710 | } | |||
2711 | string op_type_str; | |||
2712 | if (!ParseOpTypeString(op_type, &op_type_str)) { | |||
2713 | return nullptr; | |||
2714 | } | |||
2715 | TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes, | |||
2716 | backward_function_getter, backward_function_killer, | |||
2717 | // No filtering based on relative ordering with forward | |||
2718 | // accumulators. | |||
2719 | std::numeric_limits<tensorflow::uint64>::max()); | |||
2720 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
2721 | } | |||
2722 | ||||
2723 | void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id) { | |||
2724 | for (TFE_Py_Tape* tape : *GetTapeSet()) { | |||
2725 | tape->tape->DeleteTrace(tensor_id); | |||
2726 | } | |||
2727 | } | |||
2728 | ||||
2729 | std::vector<PyObject*> MakeTensorList(PyObject* tensors) { | |||
2730 | PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); | |||
2731 | if (seq == nullptr) { | |||
2732 | return {}; | |||
2733 | } | |||
2734 | int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size)); | |||
2735 | PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item); | |||
2736 | std::vector<PyObject*> list(seq_array, seq_array + len); | |||
2737 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | |||
2738 | return list; | |||
2739 | } | |||
2740 | ||||
2741 | PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target, | |||
2742 | PyObject* sources, PyObject* output_gradients, | |||
2743 | PyObject* sources_raw, | |||
2744 | PyObject* unconnected_gradients, | |||
2745 | TF_Status* status) { | |||
2746 | TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape); | |||
2747 | if (!tape_obj->tape->IsPersistent()) { | |||
2748 | auto* tape_set = GetTapeSet(); | |||
2749 | if (tape_set->find(tape_obj) != tape_set->end()) { | |||
2750 | PyErr_SetString(PyExc_RuntimeError, | |||
2751 | "gradient() cannot be invoked within the " | |||
2752 | "GradientTape context (i.e., while operations are being " | |||
2753 | "recorded). Either move the call to gradient() to be " | |||
2754 | "outside the 'with tf.GradientTape' block, or " | |||
2755 | "use a persistent tape: " | |||
2756 | "'with tf.GradientTape(persistent=true)'"); | |||
2757 | return nullptr; | |||
2758 | } | |||
2759 | } | |||
2760 | ||||
2761 | std::vector<int64_t> target_vec = MakeTensorIDList(target); | |||
2762 | if (PyErr_Occurred()) { | |||
2763 | return nullptr; | |||
2764 | } | |||
2765 | std::vector<int64_t> sources_vec = MakeTensorIDList(sources); | |||
2766 | if (PyErr_Occurred()) { | |||
2767 | return nullptr; | |||
2768 | } | |||
2769 | tensorflow::gtl::FlatSet<int64_t> sources_set(sources_vec.begin(), | |||
2770 | sources_vec.end()); | |||
2771 | ||||
2772 | tensorflow::Safe_PyObjectPtr seq = | |||
2773 | tensorflow::make_safe(PySequence_Fast(target, "expected a sequence")); | |||
2774 | int len = PySequence_Fast_GET_SIZE(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(seq.get()))->ob_size)) : (((PyVarObject* )(((static_cast<void> (0)), (PyTupleObject *)(seq.get() ))))->ob_size)); | |||
2775 | PyObject** seq_array = PySequence_Fast_ITEMS(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq.get()))-> ob_item : ((PyTupleObject *)(seq.get()))->ob_item); | |||
2776 | std::unordered_map<int64_t, PyTapeTensor> source_tensors_that_are_targets; | |||
2777 | for (int i = 0; i < len; ++i) { | |||
2778 | int64_t target_id = target_vec[i]; | |||
2779 | if (sources_set.find(target_id) != sources_set.end()) { | |||
2780 | auto tensor = seq_array[i]; | |||
2781 | source_tensors_that_are_targets.insert( | |||
2782 | std::make_pair(target_id, TapeTensorFromTensor(tensor))); | |||
2783 | } | |||
2784 | if (PyErr_Occurred()) { | |||
2785 | return nullptr; | |||
2786 | } | |||
2787 | } | |||
2788 | if (PyErr_Occurred()) { | |||
2789 | return nullptr; | |||
2790 | } | |||
2791 | ||||
2792 | std::vector<PyObject*> outgrad_vec; | |||
2793 | if (output_gradients != Py_None(&_Py_NoneStruct)) { | |||
2794 | outgrad_vec = MakeTensorList(output_gradients); | |||
2795 | if (PyErr_Occurred()) { | |||
2796 | return nullptr; | |||
2797 | } | |||
2798 | for (PyObject* tensor : outgrad_vec) { | |||
2799 | // Calling the backward function will eat a reference to the tensors in | |||
2800 | // outgrad_vec, so we need to increase their reference count. | |||
2801 | Py_INCREF(tensor)_Py_INCREF(((PyObject*)(tensor))); | |||
2802 | } | |||
2803 | } | |||
2804 | std::vector<PyObject*> result(sources_vec.size()); | |||
2805 | status->status = tape_obj->tape->ComputeGradient( | |||
2806 | *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets, | |||
2807 | outgrad_vec, absl::MakeSpan(result)); | |||
2808 | if (!status->status.ok()) { | |||
2809 | if (PyErr_Occurred()) { | |||
2810 | // Do not propagate the erroneous status as that would swallow the | |||
2811 | // exception which caused the problem. | |||
2812 | status->status = tensorflow::Status::OK(); | |||
2813 | } | |||
2814 | return nullptr; | |||
2815 | } | |||
2816 | ||||
2817 | bool unconnected_gradients_zero = | |||
2818 | strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0; | |||
2819 | std::vector<PyObject*> sources_obj; | |||
2820 | if (unconnected_gradients_zero) { | |||
2821 | // Uses the "raw" sources here so it can properly make a zeros tensor even | |||
2822 | // if there are resource variables as sources. | |||
2823 | sources_obj = MakeTensorList(sources_raw); | |||
2824 | } | |||
2825 | ||||
2826 | if (!result.empty()) { | |||
2827 | PyObject* py_result = PyList_New(result.size()); | |||
2828 | tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size()); | |||
2829 | for (int i = 0; i < result.size(); ++i) { | |||
2830 | if (result[i] == nullptr) { | |||
2831 | if (unconnected_gradients_zero) { | |||
2832 | // generate a zeros tensor in the shape of sources[i] | |||
2833 | tensorflow::DataType dtype = | |||
2834 | tensorflow::PyTensor_DataType(sources_obj[i]); | |||
2835 | PyTapeTensor tensor = | |||
2836 | PyTapeTensor(sources_vec[i], dtype, sources_obj[i]); | |||
2837 | result[i] = tensor.ZerosLike(); | |||
2838 | } else { | |||
2839 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | |||
2840 | result[i] = Py_None(&_Py_NoneStruct); | |||
2841 | } | |||
2842 | } else if (seen_results.find(result[i]) != seen_results.end()) { | |||
2843 | Py_INCREF(result[i])_Py_INCREF(((PyObject*)(result[i]))); | |||
2844 | } | |||
2845 | seen_results.insert(result[i]); | |||
2846 | PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]))PyList_SetItem(py_result, i, reinterpret_cast<PyObject*> (result[i])); | |||
2847 | } | |||
2848 | return py_result; | |||
2849 | } | |||
2850 | return PyList_New(0); | |||
2851 | } | |||
2852 | ||||
2853 | PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) { | |||
2854 | TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew; | |||
2855 | if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr; | |||
2856 | TFE_Py_ForwardAccumulator* accumulator = | |||
2857 | PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type)( (TFE_Py_ForwardAccumulator *) PyObject_Init( (PyObject *) PyObject_Malloc ( ( (&TFE_Py_ForwardAccumulator_Type)->tp_basicsize ) ) , (&TFE_Py_ForwardAccumulator_Type)) ); | |||
2858 | if (py_vspace == nullptr) { | |||
2859 | MaybeRaiseExceptionFromStatus( | |||
2860 | tensorflow::errors::Internal( | |||
2861 | "ForwardAccumulator requires a PyVSpace to be registered."), | |||
2862 | nullptr); | |||
2863 | } | |||
2864 | accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch); | |||
2865 | return reinterpret_cast<PyObject*>(accumulator); | |||
2866 | } | |||
2867 | ||||
2868 | PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) { | |||
2869 | TFE_Py_ForwardAccumulator* c_accumulator( | |||
2870 | reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)); | |||
2871 | c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1); | |||
2872 | if (GetAccumulatorSet()->insert(c_accumulator)) { | |||
2873 | Py_INCREF(accumulator)_Py_INCREF(((PyObject*)(accumulator))); | |||
2874 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
2875 | } else { | |||
2876 | MaybeRaiseExceptionFromStatus( | |||
2877 | tensorflow::errors::Internal( | |||
2878 | "A ForwardAccumulator was added to the active set twice."), | |||
2879 | nullptr); | |||
2880 | return nullptr; | |||
2881 | } | |||
2882 | } | |||
2883 | ||||
2884 | void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) { | |||
2885 | GetAccumulatorSet()->erase( | |||
2886 | reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)); | |||
2887 | Py_DECREF(accumulator)_Py_DECREF(((PyObject*)(accumulator))); | |||
2888 | } | |||
2889 | ||||
2890 | void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor, | |||
2891 | PyObject* tangent) { | |||
2892 | int64_t tensor_id = FastTensorId(tensor); | |||
2893 | reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator) | |||
2894 | ->accumulator->Watch(tensor_id, tangent); | |||
2895 | RegisterForwardAccumulatorCleanup(tensor, tensor_id); | |||
2896 | } | |||
2897 | ||||
2898 | // Returns a new reference to the JVP Tensor. | |||
2899 | PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, | |||
2900 | PyObject* tensor) { | |||
2901 | PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator) | |||
2902 | ->accumulator->FetchJVP(FastTensorId(tensor)); | |||
2903 | if (jvp == nullptr) { | |||
2904 | jvp = Py_None(&_Py_NoneStruct); | |||
2905 | } | |||
2906 | Py_INCREF(jvp)_Py_INCREF(((PyObject*)(jvp))); | |||
2907 | return jvp; | |||
2908 | } | |||
2909 | ||||
2910 | PyObject* TFE_Py_PackJVPs(PyObject* tensors) { | |||
2911 | if (!TapeCouldPossiblyRecord(tensors)) { | |||
2912 | tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0)); | |||
2913 | tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0)); | |||
2914 | return PyTuple_Pack(2, empty_tuple.get(), empty_list.get()); | |||
2915 | } | |||
2916 | auto accumulators = *GetAccumulatorSet(); | |||
2917 | tensorflow::Safe_PyObjectPtr tensors_fast( | |||
2918 | PySequence_Fast(tensors, "Expected a sequence of input Tensors.")); | |||
2919 | if (tensors_fast == nullptr || PyErr_Occurred()) { | |||
2920 | return nullptr; | |||
2921 | } | |||
2922 | std::vector<int64_t> augmented_input_ids; | |||
2923 | int len = PySequence_Fast_GET_SIZE(tensors_fast.get())(((((((PyObject*)(tensors_fast.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(tensors_fast.get()))->ob_size)) : ( ((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(tensors_fast.get()))))->ob_size)); | |||
2924 | PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get())(((((((PyObject*)(tensors_fast.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(tensors_fast .get()))->ob_item : ((PyTupleObject *)(tensors_fast.get()) )->ob_item); | |||
2925 | for (Py_ssize_t position = 0; position < len; ++position) { | |||
2926 | PyObject* input = tensors_fast_array[position]; | |||
2927 | if (input == Py_None(&_Py_NoneStruct)) { | |||
2928 | continue; | |||
2929 | } | |||
2930 | tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input)); | |||
2931 | if (input_dtype == tensorflow::DT_INVALID) { | |||
2932 | return nullptr; | |||
2933 | } | |||
2934 | augmented_input_ids.push_back(FastTensorId(input)); | |||
2935 | } | |||
2936 | if (PyErr_Occurred()) { | |||
2937 | return nullptr; | |||
2938 | } | |||
2939 | // Find the innermost accumulator such that all outer accumulators are | |||
2940 | // recording. Any more deeply nested accumulators will not have their JVPs | |||
2941 | // saved. | |||
2942 | AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin(); | |||
2943 | for (; innermost_all_recording != accumulators.end(); | |||
2944 | ++innermost_all_recording) { | |||
2945 | if ((*innermost_all_recording)->accumulator->BusyAccumulating()) { | |||
2946 | break; | |||
2947 | } | |||
2948 | } | |||
2949 | AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording( | |||
2950 | innermost_all_recording); | |||
2951 | ||||
2952 | bool saving_jvps = false; | |||
2953 | tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size())); | |||
2954 | std::vector<PyObject*> new_tensors; | |||
2955 | Py_ssize_t accumulator_index = 0; | |||
2956 | // Start with the innermost accumulators to give outer accumulators a chance | |||
2957 | // to find their higher-order JVPs. | |||
2958 | for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin(); | |||
2959 | it != accumulators.rend(); ++it, ++accumulator_index) { | |||
2960 | std::vector<int64_t> new_input_ids; | |||
2961 | std::vector<std::pair<int64_t, int64_t>> accumulator_indices; | |||
2962 | if (it == reverse_innermost_all_recording) { | |||
2963 | saving_jvps = true; | |||
2964 | } | |||
2965 | if (saving_jvps) { | |||
2966 | for (int input_index = 0; input_index < augmented_input_ids.size(); | |||
2967 | ++input_index) { | |||
2968 | int64_t existing_input = augmented_input_ids[input_index]; | |||
2969 | PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input); | |||
2970 | if (jvp != nullptr) { | |||
2971 | new_tensors.push_back(jvp); | |||
2972 | new_input_ids.push_back(FastTensorId(jvp)); | |||
2973 | accumulator_indices.emplace_back( | |||
2974 | input_index, | |||
2975 | augmented_input_ids.size() + new_input_ids.size() - 1); | |||
2976 | } | |||
2977 | } | |||
2978 | } | |||
2979 | tensorflow::Safe_PyObjectPtr accumulator_indices_py( | |||
2980 | PyTuple_New(accumulator_indices.size())); | |||
2981 | for (int i = 0; i < accumulator_indices.size(); ++i) { | |||
2982 | tensorflow::Safe_PyObjectPtr from_index( | |||
2983 | GetPythonObjectFromInt(accumulator_indices[i].first)); | |||
2984 | tensorflow::Safe_PyObjectPtr to_index( | |||
2985 | GetPythonObjectFromInt(accumulator_indices[i].second)); | |||
2986 | PyTuple_SetItem(accumulator_indices_py.get(), i, | |||
2987 | PyTuple_Pack(2, from_index.get(), to_index.get())); | |||
2988 | } | |||
2989 | PyTuple_SetItem(all_indices.get(), accumulator_index, | |||
2990 | accumulator_indices_py.release()); | |||
2991 | augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(), | |||
2992 | new_input_ids.end()); | |||
2993 | } | |||
2994 | ||||
2995 | tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size())); | |||
2996 | for (int i = 0; i < new_tensors.size(); ++i) { | |||
2997 | PyObject* jvp = new_tensors[i]; | |||
2998 | Py_INCREF(jvp)_Py_INCREF(((PyObject*)(jvp))); | |||
2999 | PyList_SET_ITEM(new_tensors_py.get(), i, jvp)PyList_SetItem(new_tensors_py.get(), i, jvp); | |||
3000 | } | |||
3001 | return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get()); | |||
3002 | } | |||
3003 | ||||
3004 | namespace { | |||
3005 | ||||
3006 | // Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C. | |||
3007 | enum FastPathExecuteArgIndex { | |||
3008 | FAST_PATH_EXECUTE_ARG_CONTEXT = 0, | |||
3009 | FAST_PATH_EXECUTE_ARG_OP_NAME = 1, | |||
3010 | FAST_PATH_EXECUTE_ARG_NAME = 2, | |||
3011 | FAST_PATH_EXECUTE_ARG_INPUT_START = 3 | |||
3012 | }; | |||
3013 | ||||
3014 | PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) { | |||
3015 | #if PY_MAJOR_VERSION3 >= 3 | |||
3016 | return PyUnicode_FromStringAndSize(s.data(), s.size()); | |||
3017 | #else | |||
3018 | return PyBytes_FromStringAndSize(s.data(), s.size()); | |||
3019 | #endif | |||
3020 | } | |||
3021 | ||||
3022 | bool CheckResourceVariable(PyObject* item) { | |||
3023 | if (tensorflow::swig::IsResourceVariable(item)) { | |||
3024 | tensorflow::Safe_PyObjectPtr handle( | |||
3025 | PyObject_GetAttrString(item, "_handle")); | |||
3026 | return EagerTensor_CheckExact(handle.get()); | |||
3027 | } | |||
3028 | ||||
3029 | return false; | |||
3030 | } | |||
3031 | ||||
3032 | bool IsNumberType(PyObject* item) { | |||
3033 | #if PY_MAJOR_VERSION3 >= 3 | |||
3034 | return PyFloat_Check(item)((((PyObject*)(item))->ob_type) == (&PyFloat_Type) || PyType_IsSubtype ((((PyObject*)(item))->ob_type), (&PyFloat_Type))) || PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL << 24))) != 0); | |||
3035 | #else | |||
3036 | return PyFloat_Check(item)((((PyObject*)(item))->ob_type) == (&PyFloat_Type) || PyType_IsSubtype ((((PyObject*)(item))->ob_type), (&PyFloat_Type))) || PyInt_Check(item) || PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL << 24))) != 0); | |||
3037 | #endif | |||
3038 | } | |||
3039 | ||||
3040 | bool CheckOneInput(PyObject* item) { | |||
3041 | if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) || | |||
3042 | PyArray_Check(item)((((PyObject*)(item))->ob_type) == (&(*(PyTypeObject * )_tensorflow_numpy_api[2])) || PyType_IsSubtype((((PyObject*) (item))->ob_type), (&(*(PyTypeObject *)_tensorflow_numpy_api [2])))) || IsNumberType(item)) { | |||
3043 | return true; | |||
3044 | } | |||
3045 | ||||
3046 | // Sequences are not properly handled. Sequences with purely python numeric | |||
3047 | // types work, but sequences with mixes of EagerTensors and python numeric | |||
3048 | // types don't work. | |||
3049 | // TODO(nareshmodi): fix | |||
3050 | return false; | |||
3051 | } | |||
3052 | ||||
3053 | bool CheckInputsOk(PyObject* seq, int start_index, | |||
3054 | const tensorflow::OpDef& op_def) { | |||
3055 | for (int i = 0; i < op_def.input_arg_size(); i++) { | |||
3056 | PyObject* item = PyTuple_GET_ITEM(seq, i + start_index)(((static_cast<void> (0)), (PyTupleObject *)(seq))-> ob_item[i + start_index]); | |||
3057 | if (!op_def.input_arg(i).number_attr().empty() || | |||
3058 | !op_def.input_arg(i).type_list_attr().empty()) { | |||
3059 | // This item should be a seq input. | |||
3060 | if (!PySequence_Check(item)) { | |||
3061 | VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 3061, tensorflow::INFO) << "Falling back to slow path for Op \"" << op_def.name() | |||
3062 | << "\", Input \"" << op_def.input_arg(i).name() | |||
3063 | << "\" since we expected a sequence, but got " | |||
3064 | << item->ob_type->tp_name; | |||
3065 | return false; | |||
3066 | } | |||
3067 | tensorflow::Safe_PyObjectPtr fast_item( | |||
3068 | PySequence_Fast(item, "Could not parse sequence.")); | |||
3069 | if (fast_item.get() == nullptr) { | |||
3070 | return false; | |||
3071 | } | |||
3072 | int len = PySequence_Fast_GET_SIZE(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_item.get()))->ob_size)) : ((( PyVarObject*)(((static_cast<void> (0)), (PyTupleObject * )(fast_item.get()))))->ob_size)); | |||
3073 | PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_item .get()))->ob_item : ((PyTupleObject *)(fast_item.get()))-> ob_item); | |||
3074 | for (Py_ssize_t j = 0; j < len; j++) { | |||
3075 | PyObject* inner_item = fast_item_array[j]; | |||
3076 | if (!CheckOneInput(inner_item)) { | |||
3077 | VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 3077, tensorflow::INFO) << "Falling back to slow path for Op \"" << op_def.name() | |||
3078 | << "\", Input \"" << op_def.input_arg(i).name() | |||
3079 | << "\", Index " << j | |||
3080 | << " since we expected an EagerTensor/ResourceVariable, " | |||
3081 | "but got " | |||
3082 | << inner_item->ob_type->tp_name; | |||
3083 | return false; | |||
3084 | } | |||
3085 | } | |||
3086 | } else if (!CheckOneInput(item)) { | |||
3087 | VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 3087, tensorflow::INFO) | |||
3088 | << "Falling back to slow path for Op \"" << op_def.name() | |||
3089 | << "\", Input \"" << op_def.input_arg(i).name() | |||
3090 | << "\" since we expected an EagerTensor/ResourceVariable, but got " | |||
3091 | << item->ob_type->tp_name; | |||
3092 | return false; | |||
3093 | } | |||
3094 | } | |||
3095 | ||||
3096 | return true; | |||
3097 | } | |||
3098 | ||||
3099 | tensorflow::DataType MaybeGetDType(PyObject* item) { | |||
3100 | if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) { | |||
3101 | return tensorflow::PyTensor_DataType(item); | |||
3102 | } | |||
3103 | ||||
3104 | return tensorflow::DT_INVALID; | |||
3105 | } | |||
3106 | ||||
3107 | tensorflow::DataType MaybeGetDTypeForAttr(const string& attr, | |||
3108 | FastPathOpExecInfo* op_exec_info) { | |||
3109 | auto cached_it = op_exec_info->cached_dtypes.find(attr); | |||
3110 | if (cached_it != op_exec_info->cached_dtypes.end()) { | |||
3111 | return cached_it->second; | |||
3112 | } | |||
3113 | ||||
3114 | auto it = op_exec_info->attr_to_inputs_map->find(attr); | |||
3115 | if (it == op_exec_info->attr_to_inputs_map->end()) { | |||
3116 | // No other inputs - this should never happen. | |||
3117 | return tensorflow::DT_INVALID; | |||
3118 | } | |||
3119 | ||||
3120 | for (const auto& input_info : it->second) { | |||
3121 | PyObject* item = PyTuple_GET_ITEM((((static_cast<void> (0)), (PyTupleObject *)(op_exec_info ->args))->ob_item[FAST_PATH_EXECUTE_ARG_INPUT_START + input_info .i]) | |||
3122 | op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i)(((static_cast<void> (0)), (PyTupleObject *)(op_exec_info ->args))->ob_item[FAST_PATH_EXECUTE_ARG_INPUT_START + input_info .i]); | |||
3123 | if (input_info.is_list) { | |||
3124 | tensorflow::Safe_PyObjectPtr fast_item( | |||
3125 | PySequence_Fast(item, "Unable to allocate")); | |||
3126 | int len = PySequence_Fast_GET_SIZE(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_item.get()))->ob_size)) : ((( PyVarObject*)(((static_cast<void> (0)), (PyTupleObject * )(fast_item.get()))))->ob_size)); | |||
3127 | PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_item .get()))->ob_item : ((PyTupleObject *)(fast_item.get()))-> ob_item); | |||
3128 | for (int i = 0; i < len; i++) { | |||
3129 | auto dtype = MaybeGetDType(fast_item_array[i]); | |||
3130 | if (dtype != tensorflow::DT_INVALID) return dtype; | |||
3131 | } | |||
3132 | } else { | |||
3133 | auto dtype = MaybeGetDType(item); | |||
3134 | if (dtype != tensorflow::DT_INVALID) return dtype; | |||
3135 | } | |||
3136 | } | |||
3137 | ||||
3138 | auto default_it = op_exec_info->default_dtypes->find(attr); | |||
3139 | if (default_it != op_exec_info->default_dtypes->end()) { | |||
3140 | return default_it->second; | |||
3141 | } | |||
3142 | ||||
3143 | return tensorflow::DT_INVALID; | |||
3144 | } | |||
3145 | ||||
3146 | PyObject* CopySequenceSettingIndicesToNull( | |||
3147 | PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) { | |||
3148 | tensorflow::Safe_PyObjectPtr fast_seq( | |||
3149 | PySequence_Fast(seq, "unable to allocate")); | |||
3150 | int len = PySequence_Fast_GET_SIZE(fast_seq.get())(((((((PyObject*)(fast_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_seq.get()))->ob_size)) : (((PyVarObject *)(((static_cast<void> (0)), (PyTupleObject *)(fast_seq .get()))))->ob_size)); | |||
3151 | PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get())(((((((PyObject*)(fast_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_seq .get()))->ob_item : ((PyTupleObject *)(fast_seq.get()))-> ob_item); | |||
3152 | PyObject* result = PyTuple_New(len); | |||
3153 | for (int i = 0; i < len; i++) { | |||
3154 | PyObject* item; | |||
3155 | if (indices.find(i) != indices.end()) { | |||
3156 | item = Py_None(&_Py_NoneStruct); | |||
3157 | } else { | |||
3158 | item = fast_seq_array[i]; | |||
3159 | } | |||
3160 | Py_INCREF(item)_Py_INCREF(((PyObject*)(item))); | |||
3161 | PyTuple_SET_ITEM(result, i, item)PyTuple_SetItem(result, i, item); | |||
3162 | } | |||
3163 | return result; | |||
3164 | } | |||
3165 | ||||
3166 | PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, | |||
3167 | PyObject* results, | |||
3168 | PyObject* forward_pass_name_scope = nullptr) { | |||
3169 | std::vector<int64_t> input_ids = MakeTensorIDList(inputs); | |||
3170 | if (PyErr_Occurred()) return nullptr; | |||
3171 | std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs); | |||
3172 | if (PyErr_Occurred()) return nullptr; | |||
3173 | ||||
3174 | bool should_record = false; | |||
3175 | for (TFE_Py_Tape* tape : SafeTapeSet()) { | |||
3176 | if (tape->tape->ShouldRecord(input_ids, input_dtypes)) { | |||
3177 | should_record = true; | |||
3178 | break; | |||
3179 | } | |||
3180 | } | |||
3181 | if (!should_record) { | |||
3182 | for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) { | |||
3183 | if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) { | |||
3184 | should_record = true; | |||
3185 | break; | |||
3186 | } | |||
3187 | } | |||
3188 | } | |||
3189 | if (!should_record) Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
3190 | ||||
3191 | string c_op_name = TFE_GetPythonString(op_name); | |||
3192 | ||||
3193 | PyObject* op_outputs; | |||
3194 | bool op_outputs_tuple_created = false; | |||
3195 | ||||
3196 | if (const auto unused_output_indices = | |||
3197 | OpGradientUnusedOutputIndices(c_op_name)) { | |||
3198 | if (unused_output_indices->empty()) { | |||
3199 | op_outputs = Py_None(&_Py_NoneStruct); | |||
3200 | } else { | |||
3201 | op_outputs_tuple_created = true; | |||
3202 | op_outputs = | |||
3203 | CopySequenceSettingIndicesToNull(results, *unused_output_indices); | |||
3204 | } | |||
3205 | } else { | |||
3206 | op_outputs = results; | |||
3207 | } | |||
3208 | ||||
3209 | PyObject* op_inputs; | |||
3210 | bool op_inputs_tuple_created = false; | |||
3211 | ||||
3212 | if (const auto unused_input_indices = | |||
3213 | OpGradientUnusedInputIndices(c_op_name)) { | |||
3214 | if (unused_input_indices->empty()) { | |||
3215 | op_inputs = Py_None(&_Py_NoneStruct); | |||
3216 | } else { | |||
3217 | op_inputs_tuple_created = true; | |||
3218 | op_inputs = | |||
3219 | CopySequenceSettingIndicesToNull(inputs, *unused_input_indices); | |||
3220 | } | |||
3221 | } else { | |||
3222 | op_inputs = inputs; | |||
3223 | } | |||
3224 | ||||
3225 | tensorflow::eager::ForwardFunction<PyObject> py_forward_function( | |||
3226 | [op_name, attrs, inputs, results]( | |||
3227 | const std::vector<PyObject*>& input_tangents, | |||
3228 | std::vector<PyObject*>* output_tangents, bool use_batch) { | |||
3229 | return CallJVPFunction(op_name, attrs, inputs, results, input_tangents, | |||
3230 | output_tangents, use_batch); | |||
3231 | }); | |||
3232 | tensorflow::eager::ForwardFunction<PyObject>* forward_function; | |||
3233 | if (c_op_name == "While" || c_op_name == "StatelessWhile" || | |||
3234 | c_op_name == "If" || c_op_name == "StatelessIf") { | |||
3235 | // Control flow contains non-hashable attributes. Handling them in Python is | |||
3236 | // a headache, so instead we'll stay as close to GradientTape's handling as | |||
3237 | // possible (a null forward function means the accumulator forwards to a | |||
3238 | // tape). | |||
3239 | // | |||
3240 | // This is safe to do since we'll only see control flow when graph building, | |||
3241 | // in which case we can rely on pruning. | |||
3242 | forward_function = nullptr; | |||
3243 | } else { | |||
3244 | forward_function = &py_forward_function; | |||
3245 | } | |||
3246 | ||||
3247 | PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs)); | |||
3248 | ||||
3249 | if (!forward_pass_name_scope) forward_pass_name_scope = Py_None(&_Py_NoneStruct); | |||
3250 | ||||
3251 | TapeSetRecordOperation( | |||
3252 | op_name, inputs, results, input_ids, input_dtypes, | |||
3253 | [op_name, attrs, num_inputs, op_inputs, op_outputs, | |||
3254 | forward_pass_name_scope]() { | |||
3255 | Py_INCREF(op_name)_Py_INCREF(((PyObject*)(op_name))); | |||
3256 | Py_INCREF(attrs)_Py_INCREF(((PyObject*)(attrs))); | |||
3257 | Py_INCREF(num_inputs)_Py_INCREF(((PyObject*)(num_inputs))); | |||
3258 | Py_INCREF(op_inputs)_Py_INCREF(((PyObject*)(op_inputs))); | |||
3259 | Py_INCREF(op_outputs)_Py_INCREF(((PyObject*)(op_outputs))); | |||
3260 | Py_INCREF(forward_pass_name_scope)_Py_INCREF(((PyObject*)(forward_pass_name_scope))); | |||
3261 | PyBackwardFunction* function = new PyBackwardFunction( | |||
3262 | [op_name, attrs, num_inputs, op_inputs, op_outputs, | |||
3263 | forward_pass_name_scope]( | |||
3264 | PyObject* output_grads, | |||
3265 | const std::vector<int64_t>& unneeded_gradients) { | |||
3266 | if (PyErr_Occurred()) { | |||
3267 | return static_cast<PyObject*>(nullptr); | |||
3268 | } | |||
3269 | tensorflow::Safe_PyObjectPtr skip_input_indices; | |||
3270 | if (!unneeded_gradients.empty()) { | |||
3271 | skip_input_indices.reset( | |||
3272 | PyTuple_New(unneeded_gradients.size())); | |||
3273 | for (int i = 0; i < unneeded_gradients.size(); i++) { | |||
3274 | PyTuple_SET_ITEM(PyTuple_SetItem(skip_input_indices.get(), i, GetPythonObjectFromInt (unneeded_gradients[i])) | |||
3275 | skip_input_indices.get(), i,PyTuple_SetItem(skip_input_indices.get(), i, GetPythonObjectFromInt (unneeded_gradients[i])) | |||
3276 | GetPythonObjectFromInt(unneeded_gradients[i]))PyTuple_SetItem(skip_input_indices.get(), i, GetPythonObjectFromInt (unneeded_gradients[i])); | |||
3277 | } | |||
3278 | } else { | |||
3279 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | |||
3280 | skip_input_indices.reset(Py_None(&_Py_NoneStruct)); | |||
3281 | } | |||
3282 | tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue( | |||
3283 | "OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs, | |||
3284 | output_grads, skip_input_indices.get(), | |||
3285 | forward_pass_name_scope)); | |||
3286 | ||||
3287 | tensorflow::Safe_PyObjectPtr result( | |||
3288 | PyObject_CallObject(gradient_function, callback_args.get())); | |||
3289 | ||||
3290 | if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr); | |||
3291 | ||||
3292 | return tensorflow::swig::Flatten(result.get()); | |||
3293 | }); | |||
3294 | return function; | |||
3295 | }, | |||
3296 | [op_name, attrs, num_inputs, op_inputs, op_outputs, | |||
3297 | forward_pass_name_scope](PyBackwardFunction* backward_function) { | |||
3298 | Py_DECREF(op_name)_Py_DECREF(((PyObject*)(op_name))); | |||
3299 | Py_DECREF(attrs)_Py_DECREF(((PyObject*)(attrs))); | |||
3300 | Py_DECREF(num_inputs)_Py_DECREF(((PyObject*)(num_inputs))); | |||
3301 | Py_DECREF(op_inputs)_Py_DECREF(((PyObject*)(op_inputs))); | |||
3302 | Py_DECREF(op_outputs)_Py_DECREF(((PyObject*)(op_outputs))); | |||
3303 | Py_DECREF(forward_pass_name_scope)_Py_DECREF(((PyObject*)(forward_pass_name_scope))); | |||
3304 | ||||
3305 | delete backward_function; | |||
3306 | }, | |||
3307 | forward_function); | |||
3308 | ||||
3309 | Py_DECREF(num_inputs)_Py_DECREF(((PyObject*)(num_inputs))); | |||
3310 | if (op_outputs_tuple_created) Py_DECREF(op_outputs)_Py_DECREF(((PyObject*)(op_outputs))); | |||
3311 | if (op_inputs_tuple_created) Py_DECREF(op_inputs)_Py_DECREF(((PyObject*)(op_inputs))); | |||
3312 | ||||
3313 | if (PyErr_Occurred()) { | |||
3314 | return nullptr; | |||
3315 | } | |||
3316 | ||||
3317 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
3318 | } | |||
3319 | ||||
3320 | void MaybeNotifyVariableAccessed(PyObject* input) { | |||
3321 | DCHECK(CheckResourceVariable(input))while (false && (CheckResourceVariable(input))) ::tensorflow ::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 3321); | |||
3322 | DCHECK(PyObject_HasAttrString(input, "_trainable"))while (false && (PyObject_HasAttrString(input, "_trainable" ))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 3322); | |||
3323 | ||||
3324 | tensorflow::Safe_PyObjectPtr trainable( | |||
3325 | PyObject_GetAttrString(input, "_trainable")); | |||
3326 | if (trainable.get() == Py_False((PyObject *) &_Py_FalseStruct)) return; | |||
3327 | TFE_Py_TapeVariableAccessed(input); | |||
3328 | TFE_Py_VariableWatcherVariableAccessed(input); | |||
3329 | } | |||
3330 | ||||
3331 | bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, | |||
3332 | PyObject* input, tensorflow::Safe_PyObjectPtr* output, | |||
3333 | TF_Status* status) { | |||
3334 | MaybeNotifyVariableAccessed(input); | |||
3335 | ||||
3336 | TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status); | |||
3337 | auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); }); | |||
3338 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; | |||
3339 | ||||
3340 | TFE_OpSetDevice(op, parent_op_exec_info.device_name, status); | |||
3341 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; | |||
3342 | ||||
3343 | // Set dtype | |||
3344 | DCHECK(PyObject_HasAttrString(input, "_dtype"))while (false && (PyObject_HasAttrString(input, "_dtype" ))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 3344); | |||
3345 | tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype")); | |||
3346 | int value; | |||
3347 | if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) { | |||
3348 | return false; | |||
3349 | } | |||
3350 | TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value)); | |||
3351 | ||||
3352 | // Get handle | |||
3353 | tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle")); | |||
3354 | if (!EagerTensor_CheckExact(handle.get())) return false; | |||
3355 | TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status); | |||
3356 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; | |||
3357 | ||||
3358 | int num_retvals = 1; | |||
3359 | TFE_TensorHandle* output_handle; | |||
3360 | TFE_Execute(op, &output_handle, &num_retvals, status); | |||
3361 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; | |||
3362 | ||||
3363 | // Always create the py object (and correctly DECREF it) from the returned | |||
3364 | // value, else the data will leak. | |||
3365 | output->reset(EagerTensorFromHandle(output_handle)); | |||
3366 | ||||
3367 | // TODO(nareshmodi): Should we run post exec callbacks here? | |||
3368 | if (parent_op_exec_info.run_gradient_callback) { | |||
3369 | tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1)); | |||
3370 | PyTuple_SET_ITEM(inputs.get(), 0, handle.release())PyTuple_SetItem(inputs.get(), 0, handle.release()); | |||
3371 | ||||
3372 | tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1)); | |||
3373 | Py_INCREF(output->get())_Py_INCREF(((PyObject*)(output->get()))); // stay alive after since tuple steals. | |||
3374 | PyTuple_SET_ITEM(outputs.get(), 0, output->get())PyTuple_SetItem(outputs.get(), 0, output->get()); | |||
3375 | ||||
3376 | tensorflow::Safe_PyObjectPtr op_string( | |||
3377 | GetPythonObjectFromString("ReadVariableOp")); | |||
3378 | if (!RecordGradient(op_string.get(), inputs.get(), Py_None(&_Py_NoneStruct), | |||
3379 | outputs.get())) { | |||
3380 | return false; | |||
3381 | } | |||
3382 | } | |||
3383 | ||||
3384 | return true; | |||
3385 | } | |||
3386 | ||||
3387 | // Supports 3 cases at the moment: | |||
3388 | // i) input is an EagerTensor. | |||
3389 | // ii) input is a ResourceVariable - in this case, the is_variable param is | |||
3390 | // set to true. | |||
3391 | // iii) input is an arbitrary python list/tuple (note, this handling doesn't | |||
3392 | // support packing). | |||
3393 | // | |||
3394 | // NOTE: dtype_hint_getter must *always* return a PyObject that can be | |||
3395 | // decref'd. So if no hint is found, Py_RETURN_NONE (which correctly | |||
3396 | // increfs Py_None). | |||
3397 | // | |||
3398 | // NOTE: This function sets a python error directly, and returns false. | |||
3399 | // TF_Status is only passed since we don't want to have to reallocate it. | |||
3400 | bool ConvertToTensor( | |||
3401 | const FastPathOpExecInfo& op_exec_info, PyObject* input, | |||
3402 | tensorflow::Safe_PyObjectPtr* output_handle, | |||
3403 | // This gets a hint for this particular input. | |||
3404 | const std::function<tensorflow::DataType()>& dtype_hint_getter, | |||
3405 | // This sets the dtype after conversion is complete. | |||
3406 | const std::function<void(const tensorflow::DataType dtype)>& dtype_setter, | |||
3407 | TF_Status* status) { | |||
3408 | if (EagerTensor_CheckExact(input)) { | |||
3409 | Py_INCREF(input)_Py_INCREF(((PyObject*)(input))); | |||
3410 | output_handle->reset(input); | |||
3411 | return true; | |||
3412 | } else if (CheckResourceVariable(input)) { | |||
3413 | return ReadVariableOp(op_exec_info, input, output_handle, status); | |||
3414 | } | |||
3415 | ||||
3416 | // The hint comes from a supposedly similarly typed tensor. | |||
3417 | tensorflow::DataType dtype_hint = dtype_hint_getter(); | |||
3418 | ||||
3419 | TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor( | |||
3420 | op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name); | |||
3421 | if (handle == nullptr) { | |||
3422 | return MaybeRaiseExceptionFromTFStatus(status, nullptr); | |||
3423 | } | |||
3424 | ||||
3425 | output_handle->reset(EagerTensorFromHandle(handle)); | |||
3426 | dtype_setter( | |||
3427 | static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle))); | |||
3428 | ||||
3429 | return true; | |||
3430 | } | |||
3431 | ||||
3432 | // Adds input and type attr to the op, and to the list of flattened | |||
3433 | // inputs/attrs. | |||
3434 | bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input, | |||
3435 | const bool add_type_attr, | |||
3436 | const tensorflow::OpDef::ArgDef& input_arg, | |||
3437 | std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs, | |||
3438 | std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs, | |||
3439 | TFE_Op* op, TF_Status* status) { | |||
3440 | // py_eager_tensor's ownership is transferred to flattened_inputs if it is | |||
3441 | // required, else the object is destroyed and DECREF'd when the object goes | |||
3442 | // out of scope in this function. | |||
3443 | tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr; | |||
3444 | ||||
3445 | if (!ConvertToTensor( | |||
3446 | *op_exec_info, input, &py_eager_tensor, | |||
3447 | [&]() { | |||
3448 | if (input_arg.type() != tensorflow::DataType::DT_INVALID) { | |||
3449 | return input_arg.type(); | |||
3450 | } | |||
3451 | return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info); | |||
3452 | }, | |||
3453 | [&](const tensorflow::DataType dtype) { | |||
3454 | op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype; | |||
3455 | }, | |||
3456 | status)) { | |||
3457 | return false; | |||
3458 | } | |||
3459 | ||||
3460 | TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get()); | |||
3461 | ||||
3462 | if (add_type_attr && !input_arg.type_attr().empty()) { | |||
3463 | auto dtype = TFE_TensorHandleDataType(input_handle); | |||
3464 | TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype); | |||
3465 | if (flattened_attrs != nullptr) { | |||
3466 | flattened_attrs->emplace_back( | |||
3467 | GetPythonObjectFromString(input_arg.type_attr())); | |||
3468 | flattened_attrs->emplace_back(PyLong_FromLong(dtype)); | |||
3469 | } | |||
3470 | } | |||
3471 | ||||
3472 | if (flattened_inputs != nullptr) { | |||
3473 | flattened_inputs->emplace_back(std::move(py_eager_tensor)); | |||
3474 | } | |||
3475 | ||||
3476 | TFE_OpAddInput(op, input_handle, status); | |||
3477 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { | |||
3478 | return false; | |||
3479 | } | |||
3480 | ||||
3481 | return true; | |||
3482 | } | |||
3483 | ||||
3484 | const char* GetDeviceName(PyObject* py_device_name) { | |||
3485 | if (py_device_name != Py_None(&_Py_NoneStruct)) { | |||
3486 | return TFE_GetPythonString(py_device_name); | |||
3487 | } | |||
3488 | return nullptr; | |||
3489 | } | |||
3490 | ||||
3491 | bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) { | |||
3492 | if (!PySequence_Check(seq)) { | |||
3493 | PyErr_SetString(PyExc_TypeError, | |||
3494 | Printf("expected a sequence for attr %s, got %s instead", | |||
3495 | attr_name.data(), seq->ob_type->tp_name) | |||
3496 | .data()); | |||
3497 | ||||
3498 | return false; | |||
3499 | } | |||
3500 | if (PyArray_Check(seq)((((PyObject*)(seq))->ob_type) == (&(*(PyTypeObject *) _tensorflow_numpy_api[2])) || PyType_IsSubtype((((PyObject*)( seq))->ob_type), (&(*(PyTypeObject *)_tensorflow_numpy_api [2])))) && | |||
3501 | PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) { | |||
3502 | PyErr_SetString(PyExc_ValueError, | |||
3503 | Printf("expected a sequence for attr %s, got an ndarray " | |||
3504 | "with rank %d instead", | |||
3505 | attr_name.data(), | |||
3506 | PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq))) | |||
3507 | .data()); | |||
3508 | return false; | |||
3509 | } | |||
3510 | return true; | |||
3511 | } | |||
3512 | ||||
3513 | bool RunCallbacks( | |||
3514 | const FastPathOpExecInfo& op_exec_info, PyObject* args, | |||
3515 | int num_inferred_attrs, | |||
3516 | const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs, | |||
3517 | const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs, | |||
3518 | PyObject* flattened_result) { | |||
3519 | DCHECK(op_exec_info.run_callbacks)while (false && (op_exec_info.run_callbacks)) ::tensorflow ::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 3519); | |||
3520 | ||||
3521 | tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size())); | |||
3522 | for (int i = 0; i < flattened_inputs.size(); i++) { | |||
3523 | PyObject* input = flattened_inputs[i].get(); | |||
3524 | Py_INCREF(input)_Py_INCREF(((PyObject*)(input))); | |||
3525 | PyTuple_SET_ITEM(inputs.get(), i, input)PyTuple_SetItem(inputs.get(), i, input); | |||
3526 | } | |||
3527 | ||||
3528 | int num_non_inferred_attrs = PyTuple_GET_SIZE(args)(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(args))))->ob_size) - num_inferred_attrs; | |||
3529 | int num_attrs = flattened_attrs.size() + num_non_inferred_attrs; | |||
3530 | tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs)); | |||
3531 | ||||
3532 | for (int i = 0; i < num_non_inferred_attrs; i++) { | |||
3533 | auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[num_inferred_attrs + i]); | |||
3534 | Py_INCREF(attr)_Py_INCREF(((PyObject*)(attr))); | |||
3535 | PyTuple_SET_ITEM(attrs.get(), i, attr)PyTuple_SetItem(attrs.get(), i, attr); | |||
3536 | } | |||
3537 | ||||
3538 | for (int i = num_non_inferred_attrs; i < num_attrs; i++) { | |||
3539 | PyObject* attr_or_name = | |||
3540 | flattened_attrs.at(i - num_non_inferred_attrs).get(); | |||
3541 | Py_INCREF(attr_or_name)_Py_INCREF(((PyObject*)(attr_or_name))); | |||
3542 | PyTuple_SET_ITEM(attrs.get(), i, attr_or_name)PyTuple_SetItem(attrs.get(), i, attr_or_name); | |||
3543 | } | |||
3544 | ||||
3545 | if (op_exec_info.run_gradient_callback) { | |||
3546 | if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(), | |||
3547 | flattened_result)) { | |||
3548 | return false; | |||
3549 | } | |||
3550 | } | |||
3551 | ||||
3552 | if (op_exec_info.run_post_exec_callbacks) { | |||
3553 | tensorflow::Safe_PyObjectPtr callback_args( | |||
3554 | Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(), | |||
3555 | flattened_result, op_exec_info.name)); | |||
3556 | for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) { | |||
3557 | PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i)(((PyListObject *)(op_exec_info.callbacks))->ob_item[i]); | |||
3558 | if (!PyCallable_Check(callback_fn)) { | |||
3559 | PyErr_SetString( | |||
3560 | PyExc_TypeError, | |||
3561 | Printf("expected a function for " | |||
3562 | "post execution callback in index %ld, got %s instead", | |||
3563 | i, callback_fn->ob_type->tp_name) | |||
3564 | .c_str()); | |||
3565 | return false; | |||
3566 | } | |||
3567 | PyObject* callback_result = | |||
3568 | PyObject_CallObject(callback_fn, callback_args.get()); | |||
3569 | if (!callback_result) { | |||
3570 | return false; | |||
3571 | } | |||
3572 | Py_DECREF(callback_result)_Py_DECREF(((PyObject*)(callback_result))); | |||
3573 | } | |||
3574 | } | |||
3575 | ||||
3576 | return true; | |||
3577 | } | |||
3578 | ||||
3579 | } // namespace | |||
3580 | ||||
3581 | PyObject* TFE_Py_FastPathExecute_C(PyObject* args) { | |||
3582 | tensorflow::profiler::TraceMe activity( | |||
3583 | "TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo); | |||
3584 | Py_ssize_t args_size = PyTuple_GET_SIZE(args)(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(args))))->ob_size); | |||
3585 | if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) { | |||
3586 | PyErr_SetString( | |||
3587 | PyExc_ValueError, | |||
3588 | Printf("There must be at least %d items in the input tuple.", | |||
3589 | FAST_PATH_EXECUTE_ARG_INPUT_START) | |||
3590 | .c_str()); | |||
3591 | return nullptr; | |||
3592 | } | |||
3593 | ||||
3594 | FastPathOpExecInfo op_exec_info; | |||
3595 | ||||
3596 | PyObject* py_eager_context = | |||
3597 | PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[FAST_PATH_EXECUTE_ARG_CONTEXT]); | |||
3598 | ||||
3599 | // TODO(edoper): Use interned string here | |||
3600 | PyObject* eager_context_handle = | |||
3601 | PyObject_GetAttrString(py_eager_context, "_context_handle"); | |||
3602 | ||||
3603 | TFE_Context* ctx = reinterpret_cast<TFE_Context*>( | |||
3604 | PyCapsule_GetPointer(eager_context_handle, nullptr)); | |||
3605 | op_exec_info.ctx = ctx; | |||
3606 | op_exec_info.args = args; | |||
3607 | ||||
3608 | if (ctx == nullptr) { | |||
3609 | // The context hasn't been initialized. It will be in the slow path. | |||
3610 | RaiseFallbackException( | |||
3611 | "This function does not handle the case of the path where " | |||
3612 | "all inputs are not already EagerTensors."); | |||
3613 | return nullptr; | |||
3614 | } | |||
3615 | ||||
3616 | auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context); | |||
3617 | if (tld == nullptr) { | |||
3618 | return nullptr; | |||
3619 | } | |||
3620 | op_exec_info.device_name = GetDeviceName(tld->device_name.get()); | |||
3621 | op_exec_info.callbacks = tld->op_callbacks.get(); | |||
3622 | ||||
3623 | op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[FAST_PATH_EXECUTE_ARG_OP_NAME]); | |||
3624 | op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[FAST_PATH_EXECUTE_ARG_NAME]); | |||
3625 | ||||
3626 | // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks | |||
3627 | // (similar to benchmark_tf_gradient_function_*). Also consider using an | |||
3628 | // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks | |||
3629 | // point out problems with heap allocs. | |||
3630 | op_exec_info.run_gradient_callback = | |||
3631 | !*ThreadTapeIsStopped() && HasAccumulatorOrTape(); | |||
3632 | op_exec_info.run_post_exec_callbacks = | |||
3633 | op_exec_info.callbacks != Py_None(&_Py_NoneStruct) && | |||
3634 | PyList_Size(op_exec_info.callbacks) > 0; | |||
3635 | op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || | |||
3636 | op_exec_info.run_post_exec_callbacks; | |||
3637 | ||||
3638 | TF_Status* status = GetStatus(); | |||
3639 | const char* op_name = TFE_GetPythonString(op_exec_info.op_name); | |||
3640 | if (op_name == nullptr) { | |||
3641 | PyErr_SetString(PyExc_TypeError, | |||
3642 | Printf("expected a string for op_name, got %s instead", | |||
3643 | op_exec_info.op_name->ob_type->tp_name) | |||
3644 | .c_str()); | |||
3645 | return nullptr; | |||
3646 | } | |||
3647 | ||||
3648 | TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status); | |||
3649 | ||||
3650 | auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] { | |||
3651 | ReturnStatus(status); | |||
3652 | ReturnOp(ctx, op); | |||
3653 | }); | |||
3654 | ||||
3655 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { | |||
3656 | return nullptr; | |||
3657 | } | |||
3658 | ||||
3659 | tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace( | |||
3660 | tensorflow::StackTrace::kStackTraceInitialSize)); | |||
3661 | ||||
3662 | const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef(); | |||
3663 | if (op_def == nullptr) return nullptr; | |||
3664 | ||||
3665 | if (args_size < | |||
3666 | FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) { | |||
3667 | PyErr_SetString( | |||
3668 | PyExc_ValueError, | |||
3669 | Printf("Tuple size smaller than intended. Expected to be at least %d, " | |||
3670 | "was %ld", | |||
3671 | FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(), | |||
3672 | args_size) | |||
3673 | .c_str()); | |||
3674 | return nullptr; | |||
3675 | } | |||
3676 | ||||
3677 | if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) { | |||
3678 | RaiseFallbackException( | |||
3679 | "This function does not handle the case of the path where " | |||
3680 | "all inputs are not already EagerTensors."); | |||
3681 | return nullptr; | |||
3682 | } | |||
3683 | ||||
3684 | op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def); | |||
3685 | op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def); | |||
3686 | ||||
3687 | // Mapping of attr name to size - used to calculate the number of values | |||
3688 | // to be expected by the TFE_Execute run. | |||
3689 | tensorflow::gtl::FlatMap<string, int64_t> attr_list_sizes; | |||
3690 | ||||
3691 | // Set non-inferred attrs, including setting defaults if the attr is passed in | |||
3692 | // as None. | |||
3693 | for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(); | |||
3694 | i < args_size; i += 2) { | |||
3695 | PyObject* py_attr_name = PyTuple_GET_ITEM(args, i)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[i]); | |||
3696 | const char* attr_name = TFE_GetPythonString(py_attr_name); | |||
3697 | PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[i + 1]); | |||
3698 | ||||
3699 | // Not creating an index since most of the time there are not more than a | |||
3700 | // few attrs. | |||
3701 | // TODO(nareshmodi): Maybe include the index as part of the | |||
3702 | // OpRegistrationData. | |||
3703 | for (const auto& attr : op_def->attr()) { | |||
3704 | if (tensorflow::StringPiece(attr_name) == attr.name()) { | |||
3705 | SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value, | |||
3706 | &attr_list_sizes, status); | |||
3707 | ||||
3708 | if (!status->status.ok()) { | |||
3709 | VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 3709, tensorflow::INFO) << "Falling back to slow path for Op \"" << op_def->name() | |||
3710 | << "\" since we are unable to set the value for attr \"" | |||
3711 | << attr.name() << "\" due to: " << TF_Message(status); | |||
3712 | RaiseFallbackException(TF_Message(status)); | |||
3713 | return nullptr; | |||
3714 | } | |||
3715 | ||||
3716 | break; | |||
3717 | } | |||
3718 | } | |||
3719 | } | |||
3720 | ||||
3721 | // Flat attrs and inputs as required by the record_gradient call. The attrs | |||
3722 | // here only contain inferred attrs (non-inferred attrs are added directly | |||
3723 | // from the input args). | |||
3724 | // All items in flattened_attrs and flattened_inputs contain | |||
3725 | // Safe_PyObjectPtr - any time something steals a reference to this, it must | |||
3726 | // INCREF. | |||
3727 | // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work | |||
3728 | // directly. | |||
3729 | std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs = | |||
3730 | nullptr; | |||
3731 | std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs = | |||
3732 | nullptr; | |||
3733 | ||||
3734 | // TODO(nareshmodi): Encapsulate callbacks information into a struct. | |||
3735 | if (op_exec_info.run_callbacks) { | |||
3736 | flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>); | |||
3737 | flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>); | |||
3738 | } | |||
3739 | ||||
3740 | // Add inferred attrs and inputs. | |||
3741 | // The following code might set duplicate type attrs. This will result in | |||
3742 | // the CacheKey for the generated AttrBuilder possibly differing from | |||
3743 | // those where the type attrs are correctly set. Inconsistent CacheKeys | |||
3744 | // for ops means that there might be unnecessarily duplicated kernels. | |||
3745 | // TODO(nareshmodi): Fix this. | |||
3746 | for (int i = 0; i < op_def->input_arg_size(); i++) { | |||
3747 | const auto& input_arg = op_def->input_arg(i); | |||
3748 | ||||
3749 | PyObject* input = | |||
3750 | PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[FAST_PATH_EXECUTE_ARG_INPUT_START + i]); | |||
3751 | if (!input_arg.number_attr().empty()) { | |||
3752 | // The item is a homogeneous list. | |||
3753 | if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr; | |||
3754 | tensorflow::Safe_PyObjectPtr fast_input( | |||
3755 | PySequence_Fast(input, "Could not parse sequence.")); | |||
3756 | if (fast_input.get() == nullptr) { | |||
3757 | return nullptr; | |||
3758 | } | |||
3759 | Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_input.get()))->ob_size)) : (( (PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(fast_input.get()))))->ob_size)); | |||
3760 | PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_input .get()))->ob_item : ((PyTupleObject *)(fast_input.get()))-> ob_item); | |||
3761 | ||||
3762 | TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len); | |||
3763 | if (op_exec_info.run_callbacks) { | |||
3764 | flattened_attrs->emplace_back( | |||
3765 | GetPythonObjectFromString(input_arg.number_attr())); | |||
3766 | flattened_attrs->emplace_back(PyLong_FromLong(len)); | |||
3767 | } | |||
3768 | attr_list_sizes[input_arg.number_attr()] = len; | |||
3769 | ||||
3770 | if (len > 0) { | |||
3771 | // First item adds the type attr. | |||
3772 | if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg, | |||
3773 | flattened_attrs.get(), flattened_inputs.get(), op, | |||
3774 | status)) { | |||
3775 | return nullptr; | |||
3776 | } | |||
3777 | ||||
3778 | for (Py_ssize_t j = 1; j < len; j++) { | |||
3779 | // Since the list is homogeneous, we don't need to re-add the attr. | |||
3780 | if (!AddInputToOp(&op_exec_info, fast_input_array[j], false, | |||
3781 | input_arg, nullptr /* flattened_attrs */, | |||
3782 | flattened_inputs.get(), op, status)) { | |||
3783 | return nullptr; | |||
3784 | } | |||
3785 | } | |||
3786 | } | |||
3787 | } else if (!input_arg.type_list_attr().empty()) { | |||
3788 | // The item is a heterogeneous list. | |||
3789 | if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) { | |||
3790 | return nullptr; | |||
3791 | } | |||
3792 | tensorflow::Safe_PyObjectPtr fast_input( | |||
3793 | PySequence_Fast(input, "Could not parse sequence.")); | |||
3794 | if (fast_input.get() == nullptr) { | |||
3795 | return nullptr; | |||
3796 | } | |||
3797 | const string& attr_name = input_arg.type_list_attr(); | |||
3798 | Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_input.get()))->ob_size)) : (( (PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(fast_input.get()))))->ob_size)); | |||
3799 | PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_input .get()))->ob_item : ((PyTupleObject *)(fast_input.get()))-> ob_item); | |||
3800 | tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len); | |||
3801 | PyObject* py_attr_value = nullptr; | |||
3802 | if (op_exec_info.run_callbacks) { | |||
3803 | py_attr_value = PyTuple_New(len); | |||
3804 | } | |||
3805 | for (Py_ssize_t j = 0; j < len; j++) { | |||
3806 | PyObject* py_input = fast_input_array[j]; | |||
3807 | tensorflow::Safe_PyObjectPtr py_eager_tensor; | |||
3808 | if (!ConvertToTensor( | |||
3809 | op_exec_info, py_input, &py_eager_tensor, | |||
3810 | []() { return tensorflow::DT_INVALID; }, | |||
3811 | [](const tensorflow::DataType dtype) {}, status)) { | |||
3812 | return nullptr; | |||
3813 | } | |||
3814 | ||||
3815 | TFE_TensorHandle* input_handle = | |||
3816 | EagerTensor_Handle(py_eager_tensor.get()); | |||
3817 | ||||
3818 | attr_value[j] = TFE_TensorHandleDataType(input_handle); | |||
3819 | ||||
3820 | TFE_OpAddInput(op, input_handle, status); | |||
3821 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { | |||
3822 | return nullptr; | |||
3823 | } | |||
3824 | ||||
3825 | if (op_exec_info.run_callbacks) { | |||
3826 | flattened_inputs->emplace_back(std::move(py_eager_tensor)); | |||
3827 | ||||
3828 | PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]))PyTuple_SetItem(py_attr_value, j, PyLong_FromLong(attr_value[ j])); | |||
3829 | } | |||
3830 | } | |||
3831 | if (op_exec_info.run_callbacks) { | |||
3832 | flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name)); | |||
3833 | flattened_attrs->emplace_back(py_attr_value); | |||
3834 | } | |||
3835 | TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(), | |||
3836 | attr_value.size()); | |||
3837 | attr_list_sizes[attr_name] = len; | |||
3838 | } else { | |||
3839 | // The item is a single item. | |||
3840 | if (!AddInputToOp(&op_exec_info, input, true, input_arg, | |||
3841 | flattened_attrs.get(), flattened_inputs.get(), op, | |||
3842 | status)) { | |||
3843 | return nullptr; | |||
3844 | } | |||
3845 | } | |||
3846 | } | |||
3847 | ||||
3848 | int64_t num_outputs = 0; | |||
3849 | for (int i = 0; i < op_def->output_arg_size(); i++) { | |||
3850 | const auto& output_arg = op_def->output_arg(i); | |||
3851 | int64_t delta = 1; | |||
3852 | if (!output_arg.number_attr().empty()) { | |||
3853 | delta = attr_list_sizes[output_arg.number_attr()]; | |||
3854 | } else if (!output_arg.type_list_attr().empty()) { | |||
3855 | delta = attr_list_sizes[output_arg.type_list_attr()]; | |||
3856 | } | |||
3857 | if (delta < 0) { | |||
3858 | RaiseFallbackException( | |||
3859 | "Attributes suggest that the size of an output list is less than 0"); | |||
3860 | return nullptr; | |||
3861 | } | |||
3862 | num_outputs += delta; | |||
3863 | } | |||
3864 | ||||
3865 | // If number of retvals is larger than int32, we error out. | |||
3866 | if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) { | |||
3867 | PyErr_SetString( | |||
3868 | PyExc_ValueError, | |||
3869 | Printf("Number of outputs is too big: %ld", num_outputs).c_str()); | |||
3870 | return nullptr; | |||
3871 | } | |||
3872 | int num_retvals = num_outputs; | |||
3873 | ||||
3874 | tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals); | |||
3875 | ||||
3876 | Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();; | |||
3877 | TFE_Execute(op, retvals.data(), &num_retvals, status); | |||
3878 | Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); }; | |||
3879 | ||||
3880 | if (!status->status.ok()) { | |||
3881 | // Augment the status with the op_name for easier debugging similar to | |||
3882 | // TFE_Py_Execute. | |||
3883 | status->status = tensorflow::errors::CreateWithUpdatedMessage( | |||
3884 | status->status, tensorflow::strings::StrCat( | |||
3885 | TF_Message(status), " [Op:", | |||
3886 | TFE_GetPythonString(op_exec_info.op_name), "]")); | |||
3887 | MaybeRaiseExceptionFromTFStatus(status, nullptr); | |||
3888 | return nullptr; | |||
3889 | } | |||
3890 | ||||
3891 | tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals)); | |||
3892 | for (int i = 0; i < num_retvals; ++i) { | |||
3893 | PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]))PyList_SetItem(flat_result.get(), i, EagerTensorFromHandle(retvals [i])); | |||
3894 | } | |||
3895 | ||||
3896 | if (op_exec_info.run_callbacks) { | |||
3897 | if (!RunCallbacks( | |||
3898 | op_exec_info, args, | |||
3899 | FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(), | |||
3900 | *flattened_inputs, *flattened_attrs, flat_result.get())) { | |||
3901 | return nullptr; | |||
3902 | } | |||
3903 | } | |||
3904 | ||||
3905 | // Unflatten results. | |||
3906 | if (op_def->output_arg_size() == 0) { | |||
3907 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
3908 | } | |||
3909 | ||||
3910 | if (op_def->output_arg_size() == 1) { | |||
3911 | if (!op_def->output_arg(0).number_attr().empty() || | |||
3912 | !op_def->output_arg(0).type_list_attr().empty()) { | |||
3913 | return flat_result.release(); | |||
3914 | } else { | |||
3915 | auto* result = PyList_GET_ITEM(flat_result.get(), 0)(((PyListObject *)(flat_result.get()))->ob_item[0]); | |||
3916 | Py_INCREF(result)_Py_INCREF(((PyObject*)(result))); | |||
3917 | return result; | |||
3918 | } | |||
3919 | } | |||
3920 | ||||
3921 | // Correctly output the results that are made into a namedtuple. | |||
3922 | PyObject* result = PyList_New(op_def->output_arg_size()); | |||
3923 | int flat_result_index = 0; | |||
3924 | for (int i = 0; i < op_def->output_arg_size(); i++) { | |||
3925 | if (!op_def->output_arg(i).number_attr().empty()) { | |||
3926 | int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()]; | |||
3927 | PyObject* inner_list = PyList_New(list_length); | |||
3928 | for (int j = 0; j < list_length; j++) { | |||
3929 | PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++)(((PyListObject *)(flat_result.get()))->ob_item[flat_result_index ++]); | |||
3930 | Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj))); | |||
3931 | PyList_SET_ITEM(inner_list, j, obj)PyList_SetItem(inner_list, j, obj); | |||
3932 | } | |||
3933 | PyList_SET_ITEM(result, i, inner_list)PyList_SetItem(result, i, inner_list); | |||
3934 | } else if (!op_def->output_arg(i).type_list_attr().empty()) { | |||
3935 | int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()]; | |||
3936 | PyObject* inner_list = PyList_New(list_length); | |||
3937 | for (int j = 0; j < list_length; j++) { | |||
3938 | PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++)(((PyListObject *)(flat_result.get()))->ob_item[flat_result_index ++]); | |||
3939 | Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj))); | |||
3940 | PyList_SET_ITEM(inner_list, j, obj)PyList_SetItem(inner_list, j, obj); | |||
3941 | } | |||
3942 | PyList_SET_ITEM(result, i, inner_list)PyList_SetItem(result, i, inner_list); | |||
3943 | } else { | |||
3944 | PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++)(((PyListObject *)(flat_result.get()))->ob_item[flat_result_index ++]); | |||
3945 | Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj))); | |||
3946 | PyList_SET_ITEM(result, i, obj)PyList_SetItem(result, i, obj); | |||
3947 | } | |||
3948 | } | |||
3949 | return result; | |||
3950 | } | |||
3951 | ||||
3952 | PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, | |||
3953 | PyObject* attrs, PyObject* results, | |||
3954 | PyObject* forward_pass_name_scope) { | |||
3955 | if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) { | |||
3956 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
3957 | } | |||
3958 | ||||
3959 | return RecordGradient(op_name, inputs, attrs, results, | |||
3960 | forward_pass_name_scope); | |||
3961 | } | |||
3962 | ||||
3963 | namespace { | |||
3964 | const char kTensor[] = "T"; | |||
3965 | const char kList[] = "L"; | |||
3966 | const char kListEnd[] = "l"; | |||
3967 | const char kTuple[] = "U"; | |||
3968 | const char kTupleEnd[] = "u"; | |||
3969 | const char kDIter[] = "I"; | |||
3970 | const char kDict[] = "D"; | |||
3971 | const char kRaw[] = "R"; | |||
3972 | const char kResourceVariable[] = "r"; | |||
3973 | const char kShape[] = "s"; | |||
3974 | const char kShapeDelim[] = "-"; | |||
3975 | const char kDType[] = "d"; | |||
3976 | const char kNone[] = "n"; | |||
3977 | const char kCompositeTensor[] = "C"; | |||
3978 | const char kAttrs[] = "A"; | |||
3979 | const char kAttrsEnd[] = "a"; | |||
3980 | const char kName[] = "'"; | |||
3981 | const char kNameEnd[] = "'"; | |||
3982 | const char kLocalIdDelim[] = "_"; | |||
3983 | ||||
3984 | // Container for storing generated string encoding as well as the raw python | |||
3985 | // objects that were not included in the string. | |||
3986 | struct EncodeResult { | |||
3987 | string str; | |||
3988 | std::vector<PyObject*> objects; | |||
3989 | ||||
3990 | PyObject* ToPyTuple() { | |||
3991 | PyObject* result = PyTuple_New(2); | |||
3992 | ||||
3993 | PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str))PyTuple_SetItem(result, 0, GetPythonObjectFromString(str)); | |||
3994 | ||||
3995 | if (objects.empty()) { | |||
3996 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | |||
3997 | PyTuple_SET_ITEM(result, 1, Py_None)PyTuple_SetItem(result, 1, (&_Py_NoneStruct)); | |||
3998 | } else { | |||
3999 | PyObject* objects_tuple = PyTuple_New(objects.size()); | |||
4000 | ||||
4001 | for (int i = 0; i < objects.size(); i++) { | |||
4002 | PyTuple_SET_ITEM(objects_tuple, i, objects[i])PyTuple_SetItem(objects_tuple, i, objects[i]); | |||
4003 | } | |||
4004 | ||||
4005 | PyTuple_SET_ITEM(result, 1, objects_tuple)PyTuple_SetItem(result, 1, objects_tuple); | |||
4006 | } | |||
4007 | ||||
4008 | return result; | |||
4009 | } | |||
4010 | }; | |||
4011 | ||||
4012 | // Gives each unique resource_id a unique incremental local_id. Provides a | |||
4013 | // string encoding that informs an order and uniqueness sensitive input | |||
4014 | // signature. | |||
4015 | // This class is not thread safe and is not meant to be shared across threads. | |||
4016 | class LocalResourceIdMap { | |||
4017 | public: | |||
4018 | // When the resource ID is known (such as for OwnedIterator). | |||
4019 | // Returns the existing local ID (if present) or a new unique one. | |||
4020 | int AddResourceId(int resource_id) { | |||
4021 | const auto& it = resource_id_to_local_id_.find(resource_id); | |||
4022 | if (it == resource_id_to_local_id_.end()) { | |||
4023 | resource_id_to_local_id_[resource_id] = next_local_id_; | |||
4024 | return next_local_id_++; | |||
4025 | } else { | |||
4026 | return it->second; | |||
4027 | } | |||
4028 | } | |||
4029 | ||||
4030 | // When the resource ID is not known (such as for IteratorSpec). | |||
4031 | // Returns a new unique local ID. | |||
4032 | int AddUnknownResource() { return next_local_id_++; } | |||
4033 | ||||
4034 | private: | |||
4035 | absl::flat_hash_map<int, int> resource_id_to_local_id_; | |||
4036 | int next_local_id_ = 0; | |||
4037 | }; | |||
4038 | ||||
4039 | // Contains encoding configuration, intermediary data and result. | |||
4040 | struct EncodingContext { | |||
4041 | bool include_tensor_ranks_only; | |||
4042 | bool encode_variable_by_resource_id; | |||
4043 | ||||
4044 | LocalResourceIdMap resource_id_map; | |||
4045 | EncodeResult result; | |||
4046 | }; | |||
4047 | ||||
4048 | tensorflow::Status EncodeTensorOrTensorSpec(PyObject* arg, bool is_tensor_spec, | |||
4049 | EncodingContext& context) { | |||
4050 | absl::StrAppend(&context.result.str, kTensor); | |||
4051 | ||||
4052 | if (is_tensor_spec) { | |||
4053 | tensorflow::Safe_PyObjectPtr name(PyObject_GetAttrString(arg, "name")); | |||
4054 | if (name != nullptr && name.get() != Py_None(&_Py_NoneStruct)) { | |||
4055 | absl::StrAppend(&context.result.str, kName, | |||
4056 | TFE_GetPythonString(name.get()), kNameEnd); | |||
4057 | } | |||
4058 | } | |||
4059 | ||||
4060 | tensorflow::Safe_PyObjectPtr dtype_object( | |||
4061 | PyObject_GetAttrString(arg, "dtype")); | |||
4062 | if (dtype_object == nullptr) { | |||
4063 | return tensorflow::errors::InvalidArgument( | |||
4064 | "tf.TensorSpec object doesn't have dtype() attr."); | |||
4065 | } | |||
4066 | ||||
4067 | tensorflow::Safe_PyObjectPtr dtype_enum( | |||
4068 | PyObject_GetAttrString(dtype_object.get(), "_type_enum")); | |||
4069 | if (dtype_enum == nullptr) { | |||
4070 | return tensorflow::errors::InvalidArgument( | |||
4071 | "tf.TensorSpec's dtype object doesn't have _type_enum() " | |||
4072 | "attr."); | |||
4073 | } | |||
4074 | ||||
4075 | tensorflow::DataType dtype = | |||
4076 | static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get())); | |||
4077 | absl::StrAppend(&context.result.str, kDType, dtype); | |||
4078 | ||||
4079 | tensorflow::Safe_PyObjectPtr shape_tuple( | |||
4080 | PyObject_GetAttrString(arg, "shape")); | |||
4081 | if (shape_tuple == nullptr) { | |||
4082 | return tensorflow::errors::InvalidArgument( | |||
4083 | "tf.TensorSpec object doesn't have shape() attr."); | |||
4084 | } | |||
4085 | ||||
4086 | tensorflow::Safe_PyObjectPtr rank( | |||
4087 | PyObject_GetAttr(shape_tuple.get(), PyUnicode_FromString("rank"))); | |||
4088 | if (rank == nullptr || rank.get() == Py_None(&_Py_NoneStruct)) { | |||
4089 | // Unknown shape, encode that directly. | |||
4090 | absl::StrAppend(&context.result.str, kNone); | |||
4091 | return tensorflow::Status::OK(); | |||
4092 | } | |||
4093 | ||||
4094 | absl::StrAppend(&context.result.str, kShape); | |||
4095 | ||||
4096 | tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast( | |||
4097 | shape_tuple.get(), "shape_tuple didn't return a sequence")); | |||
4098 | ||||
4099 | int len = MakeInt(rank.get()); | |||
4100 | PyObject** shape_seq_array = PySequence_Fast_ITEMS(shape_seq.get())(((((((PyObject*)(shape_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(shape_seq .get()))->ob_item : ((PyTupleObject *)(shape_seq.get()))-> ob_item); | |||
4101 | ||||
4102 | if (context.include_tensor_ranks_only) { | |||
4103 | absl::StrAppend(&context.result.str, len); | |||
4104 | } else { | |||
4105 | for (int i = 0; i < len; ++i) { | |||
4106 | // Can be None, int or a Dimension object. | |||
4107 | PyObject* dimension = shape_seq_array[i]; | |||
4108 | ||||
4109 | // If it is a Dimension object, then we must extract value from it first. | |||
4110 | bool is_dimension_class = PyObject_HasAttrString(dimension, "value"); | |||
4111 | tensorflow::Safe_PyObjectPtr dimension_holder; | |||
4112 | if (is_dimension_class) { | |||
4113 | dimension_holder = | |||
4114 | tensorflow::make_safe(PyObject_GetAttrString(dimension, "value")); | |||
4115 | dimension = dimension_holder.get(); | |||
4116 | } | |||
4117 | ||||
4118 | if (dimension == Py_None(&_Py_NoneStruct)) { | |||
4119 | absl::StrAppend(&context.result.str, kNone); | |||
4120 | } else { | |||
4121 | absl::StrAppend(&context.result.str, MakeInt(dimension), kShapeDelim); | |||
4122 | } | |||
4123 | } | |||
4124 | } | |||
4125 | ||||
4126 | return tensorflow::Status::OK(); | |||
4127 | } | |||
4128 | ||||
4129 | // TODO(b/199534088): Remove this function by using EncodeResource instead. | |||
4130 | tensorflow::Status EncodeOwnedIterator(PyObject* arg, | |||
4131 | EncodingContext& context) { | |||
4132 | PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec")); | |||
4133 | if (type_spec == nullptr) { | |||
4134 | return tensorflow::errors::InvalidArgument( | |||
4135 | "Error while reading OwnedIterator._type_spec."); | |||
4136 | } | |||
4137 | context.result.objects.push_back(type_spec); | |||
4138 | ||||
4139 | // Add resource tracking | |||
4140 | tensorflow::Safe_PyObjectPtr itr_res( | |||
4141 | PyObject_GetAttrString(arg, "_iterator_resource")); | |||
4142 | if (itr_res == nullptr) { | |||
4143 | return tensorflow::errors::InvalidArgument( | |||
4144 | "Error while reading Dataset iterator resource."); | |||
4145 | } | |||
4146 | // OwnedIterator should ideally always provide a unique resource id. | |||
4147 | // TODO(b/199534088) Cases where resource_id is not provided need to be fixed. | |||
4148 | if (tensorflow::swig::IsTensor(itr_res.get())) { | |||
4149 | absl::StrAppend(&context.result.str, kDIter); | |||
4150 | tensorflow::Safe_PyObjectPtr py_resource_id( | |||
4151 | PyObject_GetAttrString(itr_res.get(), "_id")); | |||
4152 | if (py_resource_id == nullptr) { | |||
4153 | return tensorflow::errors::InvalidArgument( | |||
4154 | "Error while reading Dataset iterator resouce id."); | |||
4155 | } | |||
4156 | int resource_id = PyLong_AsSize_t(py_resource_id.get()); | |||
4157 | if (resource_id < 0) { | |||
4158 | return tensorflow::errors::InvalidArgument("PyLong_AsSize_t failure"); | |||
4159 | } | |||
4160 | int local_id = context.resource_id_map.AddResourceId(resource_id); | |||
4161 | absl::StrAppend(&context.result.str, local_id, kLocalIdDelim); | |||
4162 | } else { | |||
4163 | // If '_iterator_resource' is not a Tensor, there is no resource id. | |||
4164 | // Instead we treat it the same way as a CompositeTensor | |||
4165 | absl::StrAppend(&context.result.str, kCompositeTensor); | |||
4166 | } | |||
4167 | return tensorflow::Status::OK(); | |||
4168 | } | |||
4169 | ||||
4170 | tensorflow::Status EncodeResource(PyObject* arg, EncodingContext& context) { | |||
4171 | absl::StrAppend(&context.result.str, kResourceVariable); | |||
4172 | tensorflow::Safe_PyObjectPtr py_resource_id( | |||
4173 | PyObject_CallMethod(arg, "__tf_resource_id__", nullptr)); | |||
4174 | DCHECK(py_resource_id != nullptr)while (false && (py_resource_id != nullptr)) ::tensorflow ::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 4174); | |||
4175 | ||||
4176 | int resource_id = PyLong_AsSize_t(py_resource_id.get()); | |||
4177 | DCHECK_GE(resource_id, 0)while (false && ((void)(resource_id), (void)(0), 0)) :: tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 4177); | |||
4178 | int local_id = context.resource_id_map.AddResourceId(resource_id); | |||
4179 | absl::StrAppend(&context.result.str, local_id, kLocalIdDelim); | |||
4180 | ||||
4181 | tensorflow::Safe_PyObjectPtr type_spec( | |||
4182 | PyObject_CallMethod(arg, "__tf_function_cache_spec__", nullptr)); | |||
4183 | absl::StrAppend(&context.result.str, PyUnicode_AsUTF8(type_spec.get())); | |||
4184 | DCHECK(type_spec != nullptr)while (false && (type_spec != nullptr)) ::tensorflow:: internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 4184); | |||
4185 | ||||
4186 | return tensorflow::Status::OK(); | |||
4187 | } | |||
4188 | ||||
4189 | tensorflow::Status EncodeArgHelperInternal(PyObject* arg, | |||
4190 | EncodingContext& context); | |||
4191 | ||||
4192 | // This function doesn't set the type of sequence before | |||
4193 | tensorflow::Status EncodeSequence(PyObject* arg, const char* type, | |||
4194 | const char* end_type, | |||
4195 | EncodingContext& context) { | |||
4196 | tensorflow::Safe_PyObjectPtr arg_seq( | |||
4197 | PySequence_Fast(arg, "unable to create seq from list/tuple")); | |||
4198 | ||||
4199 | absl::StrAppend(&context.result.str, type); | |||
4200 | int len = PySequence_Fast_GET_SIZE(arg_seq.get())(((((((PyObject*)(arg_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(arg_seq.get()))->ob_size)) : (((PyVarObject *)(((static_cast<void> (0)), (PyTupleObject *)(arg_seq. get()))))->ob_size)); | |||
4201 | PyObject** arg_seq_array = PySequence_Fast_ITEMS(arg_seq.get())(((((((PyObject*)(arg_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(arg_seq.get() ))->ob_item : ((PyTupleObject *)(arg_seq.get()))->ob_item ); | |||
4202 | for (int i = 0; i < len; ++i) { | |||
4203 | PyObject* item = arg_seq_array[i]; | |||
4204 | if (item == Py_None(&_Py_NoneStruct)) { | |||
4205 | absl::StrAppend(&context.result.str, kNone); | |||
4206 | } else { | |||
4207 | TF_RETURN_IF_ERROR(EncodeArgHelperInternal(item, context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal( item, context)); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | |||
4208 | } | |||
4209 | } | |||
4210 | absl::StrAppend(&context.result.str, end_type); | |||
4211 | ||||
4212 | return tensorflow::Status::OK(); | |||
4213 | } | |||
4214 | ||||
4215 | tensorflow::Status EncodeMapping(PyObject* arg, EncodingContext& context) { | |||
4216 | tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg)); | |||
4217 | if (PyList_Sort(keys.get()) == -1) { | |||
4218 | return tensorflow::errors::Internal("Unable to sort keys"); | |||
4219 | } | |||
4220 | ||||
4221 | absl::StrAppend(&context.result.str, kDict); | |||
4222 | int len = PyList_Size(keys.get()); | |||
4223 | ||||
4224 | for (int i = 0; i < len; i++) { | |||
4225 | PyObject* key = PyList_GetItem(keys.get(), i); | |||
4226 | TF_RETURN_IF_ERROR(EncodeArgHelperInternal(key, context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal( key, context)); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | |||
4227 | tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key)); | |||
4228 | TF_RETURN_IF_ERROR(EncodeArgHelperInternal(value.get(), context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal( value.get(), context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0); | |||
4229 | } | |||
4230 | ||||
4231 | return tensorflow::Status::OK(); | |||
4232 | } | |||
4233 | ||||
4234 | tensorflow::Status EncodeCompositeTensor(PyObject* arg, | |||
4235 | EncodingContext& context) { | |||
4236 | absl::StrAppend(&context.result.str, kCompositeTensor); | |||
4237 | PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec")); | |||
4238 | if (type_spec == nullptr) { | |||
4239 | return tensorflow::errors::InvalidArgument( | |||
4240 | "Error while reading CompositeTensor._type_spec."); | |||
4241 | } | |||
4242 | context.result.objects.push_back(type_spec); | |||
4243 | ||||
4244 | return tensorflow::Status::OK(); | |||
4245 | } | |||
4246 | ||||
4247 | tensorflow::Status EncodeTypeSpec(PyObject* arg, EncodingContext& context) { | |||
4248 | absl::StrAppend(&context.result.str, kRaw); | |||
4249 | Py_INCREF(arg)_Py_INCREF(((PyObject*)(arg))); | |||
4250 | context.result.objects.push_back(arg); | |||
4251 | return tensorflow::Status::OK(); | |||
4252 | } | |||
4253 | ||||
4254 | tensorflow::Status EncodeAttrs(PyObject* arg, EncodingContext& context) { | |||
4255 | absl::StrAppend(&context.result.str, kAttrs); | |||
4256 | tensorflow::Safe_PyObjectPtr attrs( | |||
4257 | PyObject_GetAttrString(arg, "__attrs_attrs__")); | |||
4258 | tensorflow::Safe_PyObjectPtr iter(PyObject_GetIter(attrs.get())); | |||
4259 | for (tensorflow::Safe_PyObjectPtr item(PyIter_Next(iter.get())); item; | |||
4260 | item.reset(PyIter_Next(iter.get()))) { | |||
| ||||
4261 | tensorflow::Safe_PyObjectPtr name( | |||
4262 | PyObject_GetAttrString(item.get(), "name")); | |||
4263 | tensorflow::Safe_PyObjectPtr attr_arg(PyObject_GetAttr(arg, name.get())); | |||
4264 | TF_RETURN_IF_ERROR(EncodeArgHelperInternal(attr_arg.get(), context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal( attr_arg.get(), context)); if ((__builtin_expect(!_status.ok( ), 0))) return _status; } while (0); | |||
4265 | } | |||
4266 | absl::StrAppend(&context.result.str, kAttrsEnd); | |||
4267 | ||||
4268 | return tensorflow::Status::OK(); | |||
4269 | } | |||
4270 | ||||
4271 | tensorflow::Status EncodeUnidentified(PyObject* arg, EncodingContext& context) { | |||
4272 | // We hold a weak reference because cache keys live practically forever, and | |||
4273 | // this may leak heavy objects. | |||
4274 | PyObject* object = PyWeakref_NewRef(arg, nullptr); | |||
4275 | if (object == nullptr) { | |||
4276 | PyErr_Clear(); | |||
4277 | object = arg; | |||
4278 | Py_INCREF(object)_Py_INCREF(((PyObject*)(object))); | |||
4279 | } | |||
4280 | ||||
4281 | absl::StrAppend(&context.result.str, kRaw); | |||
4282 | context.result.objects.push_back(object); | |||
4283 | return tensorflow::Status::OK(); | |||
4284 | } | |||
4285 | ||||
4286 | tensorflow::Status EncodeArgHelperInternal(PyObject* arg, | |||
4287 | EncodingContext& context) { | |||
4288 | if (tensorflow::swig::IsTensorSpec(arg)) { | |||
4289 | TF_RETURN_IF_ERROR(EncodeTensorOrTensorSpec(arg, true, context))do { ::tensorflow::Status _status = (EncodeTensorOrTensorSpec (arg, true, context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0); | |||
4290 | } else if (tensorflow::swig::IsTensor(arg)) { | |||
4291 | TF_RETURN_IF_ERROR(EncodeTensorOrTensorSpec(arg, false, context))do { ::tensorflow::Status _status = (EncodeTensorOrTensorSpec (arg, false, context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0); | |||
4292 | } else if (tensorflow::swig::IsOwnedIterator(arg)) { | |||
4293 | TF_RETURN_IF_ERROR(EncodeOwnedIterator(arg, context))do { ::tensorflow::Status _status = (EncodeOwnedIterator(arg, context)); if ((__builtin_expect(!_status.ok(), 0))) return _status ; } while (0); | |||
4294 | } else if (PyList_Check(arg)((((((PyObject*)(arg))->ob_type))->tp_flags & ((1UL << 25))) != 0)) { | |||
4295 | TF_RETURN_IF_ERROR(EncodeSequence(arg, kList, kListEnd, context))do { ::tensorflow::Status _status = (EncodeSequence(arg, kList , kListEnd, context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0); | |||
4296 | } else if (tensorflow::swig::IsTuple(arg)) { | |||
4297 | TF_RETURN_IF_ERROR(EncodeSequence(arg, kTuple, kTupleEnd, context))do { ::tensorflow::Status _status = (EncodeSequence(arg, kTuple , kTupleEnd, context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0); | |||
4298 | } else if (tensorflow::swig::IsMapping(arg)) { | |||
4299 | TF_RETURN_IF_ERROR(EncodeMapping(arg, context))do { ::tensorflow::Status _status = (EncodeMapping(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | |||
4300 | } else if (tensorflow::swig::IsCompositeTensor(arg)) { | |||
4301 | TF_RETURN_IF_ERROR(EncodeCompositeTensor(arg, context))do { ::tensorflow::Status _status = (EncodeCompositeTensor(arg , context)); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | |||
4302 | } else if (tensorflow::swig::IsTypeSpec(arg)) { | |||
4303 | TF_RETURN_IF_ERROR(EncodeTypeSpec(arg, context))do { ::tensorflow::Status _status = (EncodeTypeSpec(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | |||
4304 | } else if (tensorflow::swig::IsAttrs(arg)) { | |||
4305 | TF_RETURN_IF_ERROR(EncodeAttrs(arg, context))do { ::tensorflow::Status _status = (EncodeAttrs(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | |||
4306 | } else if (tensorflow::swig::IsResourceVariable(arg) && | |||
4307 | context.encode_variable_by_resource_id) { | |||
4308 | TF_RETURN_IF_ERROR(EncodeResource(arg, context))do { ::tensorflow::Status _status = (EncodeResource(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | |||
4309 | } else { | |||
4310 | TF_RETURN_IF_ERROR(EncodeUnidentified(arg, context))do { ::tensorflow::Status _status = (EncodeUnidentified(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | |||
4311 | } | |||
4312 | ||||
4313 | return tensorflow::Status::OK(); | |||
4314 | } | |||
4315 | ||||
4316 | tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, | |||
4317 | EncodingContext& context) { | |||
4318 | auto status = EncodeArgHelperInternal(arg, context); | |||
4319 | return status; | |||
4320 | } | |||
4321 | ||||
4322 | } // namespace | |||
4323 | ||||
4324 | // `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes | |||
4325 | // are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes | |||
4326 | // are used for both performance reasons, as much TensorFlow code specializes | |||
4327 | // on known shapes to produce slimmer graphs, and correctness, as some | |||
4328 | // high-level APIs require shapes to be fully-known. | |||
4329 | // | |||
4330 | // `include_tensor_ranks_only` allows caching on arguments excluding shape info, | |||
4331 | // so that a slow path using relaxed shape can rely on a cache key that excludes | |||
4332 | // shapes. | |||
4333 | PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only, | |||
4334 | bool encode_variable_by_resource_id) { | |||
4335 | EncodingContext context; | |||
4336 | context.include_tensor_ranks_only = include_tensor_ranks_only; | |||
4337 | context.encode_variable_by_resource_id = encode_variable_by_resource_id; | |||
4338 | const auto status = TFE_Py_EncodeArgHelper(arg, context); | |||
| ||||
4339 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { | |||
4340 | return nullptr; | |||
4341 | } | |||
4342 | ||||
4343 | return context.result.ToPyTuple(); | |||
4344 | } | |||
4345 | ||||
4346 | // A method prints incoming messages directly to Python's | |||
4347 | // stdout using Python's C API. This is necessary in Jupyter notebooks | |||
4348 | // and colabs where messages to the C stdout don't go to the notebook | |||
4349 | // cell outputs, but calls to Python's stdout do. | |||
4350 | void PrintToPythonStdout(const char* msg) { | |||
4351 | if (Py_IsInitialized()) { | |||
4352 | PyGILState_STATE py_threadstate; | |||
4353 | py_threadstate = PyGILState_Ensure(); | |||
4354 | ||||
4355 | string string_msg = msg; | |||
4356 | // PySys_WriteStdout truncates strings over 1000 bytes, so | |||
4357 | // we write the message in chunks small enough to not be truncated. | |||
4358 | int CHUNK_SIZE = 900; | |||
4359 | auto len = string_msg.length(); | |||
4360 | for (int i = 0; i < len; i += CHUNK_SIZE) { | |||
4361 | PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str()); | |||
4362 | } | |||
4363 | ||||
4364 | // Force flushing to make sure print newlines aren't interleaved in | |||
4365 | // some colab environments | |||
4366 | PyRun_SimpleString("import sys; sys.stdout.flush()")PyRun_SimpleStringFlags("import sys; sys.stdout.flush()", __null ); | |||
4367 | ||||
4368 | PyGILState_Release(py_threadstate); | |||
4369 | } | |||
4370 | } | |||
4371 | ||||
4372 | // Register PrintToPythonStdout as a log listener, to allow | |||
4373 | // printing in colabs and jupyter notebooks to work. | |||
4374 | void TFE_Py_EnableInteractivePythonLogging() { | |||
4375 | static bool enabled_interactive_logging = false; | |||
4376 | if (!enabled_interactive_logging) { | |||
4377 | enabled_interactive_logging = true; | |||
4378 | TF_RegisterLogListener(PrintToPythonStdout); | |||
4379 | } | |||
4380 | } | |||
4381 | ||||
4382 | namespace { | |||
4383 | // weak reference to Python Context object currently active | |||
4384 | PyObject* weak_eager_context = nullptr; | |||
4385 | } // namespace | |||
4386 | ||||
4387 | PyObject* TFE_Py_SetEagerContext(PyObject* py_context) { | |||
4388 | Py_XDECREF(weak_eager_context)_Py_XDECREF(((PyObject*)(weak_eager_context))); | |||
4389 | weak_eager_context = PyWeakref_NewRef(py_context, nullptr); | |||
4390 | if (weak_eager_context == nullptr) { | |||
4391 | return nullptr; | |||
4392 | } | |||
4393 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
4394 | } | |||
4395 | ||||
4396 | PyObject* GetPyEagerContext() { | |||
4397 | if (weak_eager_context == nullptr) { | |||
4398 | PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set"); | |||
4399 | return nullptr; | |||
4400 | } | |||
4401 | PyObject* py_context = PyWeakref_GET_OBJECT(weak_eager_context)((((PyObject*)(((PyWeakReference *)(weak_eager_context))-> wr_object))->ob_refcnt) > 0 ? ((PyWeakReference *)(weak_eager_context ))->wr_object : (&_Py_NoneStruct)); | |||
4402 | if (py_context == Py_None(&_Py_NoneStruct)) { | |||
4403 | PyErr_SetString(PyExc_RuntimeError, "Eager context has been destroyed"); | |||
4404 | return nullptr; | |||
4405 | } | |||
4406 | Py_INCREF(py_context)_Py_INCREF(((PyObject*)(py_context))); | |||
4407 | return py_context; | |||
4408 | } | |||
4409 | ||||
4410 | namespace { | |||
4411 | ||||
4412 | // Default values for thread_local_data fields. | |||
4413 | struct EagerContextThreadLocalDataDefaults { | |||
4414 | tensorflow::Safe_PyObjectPtr is_eager; | |||
4415 | tensorflow::Safe_PyObjectPtr device_spec; | |||
4416 | }; | |||
4417 | ||||
4418 | // Maps each py_eager_context object to its thread_local_data. | |||
4419 | // | |||
4420 | // Note: we need to use the python Context object as the key here (and not | |||
4421 | // its handle object), because the handle object isn't created until the | |||
4422 | // context is initialized; but thread_local_data is potentially accessed | |||
4423 | // before then. | |||
4424 | using EagerContextThreadLocalDataMap = absl::flat_hash_map< | |||
4425 | PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>; | |||
4426 | thread_local EagerContextThreadLocalDataMap* | |||
4427 | eager_context_thread_local_data_map = nullptr; | |||
4428 | ||||
4429 | // Maps each py_eager_context object to default values. | |||
4430 | using EagerContextThreadLocalDataDefaultsMap = | |||
4431 | absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>; | |||
4432 | EagerContextThreadLocalDataDefaultsMap* | |||
4433 | eager_context_thread_local_data_defaults = nullptr; | |||
4434 | ||||
4435 | } // namespace | |||
4436 | ||||
4437 | namespace tensorflow { | |||
4438 | ||||
4439 | void MakeEagerContextThreadLocalData(PyObject* py_eager_context, | |||
4440 | PyObject* is_eager, | |||
4441 | PyObject* device_spec) { | |||
4442 | DCheckPyGilState(); | |||
4443 | if (eager_context_thread_local_data_defaults == nullptr) { | |||
4444 | absl::LeakCheckDisabler disabler; | |||
4445 | eager_context_thread_local_data_defaults = | |||
4446 | new EagerContextThreadLocalDataDefaultsMap(); | |||
4447 | } | |||
4448 | if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) { | |||
4449 | PyErr_SetString(PyExc_AssertionError, | |||
4450 | "MakeEagerContextThreadLocalData may not be called " | |||
4451 | "twice on the same eager Context object."); | |||
4452 | } | |||
4453 | ||||
4454 | auto& defaults = | |||
4455 | (*eager_context_thread_local_data_defaults)[py_eager_context]; | |||
4456 | Py_INCREF(is_eager)_Py_INCREF(((PyObject*)(is_eager))); | |||
4457 | defaults.is_eager.reset(is_eager); | |||
4458 | Py_INCREF(device_spec)_Py_INCREF(((PyObject*)(device_spec))); | |||
4459 | defaults.device_spec.reset(device_spec); | |||
4460 | } | |||
4461 | ||||
4462 | EagerContextThreadLocalData* GetEagerContextThreadLocalData( | |||
4463 | PyObject* py_eager_context) { | |||
4464 | if (eager_context_thread_local_data_defaults == nullptr) { | |||
4465 | PyErr_SetString(PyExc_AssertionError, | |||
4466 | "MakeEagerContextThreadLocalData must be called " | |||
4467 | "before GetEagerContextThreadLocalData."); | |||
4468 | return nullptr; | |||
4469 | } | |||
4470 | auto defaults = | |||
4471 | eager_context_thread_local_data_defaults->find(py_eager_context); | |||
4472 | if (defaults == eager_context_thread_local_data_defaults->end()) { | |||
4473 | PyErr_SetString(PyExc_AssertionError, | |||
4474 | "MakeEagerContextThreadLocalData must be called " | |||
4475 | "before GetEagerContextThreadLocalData."); | |||
4476 | return nullptr; | |||
4477 | } | |||
4478 | ||||
4479 | if (eager_context_thread_local_data_map == nullptr) { | |||
4480 | absl::LeakCheckDisabler disabler; | |||
4481 | eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap(); | |||
4482 | } | |||
4483 | auto& thread_local_data = | |||
4484 | (*eager_context_thread_local_data_map)[py_eager_context]; | |||
4485 | ||||
4486 | if (!thread_local_data) { | |||
4487 | thread_local_data.reset(new EagerContextThreadLocalData()); | |||
4488 | ||||
4489 | Safe_PyObjectPtr is_eager( | |||
4490 | PyObject_CallFunctionObjArgs(defaults->second.is_eager.get(), nullptr)); | |||
4491 | if (!is_eager) return nullptr; | |||
4492 | thread_local_data->is_eager = PyObject_IsTrue(is_eager.get()); | |||
4493 | ||||
4494 | #if PY_MAJOR_VERSION3 >= 3 | |||
4495 | PyObject* scope_name = PyUnicode_FromString(""); | |||
4496 | #else | |||
4497 | PyObject* scope_name = PyString_FromString(""); | |||
4498 | #endif | |||
4499 | thread_local_data->scope_name.reset(scope_name); | |||
4500 | ||||
4501 | #if PY_MAJOR_VERSION3 >= 3 | |||
4502 | PyObject* device_name = PyUnicode_FromString(""); | |||
4503 | #else | |||
4504 | PyObject* device_name = PyString_FromString(""); | |||
4505 | #endif | |||
4506 | thread_local_data->device_name.reset(device_name); | |||
4507 | ||||
4508 | Py_INCREF(defaults->second.device_spec.get())_Py_INCREF(((PyObject*)(defaults->second.device_spec.get() ))); | |||
4509 | thread_local_data->device_spec.reset(defaults->second.device_spec.get()); | |||
4510 | ||||
4511 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | |||
4512 | thread_local_data->function_call_options.reset(Py_None(&_Py_NoneStruct)); | |||
4513 | ||||
4514 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | |||
4515 | thread_local_data->executor.reset(Py_None(&_Py_NoneStruct)); | |||
4516 | ||||
4517 | thread_local_data->op_callbacks.reset(PyList_New(0)); | |||
4518 | } | |||
4519 | return thread_local_data.get(); | |||
4520 | } | |||
4521 | ||||
4522 | void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) { | |||
4523 | DCheckPyGilState(); | |||
4524 | if (eager_context_thread_local_data_defaults) { | |||
4525 | eager_context_thread_local_data_defaults->erase(py_eager_context); | |||
4526 | } | |||
4527 | if (eager_context_thread_local_data_map) { | |||
4528 | eager_context_thread_local_data_map->erase(py_eager_context); | |||
4529 | } | |||
4530 | } | |||
4531 | ||||
4532 | } // namespace tensorflow |
1 | #ifndef PyIter_Next |
2 | struct _object; |
3 | typedef struct _object PyObject; |
4 | PyObject* clang_analyzer_PyObject_New_Reference(); |
5 | PyObject* PyIter_Next(PyObject *o) { |
6 | return clang_analyzer_PyObject_New_Reference(); |
7 | } |
8 | #else |
9 | #warning "API PyIter_Next is defined as a macro." |
10 | #endif |