File: | .cache/bazel/_bazel_alan/39be661231df2a680c9b74265384c13c/execroot/org_tensorflow/tensorflow/python/eager/pywrap_tfe_src.cc |
Warning: | line 4087, column 43 PyObject ownership leak with reference count of 1 |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||||
2 | |||||
3 | Licensed under the Apache License, Version 2.0 (the "License"); | ||||
4 | you may not use this file except in compliance with the License. | ||||
5 | You may obtain a copy of the License at | ||||
6 | |||||
7 | http://www.apache.org/licenses/LICENSE-2.0 | ||||
8 | |||||
9 | Unless required by applicable law or agreed to in writing, software | ||||
10 | distributed under the License is distributed on an "AS IS" BASIS, | ||||
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
12 | See the License for the specific language governing permissions and | ||||
13 | limitations under the License. | ||||
14 | ==============================================================================*/ | ||||
15 | |||||
16 | #include <atomic> | ||||
17 | #include <cstring> | ||||
18 | #include <unordered_map> | ||||
19 | |||||
20 | #include "absl/debugging/leak_check.h" | ||||
21 | #include "absl/strings/str_cat.h" | ||||
22 | #include "absl/types/variant.h" | ||||
23 | #include "tensorflow/c/c_api.h" | ||||
24 | #include "tensorflow/c/c_api_internal.h" | ||||
25 | #include "tensorflow/c/eager/c_api.h" | ||||
26 | #include "tensorflow/c/eager/c_api_internal.h" | ||||
27 | #include "tensorflow/c/eager/tape.h" | ||||
28 | #include "tensorflow/c/eager/tfe_context_internal.h" | ||||
29 | #include "tensorflow/c/eager/tfe_op_internal.h" | ||||
30 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" | ||||
31 | #include "tensorflow/c/tf_status.h" | ||||
32 | #include "tensorflow/core/framework/types.pb.h" | ||||
33 | #include "tensorflow/core/lib/core/errors.h" | ||||
34 | #include "tensorflow/core/lib/gtl/cleanup.h" | ||||
35 | #include "tensorflow/core/lib/gtl/compactptrset.h" | ||||
36 | #include "tensorflow/core/lib/gtl/flatmap.h" | ||||
37 | #include "tensorflow/core/lib/gtl/flatset.h" | ||||
38 | #include "tensorflow/core/lib/strings/strcat.h" | ||||
39 | #include "tensorflow/core/lib/strings/stringprintf.h" | ||||
40 | #include "tensorflow/core/platform/casts.h" | ||||
41 | #include "tensorflow/core/platform/errors.h" | ||||
42 | #include "tensorflow/core/platform/mutex.h" | ||||
43 | #include "tensorflow/core/platform/protobuf.h" | ||||
44 | #include "tensorflow/core/platform/status.h" | ||||
45 | #include "tensorflow/core/platform/types.h" | ||||
46 | #include "tensorflow/core/profiler/lib/traceme.h" | ||||
47 | #include "tensorflow/core/util/managed_stack_trace.h" | ||||
48 | #include "tensorflow/python/eager/pywrap_gradient_exclusions.h" | ||||
49 | #include "tensorflow/python/eager/pywrap_tensor.h" | ||||
50 | #include "tensorflow/python/eager/pywrap_tfe.h" | ||||
51 | #include "tensorflow/python/lib/core/py_util.h" | ||||
52 | #include "tensorflow/python/lib/core/safe_ptr.h" | ||||
53 | #include "tensorflow/python/util/stack_trace.h" | ||||
54 | #include "tensorflow/python/util/util.h" | ||||
55 | |||||
56 | using tensorflow::Status; | ||||
57 | using tensorflow::string; | ||||
58 | using tensorflow::strings::Printf; | ||||
59 | |||||
60 | namespace { | ||||
61 | // NOTE: Items are retrieved from and returned to these unique_ptrs, and they | ||||
62 | // act as arenas. This is important if the same thread requests 2 items without | ||||
63 | // releasing one. | ||||
64 | // The following sequence of events on the same thread will still succeed: | ||||
65 | // - GetOp <- Returns existing. | ||||
66 | // - GetOp <- Allocates and returns a new pointer. | ||||
67 | // - ReleaseOp <- Sets the item in the unique_ptr. | ||||
68 | // - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one. | ||||
69 | // This occurs when a PyFunc kernel is run. This behavior makes it safe in that | ||||
70 | // case, as well as the case where python decides to reuse the underlying | ||||
71 | // C++ thread in 2 python threads case. | ||||
72 | struct OpDeleter { | ||||
73 | void operator()(TFE_Op* op) const { TFE_DeleteOp(op); } | ||||
74 | }; | ||||
75 | thread_local std::unordered_map<TFE_Context*, | ||||
76 | std::unique_ptr<TFE_Op, OpDeleter>> | ||||
77 | thread_local_eager_operation_map; // NOLINT | ||||
78 | thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT | ||||
79 | nullptr; | ||||
80 | |||||
81 | std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) { | ||||
82 | auto it = thread_local_eager_operation_map.find(ctx); | ||||
83 | if (it == thread_local_eager_operation_map.end()) { | ||||
84 | return nullptr; | ||||
85 | } | ||||
86 | return std::move(it->second); | ||||
87 | } | ||||
88 | |||||
89 | TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, | ||||
90 | const char* raw_device_name, TF_Status* status) { | ||||
91 | auto op = ReleaseThreadLocalOp(ctx); | ||||
92 | if (!op) { | ||||
93 | op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation())); | ||||
94 | } | ||||
95 | status->status = | ||||
96 | tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name); | ||||
97 | if (!status->status.ok()) { | ||||
98 | op.reset(); | ||||
99 | } | ||||
100 | return op.release(); | ||||
101 | } | ||||
102 | |||||
103 | void ReturnOp(TFE_Context* ctx, TFE_Op* op) { | ||||
104 | if (op) { | ||||
105 | tensorflow::unwrap(op)->Clear(); | ||||
106 | thread_local_eager_operation_map[ctx].reset(op); | ||||
107 | } | ||||
108 | } | ||||
109 | |||||
110 | TF_Status* ReleaseThreadLocalStatus() { | ||||
111 | if (thread_local_tf_status == nullptr) { | ||||
112 | return nullptr; | ||||
113 | } | ||||
114 | return thread_local_tf_status.release(); | ||||
115 | } | ||||
116 | |||||
117 | struct InputInfo { | ||||
118 | InputInfo(int i, bool is_list) : i(i), is_list(is_list) {} | ||||
119 | |||||
120 | int i; | ||||
121 | bool is_list = false; | ||||
122 | }; | ||||
123 | |||||
124 | // Takes in output gradients, returns input gradients. | ||||
125 | typedef std::function<PyObject*(PyObject*, const std::vector<int64_t>&)> | ||||
126 | PyBackwardFunction; | ||||
127 | |||||
128 | using AttrToInputsMap = | ||||
129 | tensorflow::gtl::FlatMap<string, | ||||
130 | tensorflow::gtl::InlinedVector<InputInfo, 4>>; | ||||
131 | |||||
132 | tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() { | ||||
133 | static auto* all_attr_to_input_maps = | ||||
134 | new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>; | ||||
135 | return all_attr_to_input_maps; | ||||
136 | } | ||||
137 | |||||
138 | // This function doesn't use a lock, since we depend on the GIL directly. | ||||
139 | AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) { | ||||
140 | #if PY_MAJOR_VERSION3 >= 3 && PY_MINOR_VERSION8 >= 4 | ||||
141 | DCHECK(PyGILState_Check())while (false && (PyGILState_Check())) ::tensorflow::internal ::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 141) | ||||
142 | << "This function needs to hold the GIL when called."; | ||||
143 | #endif | ||||
144 | auto* all_attr_to_input_maps = GetAllAttrToInputsMaps(); | ||||
145 | auto* output = | ||||
146 | tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name()); | ||||
147 | if (output != nullptr) { | ||||
148 | return output; | ||||
149 | } | ||||
150 | |||||
151 | std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap); | ||||
152 | |||||
153 | // Store a list of InputIndex -> List of corresponding inputs. | ||||
154 | for (int i = 0; i < op_def.input_arg_size(); i++) { | ||||
155 | if (!op_def.input_arg(i).type_attr().empty()) { | ||||
156 | auto it = m->find(op_def.input_arg(i).type_attr()); | ||||
157 | if (it == m->end()) { | ||||
158 | it = m->insert({op_def.input_arg(i).type_attr(), {}}).first; | ||||
159 | } | ||||
160 | it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty()); | ||||
161 | } | ||||
162 | } | ||||
163 | |||||
164 | auto* retval = m.get(); | ||||
165 | (*all_attr_to_input_maps)[op_def.name()] = m.release(); | ||||
166 | |||||
167 | return retval; | ||||
168 | } | ||||
169 | |||||
170 | // This function doesn't use a lock, since we depend on the GIL directly. | ||||
171 | tensorflow::gtl::FlatMap< | ||||
172 | string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>* | ||||
173 | GetAllAttrToDefaultsMaps() { | ||||
174 | static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap< | ||||
175 | string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>; | ||||
176 | return all_attr_to_defaults_maps; | ||||
177 | } | ||||
178 | |||||
179 | tensorflow::gtl::FlatMap<string, tensorflow::DataType>* | ||||
180 | GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) { | ||||
181 | #if PY_MAJOR_VERSION3 >= 3 && PY_MINOR_VERSION8 >= 4 | ||||
182 | DCHECK(PyGILState_Check())while (false && (PyGILState_Check())) ::tensorflow::internal ::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 182) | ||||
183 | << "This function needs to hold the GIL when called."; | ||||
184 | #endif | ||||
185 | auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps(); | ||||
186 | auto* output = | ||||
187 | tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name()); | ||||
188 | if (output != nullptr) { | ||||
189 | return output; | ||||
190 | } | ||||
191 | |||||
192 | auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>; | ||||
193 | |||||
194 | for (const auto& attr : op_def.attr()) { | ||||
195 | if (attr.type() == "type" && attr.has_default_value()) { | ||||
196 | new_map->insert({attr.name(), attr.default_value().type()}); | ||||
197 | } | ||||
198 | } | ||||
199 | |||||
200 | (*all_attr_to_defaults_maps)[op_def.name()] = new_map; | ||||
201 | |||||
202 | return new_map; | ||||
203 | } | ||||
204 | |||||
205 | struct FastPathOpExecInfo { | ||||
206 | TFE_Context* ctx; | ||||
207 | const char* device_name; | ||||
208 | |||||
209 | bool run_callbacks; | ||||
210 | bool run_post_exec_callbacks; | ||||
211 | bool run_gradient_callback; | ||||
212 | |||||
213 | // The op name of the main op being executed. | ||||
214 | PyObject* name; | ||||
215 | // The op type name of the main op being executed. | ||||
216 | PyObject* op_name; | ||||
217 | PyObject* callbacks; | ||||
218 | |||||
219 | // All the args passed into the FastPathOpExecInfo. | ||||
220 | PyObject* args; | ||||
221 | |||||
222 | // DTypes can come from another input that has the same attr. So build that | ||||
223 | // map. | ||||
224 | const AttrToInputsMap* attr_to_inputs_map; | ||||
225 | const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes; | ||||
226 | tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes; | ||||
227 | }; | ||||
228 | |||||
229 | #define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \ | ||||
230 | bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \ | ||||
231 | type* value) { \ | ||||
232 | if (check_fn(py_value)) { \ | ||||
233 | *value = static_cast<type>(parse_fn(py_value)); \ | ||||
234 | return true; \ | ||||
235 | } else { \ | ||||
236 | TF_SetStatus(status, TF_INVALID_ARGUMENT, \ | ||||
237 | tensorflow::strings::StrCat( \ | ||||
238 | "Expecting " #type " value for attr ", key, ", got ", \ | ||||
239 | py_value->ob_type->tp_name) \ | ||||
240 | .c_str()); \ | ||||
241 | return false; \ | ||||
242 | } \ | ||||
243 | } | ||||
244 | |||||
245 | #if PY_MAJOR_VERSION3 >= 3 | ||||
246 | PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong) | ||||
247 | PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong) | ||||
248 | #else | ||||
249 | PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong) | ||||
250 | #endif | ||||
251 | PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble) | ||||
252 | #undef PARSE_VALUE | ||||
253 | |||||
254 | #if PY_MAJOR_VERSION3 < 3 | ||||
255 | bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status, | ||||
256 | int64_t* value) { | ||||
257 | if (PyInt_Check(py_value)) { | ||||
258 | *value = static_cast<int64_t>(PyInt_AsLong(py_value)); | ||||
259 | return true; | ||||
260 | } else if (PyLong_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 24))) != 0)) { | ||||
261 | *value = static_cast<int64_t>(PyLong_AsLong(py_value)); | ||||
262 | return true; | ||||
263 | } | ||||
264 | TF_SetStatus( | ||||
265 | status, TF_INVALID_ARGUMENT, | ||||
266 | tensorflow::strings::StrCat("Expecting int or long value for attr ", key, | ||||
267 | ", got ", py_value->ob_type->tp_name) | ||||
268 | .c_str()); | ||||
269 | return false; | ||||
270 | } | ||||
271 | #endif | ||||
272 | |||||
273 | Py_ssize_t TensorShapeNumDims(PyObject* value) { | ||||
274 | const auto size = PySequence_Size(value); | ||||
275 | if (size == -1) { | ||||
276 | // TensorShape.__len__ raises an error in the scenario where the shape is an | ||||
277 | // unknown, which needs to be cleared. | ||||
278 | // TODO(nareshmodi): ensure that this is actually a TensorShape. | ||||
279 | PyErr_Clear(); | ||||
280 | } | ||||
281 | return size; | ||||
282 | } | ||||
283 | |||||
284 | bool IsInteger(PyObject* py_value) { | ||||
285 | #if PY_MAJOR_VERSION3 >= 3 | ||||
286 | return PyLong_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 24))) != 0); | ||||
287 | #else | ||||
288 | return PyInt_Check(py_value) || PyLong_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 24))) != 0); | ||||
289 | #endif | ||||
290 | } | ||||
291 | |||||
292 | // This function considers a Dimension._value of None to be valid, and sets the | ||||
293 | // value to be -1 in that case. | ||||
294 | bool ParseDimensionValue(const string& key, PyObject* py_value, | ||||
295 | TF_Status* status, int64_t* value) { | ||||
296 | if (IsInteger(py_value)) { | ||||
297 | return ParseInt64Value(key, py_value, status, value); | ||||
298 | } | ||||
299 | |||||
300 | tensorflow::Safe_PyObjectPtr dimension_value( | ||||
301 | PyObject_GetAttrString(py_value, "_value")); | ||||
302 | if (dimension_value == nullptr) { | ||||
303 | PyErr_Clear(); | ||||
304 | TF_SetStatus( | ||||
305 | status, TF_INVALID_ARGUMENT, | ||||
306 | tensorflow::strings::StrCat("Expecting a Dimension for attr ", key, | ||||
307 | ", got ", py_value->ob_type->tp_name) | ||||
308 | .c_str()); | ||||
309 | return false; | ||||
310 | } | ||||
311 | |||||
312 | if (dimension_value.get() == Py_None(&_Py_NoneStruct)) { | ||||
313 | *value = -1; | ||||
314 | return true; | ||||
315 | } | ||||
316 | |||||
317 | return ParseInt64Value(key, dimension_value.get(), status, value); | ||||
318 | } | ||||
319 | |||||
320 | bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status, | ||||
321 | tensorflow::StringPiece* value) { | ||||
322 | if (PyBytes_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 27))) != 0)) { | ||||
323 | Py_ssize_t size = 0; | ||||
324 | char* buf = nullptr; | ||||
325 | if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false; | ||||
326 | *value = tensorflow::StringPiece(buf, size); | ||||
327 | return true; | ||||
328 | } | ||||
329 | #if PY_MAJOR_VERSION3 >= 3 | ||||
330 | if (PyUnicode_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 28))) != 0)) { | ||||
331 | Py_ssize_t size = 0; | ||||
332 | const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size); | ||||
333 | if (buf == nullptr) return false; | ||||
334 | *value = tensorflow::StringPiece(buf, size); | ||||
335 | return true; | ||||
336 | } | ||||
337 | #endif | ||||
338 | TF_SetStatus( | ||||
339 | status, TF_INVALID_ARGUMENT, | ||||
340 | tensorflow::strings::StrCat("Expecting a string value for attr ", key, | ||||
341 | ", got ", py_value->ob_type->tp_name) | ||||
342 | .c_str()); | ||||
343 | return false; | ||||
344 | } | ||||
345 | |||||
346 | bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status, | ||||
347 | unsigned char* value) { | ||||
348 | *value = PyObject_IsTrue(py_value); | ||||
349 | return true; | ||||
350 | } | ||||
351 | |||||
352 | // The passed in py_value is expected to be an object of the python type | ||||
353 | // dtypes.DType or an int. | ||||
354 | bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status, | ||||
355 | int* value) { | ||||
356 | if (IsInteger(py_value)) { | ||||
357 | return ParseIntValue(key, py_value, status, value); | ||||
358 | } | ||||
359 | |||||
360 | tensorflow::Safe_PyObjectPtr py_type_enum( | ||||
361 | PyObject_GetAttrString(py_value, "_type_enum")); | ||||
362 | if (py_type_enum == nullptr) { | ||||
363 | PyErr_Clear(); | ||||
364 | TF_SetStatus( | ||||
365 | status, TF_INVALID_ARGUMENT, | ||||
366 | tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key, | ||||
367 | ", got ", py_value->ob_type->tp_name) | ||||
368 | .c_str()); | ||||
369 | return false; | ||||
370 | } | ||||
371 | |||||
372 | return ParseIntValue(key, py_type_enum.get(), status, value); | ||||
373 | } | ||||
374 | |||||
375 | bool SetOpAttrList(TFE_Context* ctx, TFE_Op* op, const char* key, | ||||
376 | PyObject* py_list, TF_AttrType type, | ||||
377 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | ||||
378 | TF_Status* status) { | ||||
379 | if (!PySequence_Check(py_list)) { | ||||
380 | TF_SetStatus( | ||||
381 | status, TF_INVALID_ARGUMENT, | ||||
382 | tensorflow::strings::StrCat("Expecting sequence value for attr ", key, | ||||
383 | ", got ", py_list->ob_type->tp_name) | ||||
384 | .c_str()); | ||||
385 | return false; | ||||
386 | } | ||||
387 | const int num_values = PySequence_Size(py_list); | ||||
388 | if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values; | ||||
389 | |||||
390 | #define PARSE_LIST(c_type, parse_fn) \ | ||||
391 | std::unique_ptr<c_type[]> values(new c_type[num_values]); \ | ||||
392 | for (int i = 0; i < num_values; ++i) { \ | ||||
393 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) )); \ | ||||
394 | if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \ | ||||
395 | } | ||||
396 | |||||
397 | if (type == TF_ATTR_STRING) { | ||||
398 | std::unique_ptr<const void*[]> values(new const void*[num_values]); | ||||
399 | std::unique_ptr<size_t[]> lengths(new size_t[num_values]); | ||||
400 | for (int i = 0; i < num_values; ++i) { | ||||
401 | tensorflow::StringPiece value; | ||||
402 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) )); | ||||
403 | if (!ParseStringValue(key, py_value.get(), status, &value)) return false; | ||||
404 | values[i] = value.data(); | ||||
405 | lengths[i] = value.size(); | ||||
406 | } | ||||
407 | TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); | ||||
408 | } else if (type == TF_ATTR_INT) { | ||||
409 | PARSE_LIST(int64_t, ParseInt64Value); | ||||
410 | TFE_OpSetAttrIntList(op, key, values.get(), num_values); | ||||
411 | } else if (type == TF_ATTR_FLOAT) { | ||||
412 | PARSE_LIST(float, ParseFloatValue); | ||||
413 | TFE_OpSetAttrFloatList(op, key, values.get(), num_values); | ||||
414 | } else if (type == TF_ATTR_BOOL) { | ||||
415 | PARSE_LIST(unsigned char, ParseBoolValue); | ||||
416 | TFE_OpSetAttrBoolList(op, key, values.get(), num_values); | ||||
417 | } else if (type == TF_ATTR_TYPE) { | ||||
418 | PARSE_LIST(int, ParseTypeValue); | ||||
419 | TFE_OpSetAttrTypeList(op, key, | ||||
420 | reinterpret_cast<const TF_DataType*>(values.get()), | ||||
421 | num_values); | ||||
422 | } else if (type == TF_ATTR_SHAPE) { | ||||
423 | // Make one pass through the input counting the total number of | ||||
424 | // dims across all the input lists. | ||||
425 | int total_dims = 0; | ||||
426 | for (int i = 0; i < num_values; ++i) { | ||||
427 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) )); | ||||
428 | if (py_value.get() != Py_None(&_Py_NoneStruct)) { | ||||
429 | if (!PySequence_Check(py_value.get())) { | ||||
430 | TF_SetStatus( | ||||
431 | status, TF_INVALID_ARGUMENT, | ||||
432 | tensorflow::strings::StrCat( | ||||
433 | "Expecting None or sequence value for element", i, | ||||
434 | " of attr ", key, ", got ", py_value->ob_type->tp_name) | ||||
435 | .c_str()); | ||||
436 | return false; | ||||
437 | } | ||||
438 | const auto size = TensorShapeNumDims(py_value.get()); | ||||
439 | if (size >= 0) { | ||||
440 | total_dims += size; | ||||
441 | } | ||||
442 | } | ||||
443 | } | ||||
444 | // Allocate a buffer that can fit all of the dims together. | ||||
445 | std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]); | ||||
446 | // Copy the input dims into the buffer and set dims to point to | ||||
447 | // the start of each list's dims. | ||||
448 | std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]); | ||||
449 | std::unique_ptr<int[]> num_dims(new int[num_values]); | ||||
450 | int64_t* offset = buffer.get(); | ||||
451 | for (int i = 0; i < num_values; ++i) { | ||||
452 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) )); | ||||
453 | if (py_value.get() == Py_None(&_Py_NoneStruct)) { | ||||
454 | dims[i] = nullptr; | ||||
455 | num_dims[i] = -1; | ||||
456 | } else { | ||||
457 | const auto size = TensorShapeNumDims(py_value.get()); | ||||
458 | if (size == -1) { | ||||
459 | dims[i] = nullptr; | ||||
460 | num_dims[i] = -1; | ||||
461 | continue; | ||||
462 | } | ||||
463 | dims[i] = offset; | ||||
464 | num_dims[i] = size; | ||||
465 | for (int j = 0; j < size; ++j) { | ||||
466 | tensorflow::Safe_PyObjectPtr inner_py_value( | ||||
467 | PySequence_ITEM(py_value.get(), j)( (((PyObject*)(py_value.get()))->ob_type)->tp_as_sequence ->sq_item(py_value.get(), j) )); | ||||
468 | if (inner_py_value.get() == Py_None(&_Py_NoneStruct)) { | ||||
469 | *offset = -1; | ||||
470 | } else if (!ParseDimensionValue(key, inner_py_value.get(), status, | ||||
471 | offset)) { | ||||
472 | return false; | ||||
473 | } | ||||
474 | ++offset; | ||||
475 | } | ||||
476 | } | ||||
477 | } | ||||
478 | TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, | ||||
479 | status); | ||||
480 | if (!status->status.ok()) return false; | ||||
481 | } else if (type == TF_ATTR_FUNC) { | ||||
482 | std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]); | ||||
483 | for (int i = 0; i < num_values; ++i) { | ||||
484 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) )); | ||||
485 | // Allow: | ||||
486 | // (1) String function name, OR | ||||
487 | // (2) A Python object with a .name attribute | ||||
488 | // (A crude test for being a | ||||
489 | // tensorflow.python.framework.function._DefinedFunction) | ||||
490 | // (which is what the various "defun" or "Defun" decorators do). | ||||
491 | // And in the future also allow an object that can encapsulate | ||||
492 | // the function name and its attribute values. | ||||
493 | tensorflow::StringPiece func_name; | ||||
494 | if (!ParseStringValue(key, py_value.get(), status, &func_name)) { | ||||
495 | PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name"); | ||||
496 | if (name_attr == nullptr || | ||||
497 | !ParseStringValue(key, name_attr, status, &func_name)) { | ||||
498 | TF_SetStatus( | ||||
499 | status, TF_INVALID_ARGUMENT, | ||||
500 | tensorflow::strings::StrCat( | ||||
501 | "unable to set function value attribute from a ", | ||||
502 | py_value.get()->ob_type->tp_name, | ||||
503 | " object. If you think this is an error, please file an " | ||||
504 | "issue at " | ||||
505 | "https://github.com/tensorflow/tensorflow/issues/new") | ||||
506 | .c_str()); | ||||
507 | return false; | ||||
508 | } | ||||
509 | } | ||||
510 | funcs[i] = TFE_NewOp(ctx, func_name.data(), status); | ||||
511 | if (!status->status.ok()) return false; | ||||
512 | } | ||||
513 | TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values); | ||||
514 | if (!status->status.ok()) return false; | ||||
515 | } else { | ||||
516 | TF_SetStatus(status, TF_UNIMPLEMENTED, | ||||
517 | tensorflow::strings::StrCat("Attr ", key, | ||||
518 | " has unhandled list type ", type) | ||||
519 | .c_str()); | ||||
520 | return false; | ||||
521 | } | ||||
522 | #undef PARSE_LIST | ||||
523 | return true; | ||||
524 | } | ||||
525 | |||||
526 | TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, | ||||
527 | TF_Status* status) { | ||||
528 | TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); | ||||
529 | for (const auto& attr : func.attr()) { | ||||
530 | if (!status->status.ok()) return nullptr; | ||||
531 | SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); | ||||
532 | if (!status->status.ok()) return nullptr; | ||||
533 | } | ||||
534 | return func_op; | ||||
535 | } | ||||
536 | |||||
537 | void SetOpAttrListDefault( | ||||
538 | TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, | ||||
539 | const char* key, TF_AttrType type, | ||||
540 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | ||||
541 | TF_Status* status) { | ||||
542 | if (type == TF_ATTR_STRING) { | ||||
543 | int num_values = attr.default_value().list().s_size(); | ||||
544 | std::unique_ptr<const void*[]> values(new const void*[num_values]); | ||||
545 | std::unique_ptr<size_t[]> lengths(new size_t[num_values]); | ||||
546 | (*attr_list_sizes)[key] = num_values; | ||||
547 | for (int i = 0; i < num_values; i++) { | ||||
548 | const string& v = attr.default_value().list().s(i); | ||||
549 | values[i] = v.data(); | ||||
550 | lengths[i] = v.size(); | ||||
551 | } | ||||
552 | TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); | ||||
553 | } else if (type == TF_ATTR_INT) { | ||||
554 | int num_values = attr.default_value().list().i_size(); | ||||
555 | std::unique_ptr<int64_t[]> values(new int64_t[num_values]); | ||||
556 | (*attr_list_sizes)[key] = num_values; | ||||
557 | for (int i = 0; i < num_values; i++) { | ||||
558 | values[i] = attr.default_value().list().i(i); | ||||
559 | } | ||||
560 | TFE_OpSetAttrIntList(op, key, values.get(), num_values); | ||||
561 | } else if (type == TF_ATTR_FLOAT) { | ||||
562 | int num_values = attr.default_value().list().f_size(); | ||||
563 | std::unique_ptr<float[]> values(new float[num_values]); | ||||
564 | (*attr_list_sizes)[key] = num_values; | ||||
565 | for (int i = 0; i < num_values; i++) { | ||||
566 | values[i] = attr.default_value().list().f(i); | ||||
567 | } | ||||
568 | TFE_OpSetAttrFloatList(op, key, values.get(), num_values); | ||||
569 | } else if (type == TF_ATTR_BOOL) { | ||||
570 | int num_values = attr.default_value().list().b_size(); | ||||
571 | std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]); | ||||
572 | (*attr_list_sizes)[key] = num_values; | ||||
573 | for (int i = 0; i < num_values; i++) { | ||||
574 | values[i] = attr.default_value().list().b(i); | ||||
575 | } | ||||
576 | TFE_OpSetAttrBoolList(op, key, values.get(), num_values); | ||||
577 | } else if (type == TF_ATTR_TYPE) { | ||||
578 | int num_values = attr.default_value().list().type_size(); | ||||
579 | std::unique_ptr<int[]> values(new int[num_values]); | ||||
580 | (*attr_list_sizes)[key] = num_values; | ||||
581 | for (int i = 0; i < num_values; i++) { | ||||
582 | values[i] = attr.default_value().list().type(i); | ||||
583 | } | ||||
584 | TFE_OpSetAttrTypeList(op, key, | ||||
585 | reinterpret_cast<const TF_DataType*>(values.get()), | ||||
586 | attr.default_value().list().type_size()); | ||||
587 | } else if (type == TF_ATTR_SHAPE) { | ||||
588 | int num_values = attr.default_value().list().shape_size(); | ||||
589 | (*attr_list_sizes)[key] = num_values; | ||||
590 | int total_dims = 0; | ||||
591 | for (int i = 0; i < num_values; ++i) { | ||||
592 | if (!attr.default_value().list().shape(i).unknown_rank()) { | ||||
593 | total_dims += attr.default_value().list().shape(i).dim_size(); | ||||
594 | } | ||||
595 | } | ||||
596 | // Allocate a buffer that can fit all of the dims together. | ||||
597 | std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]); | ||||
598 | // Copy the input dims into the buffer and set dims to point to | ||||
599 | // the start of each list's dims. | ||||
600 | std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]); | ||||
601 | std::unique_ptr<int[]> num_dims(new int[num_values]); | ||||
602 | int64_t* offset = buffer.get(); | ||||
603 | for (int i = 0; i < num_values; ++i) { | ||||
604 | const auto& shape = attr.default_value().list().shape(i); | ||||
605 | if (shape.unknown_rank()) { | ||||
606 | dims[i] = nullptr; | ||||
607 | num_dims[i] = -1; | ||||
608 | } else { | ||||
609 | for (int j = 0; j < shape.dim_size(); j++) { | ||||
610 | *offset = shape.dim(j).size(); | ||||
611 | ++offset; | ||||
612 | } | ||||
613 | } | ||||
614 | } | ||||
615 | TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, | ||||
616 | status); | ||||
617 | } else if (type == TF_ATTR_FUNC) { | ||||
618 | int num_values = attr.default_value().list().func_size(); | ||||
619 | (*attr_list_sizes)[key] = num_values; | ||||
620 | std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]); | ||||
621 | for (int i = 0; i < num_values; i++) { | ||||
622 | funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status); | ||||
623 | } | ||||
624 | TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values); | ||||
625 | } else { | ||||
626 | TF_SetStatus(status, TF_UNIMPLEMENTED, | ||||
627 | "Lists of tensors are not yet implemented for default valued " | ||||
628 | "attributes for an operation."); | ||||
629 | } | ||||
630 | } | ||||
631 | |||||
632 | bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key, | ||||
633 | PyObject* py_value, TF_AttrType type, | ||||
634 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | ||||
635 | TF_Status* status) { | ||||
636 | if (type == TF_ATTR_STRING) { | ||||
637 | tensorflow::StringPiece value; | ||||
638 | if (!ParseStringValue(key, py_value, status, &value)) return false; | ||||
639 | TFE_OpSetAttrString(op, key, value.data(), value.size()); | ||||
640 | } else if (type == TF_ATTR_INT) { | ||||
641 | int64_t value; | ||||
642 | if (!ParseInt64Value(key, py_value, status, &value)) return false; | ||||
643 | TFE_OpSetAttrInt(op, key, value); | ||||
644 | // attr_list_sizes is set for all int attributes (since at this point we are | ||||
645 | // not aware if that attribute might be used to calculate the size of an | ||||
646 | // output list or not). | ||||
647 | if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value; | ||||
648 | } else if (type == TF_ATTR_FLOAT) { | ||||
649 | float value; | ||||
650 | if (!ParseFloatValue(key, py_value, status, &value)) return false; | ||||
651 | TFE_OpSetAttrFloat(op, key, value); | ||||
652 | } else if (type == TF_ATTR_BOOL) { | ||||
653 | unsigned char value; | ||||
654 | if (!ParseBoolValue(key, py_value, status, &value)) return false; | ||||
655 | TFE_OpSetAttrBool(op, key, value); | ||||
656 | } else if (type == TF_ATTR_TYPE) { | ||||
657 | int value; | ||||
658 | if (!ParseTypeValue(key, py_value, status, &value)) return false; | ||||
659 | TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value)); | ||||
660 | } else if (type == TF_ATTR_SHAPE) { | ||||
661 | if (py_value == Py_None(&_Py_NoneStruct)) { | ||||
662 | TFE_OpSetAttrShape(op, key, nullptr, -1, status); | ||||
663 | } else { | ||||
664 | if (!PySequence_Check(py_value)) { | ||||
665 | TF_SetStatus(status, TF_INVALID_ARGUMENT, | ||||
666 | tensorflow::strings::StrCat( | ||||
667 | "Expecting None or sequence value for attr", key, | ||||
668 | ", got ", py_value->ob_type->tp_name) | ||||
669 | .c_str()); | ||||
670 | return false; | ||||
671 | } | ||||
672 | const auto num_dims = TensorShapeNumDims(py_value); | ||||
673 | if (num_dims == -1) { | ||||
674 | TFE_OpSetAttrShape(op, key, nullptr, -1, status); | ||||
675 | return true; | ||||
676 | } | ||||
677 | std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]); | ||||
678 | for (int i = 0; i < num_dims; ++i) { | ||||
679 | tensorflow::Safe_PyObjectPtr inner_py_value( | ||||
680 | PySequence_ITEM(py_value, i)( (((PyObject*)(py_value))->ob_type)->tp_as_sequence-> sq_item(py_value, i) )); | ||||
681 | if (inner_py_value.get() == Py_None(&_Py_NoneStruct)) { | ||||
682 | dims[i] = -1; | ||||
683 | } else if (!ParseDimensionValue(key, inner_py_value.get(), status, | ||||
684 | &dims[i])) { | ||||
685 | return false; | ||||
686 | } | ||||
687 | } | ||||
688 | TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status); | ||||
689 | } | ||||
690 | if (!status->status.ok()) return false; | ||||
691 | } else if (type == TF_ATTR_FUNC) { | ||||
692 | // Allow: | ||||
693 | // (1) String function name, OR | ||||
694 | // (2) A Python object with a .name attribute | ||||
695 | // (A crude test for being a | ||||
696 | // tensorflow.python.framework.function._DefinedFunction) | ||||
697 | // (which is what the various "defun" or "Defun" decorators do). | ||||
698 | // And in the future also allow an object that can encapsulate | ||||
699 | // the function name and its attribute values. | ||||
700 | tensorflow::StringPiece func_name; | ||||
701 | if (!ParseStringValue(key, py_value, status, &func_name)) { | ||||
702 | PyObject* name_attr = PyObject_GetAttrString(py_value, "name"); | ||||
703 | if (name_attr == nullptr || | ||||
704 | !ParseStringValue(key, name_attr, status, &func_name)) { | ||||
705 | TF_SetStatus( | ||||
706 | status, TF_INVALID_ARGUMENT, | ||||
707 | tensorflow::strings::StrCat( | ||||
708 | "unable to set function value attribute from a ", | ||||
709 | py_value->ob_type->tp_name, | ||||
710 | " object. If you think this is an error, please file an issue " | ||||
711 | "at https://github.com/tensorflow/tensorflow/issues/new") | ||||
712 | .c_str()); | ||||
713 | return false; | ||||
714 | } | ||||
715 | } | ||||
716 | TF_SetStatus(status, TF_OK, ""); | ||||
717 | TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size()); | ||||
718 | } else { | ||||
719 | TF_SetStatus( | ||||
720 | status, TF_UNIMPLEMENTED, | ||||
721 | tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type) | ||||
722 | .c_str()); | ||||
723 | return false; | ||||
724 | } | ||||
725 | return true; | ||||
726 | } | ||||
727 | |||||
728 | void SetOpAttrScalarDefault( | ||||
729 | TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, | ||||
730 | const char* attr_name, | ||||
731 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | ||||
732 | TF_Status* status) { | ||||
733 | SetOpAttrValueScalar(ctx, op, default_value, attr_name, status); | ||||
734 | if (default_value.value_case() == tensorflow::AttrValue::kI) { | ||||
735 | (*attr_list_sizes)[attr_name] = default_value.i(); | ||||
736 | } | ||||
737 | } | ||||
738 | |||||
739 | // start_index is the index at which the Tuple/List attrs will start getting | ||||
740 | // processed. | ||||
741 | void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index, | ||||
742 | TF_Status* out_status) { | ||||
743 | if (attrs == Py_None(&_Py_NoneStruct)) return; | ||||
744 | Py_ssize_t len = PyTuple_GET_SIZE(attrs)(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(attrs))))->ob_size) - start_index; | ||||
745 | if ((len & 1) != 0) { | ||||
746 | TF_SetStatus(out_status, TF_INVALID_ARGUMENT, | ||||
747 | "Expecting attrs tuple to have even length."); | ||||
748 | return; | ||||
749 | } | ||||
750 | // Parse attrs | ||||
751 | for (Py_ssize_t i = 0; i < len; i += 2) { | ||||
752 | PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i)(((static_cast<void> (0)), (PyTupleObject *)(attrs))-> ob_item[start_index + i]); | ||||
753 | PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1)(((static_cast<void> (0)), (PyTupleObject *)(attrs))-> ob_item[start_index + i + 1]); | ||||
754 | #if PY_MAJOR_VERSION3 >= 3 | ||||
755 | const char* key = PyBytes_Check(py_key)((((((PyObject*)(py_key))->ob_type))->tp_flags & (( 1UL << 27))) != 0) ? PyBytes_AsString(py_key) | ||||
756 | : PyUnicode_AsUTF8(py_key); | ||||
757 | #else | ||||
758 | const char* key = PyBytes_AsString(py_key); | ||||
759 | #endif | ||||
760 | unsigned char is_list = 0; | ||||
761 | const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status); | ||||
762 | if (!out_status->status.ok()) return; | ||||
763 | if (is_list != 0) { | ||||
764 | if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status)) | ||||
765 | return; | ||||
766 | } else { | ||||
767 | if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status)) | ||||
768 | return; | ||||
769 | } | ||||
770 | } | ||||
771 | } | ||||
772 | |||||
773 | // This function will set the op attrs required. If an attr has the value of | ||||
774 | // None, then it will read the AttrDef to get the default value and set that | ||||
775 | // instead. Any failure in this function will simply fall back to the slow | ||||
776 | // path. | ||||
777 | void SetOpAttrWithDefaults( | ||||
778 | TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, | ||||
779 | const char* attr_name, PyObject* attr_value, | ||||
780 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | ||||
781 | TF_Status* status) { | ||||
782 | unsigned char is_list = 0; | ||||
783 | const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status); | ||||
784 | if (!status->status.ok()) return; | ||||
785 | if (attr_value == Py_None(&_Py_NoneStruct)) { | ||||
786 | if (is_list != 0) { | ||||
787 | SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes, | ||||
788 | status); | ||||
789 | } else { | ||||
790 | SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name, | ||||
791 | attr_list_sizes, status); | ||||
792 | } | ||||
793 | } else { | ||||
794 | if (is_list != 0) { | ||||
795 | SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes, | ||||
796 | status); | ||||
797 | } else { | ||||
798 | SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes, | ||||
799 | status); | ||||
800 | } | ||||
801 | } | ||||
802 | } | ||||
803 | |||||
804 | PyObject* GetPythonObjectFromInt(int num) { | ||||
805 | #if PY_MAJOR_VERSION3 >= 3 | ||||
806 | return PyLong_FromLong(num); | ||||
807 | #else | ||||
808 | return PyInt_FromLong(num); | ||||
809 | #endif | ||||
810 | } | ||||
811 | |||||
812 | // Python subclass of Exception that is created on not ok Status. | ||||
813 | tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED); | ||||
814 | PyObject* exception_class TF_GUARDED_BY(exception_class_mutex)__attribute__((guarded_by(exception_class_mutex))) = nullptr; | ||||
815 | |||||
816 | // Python subclass of Exception that is created to signal fallback. | ||||
817 | PyObject* fallback_exception_class = nullptr; | ||||
818 | |||||
819 | // Python function that returns input gradients given output gradients. | ||||
820 | PyObject* gradient_function = nullptr; | ||||
821 | |||||
822 | // Python function that returns output gradients given input gradients. | ||||
823 | PyObject* forward_gradient_function = nullptr; | ||||
824 | |||||
825 | static std::atomic<int64_t> _uid; | ||||
826 | |||||
827 | } // namespace | ||||
828 | |||||
829 | TF_Status* GetStatus() { | ||||
830 | TF_Status* maybe_status = ReleaseThreadLocalStatus(); | ||||
831 | if (maybe_status) { | ||||
832 | TF_SetStatus(maybe_status, TF_OK, ""); | ||||
833 | return maybe_status; | ||||
834 | } else { | ||||
835 | return TF_NewStatus(); | ||||
836 | } | ||||
837 | } | ||||
838 | |||||
839 | void ReturnStatus(TF_Status* status) { | ||||
840 | TF_SetStatus(status, TF_OK, ""); | ||||
841 | thread_local_tf_status.reset(status); | ||||
842 | } | ||||
843 | |||||
844 | void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, | ||||
845 | const char* op_name, TFE_InputTensorHandles* inputs, | ||||
846 | PyObject* attrs, TFE_OutputTensorHandles* outputs, | ||||
847 | TF_Status* out_status) { | ||||
848 | TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs, | ||||
849 | /*cancellation_manager=*/nullptr, outputs, | ||||
850 | out_status); | ||||
851 | } | ||||
852 | |||||
853 | void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name, | ||||
854 | const char* op_name, | ||||
855 | TFE_InputTensorHandles* inputs, PyObject* attrs, | ||||
856 | TFE_CancellationManager* cancellation_manager, | ||||
857 | TFE_OutputTensorHandles* outputs, | ||||
858 | TF_Status* out_status) { | ||||
859 | tensorflow::profiler::TraceMe activity( | ||||
860 | "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo); | ||||
861 | |||||
862 | TFE_Op* op = GetOp(ctx, op_name, device_name, out_status); | ||||
863 | |||||
864 | auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); }); | ||||
865 | if (!out_status->status.ok()) return; | ||||
866 | |||||
867 | tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace( | ||||
868 | tensorflow::StackTrace::kStackTraceInitialSize)); | ||||
869 | |||||
870 | for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) { | ||||
871 | TFE_OpAddInput(op, inputs->at(i), out_status); | ||||
872 | } | ||||
873 | if (cancellation_manager && out_status->status.ok()) { | ||||
874 | TFE_OpSetCancellationManager(op, cancellation_manager, out_status); | ||||
875 | } | ||||
876 | if (out_status->status.ok()) { | ||||
877 | SetOpAttrs(ctx, op, attrs, 0, out_status); | ||||
878 | } | ||||
879 | Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();; | ||||
880 | |||||
881 | int num_outputs = outputs->size(); | ||||
882 | |||||
883 | if (out_status->status.ok()) { | ||||
884 | TFE_Execute(op, outputs->data(), &num_outputs, out_status); | ||||
885 | } | ||||
886 | |||||
887 | if (out_status->status.ok()) { | ||||
888 | outputs->resize(num_outputs); | ||||
889 | } else { | ||||
890 | TF_SetStatus(out_status, TF_GetCode(out_status), | ||||
891 | tensorflow::strings::StrCat(TF_Message(out_status), | ||||
892 | " [Op:", op_name, "]") | ||||
893 | .c_str()); | ||||
894 | } | ||||
895 | |||||
896 | Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); }; | ||||
897 | } | ||||
898 | |||||
899 | PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) { | ||||
900 | tensorflow::mutex_lock l(exception_class_mutex); | ||||
901 | if (exception_class != nullptr) { | ||||
902 | Py_DECREF(exception_class)_Py_DECREF(((PyObject*)(exception_class))); | ||||
903 | } | ||||
904 | if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { | ||||
905 | exception_class = nullptr; | ||||
906 | PyErr_SetString(PyExc_TypeError, | ||||
907 | "TFE_Py_RegisterExceptionClass: " | ||||
908 | "Registered class should be subclass of Exception."); | ||||
909 | return nullptr; | ||||
910 | } | ||||
911 | |||||
912 | Py_INCREF(e)_Py_INCREF(((PyObject*)(e))); | ||||
913 | exception_class = e; | ||||
914 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
915 | } | ||||
916 | |||||
917 | PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) { | ||||
918 | if (fallback_exception_class != nullptr) { | ||||
919 | Py_DECREF(fallback_exception_class)_Py_DECREF(((PyObject*)(fallback_exception_class))); | ||||
920 | } | ||||
921 | if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { | ||||
922 | fallback_exception_class = nullptr; | ||||
923 | PyErr_SetString(PyExc_TypeError, | ||||
924 | "TFE_Py_RegisterFallbackExceptionClass: " | ||||
925 | "Registered class should be subclass of Exception."); | ||||
926 | return nullptr; | ||||
927 | } else { | ||||
928 | Py_INCREF(e)_Py_INCREF(((PyObject*)(e))); | ||||
929 | fallback_exception_class = e; | ||||
930 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
931 | } | ||||
932 | } | ||||
933 | |||||
934 | PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) { | ||||
935 | if (gradient_function != nullptr) { | ||||
936 | Py_DECREF(gradient_function)_Py_DECREF(((PyObject*)(gradient_function))); | ||||
937 | } | ||||
938 | if (!PyCallable_Check(e)) { | ||||
939 | gradient_function = nullptr; | ||||
940 | PyErr_SetString(PyExc_TypeError, | ||||
941 | "TFE_Py_RegisterGradientFunction: " | ||||
942 | "Registered object should be function."); | ||||
943 | return nullptr; | ||||
944 | } else { | ||||
945 | Py_INCREF(e)_Py_INCREF(((PyObject*)(e))); | ||||
946 | gradient_function = e; | ||||
947 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
948 | } | ||||
949 | } | ||||
950 | |||||
951 | PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) { | ||||
952 | if (forward_gradient_function != nullptr) { | ||||
953 | Py_DECREF(forward_gradient_function)_Py_DECREF(((PyObject*)(forward_gradient_function))); | ||||
954 | } | ||||
955 | if (!PyCallable_Check(e)) { | ||||
956 | forward_gradient_function = nullptr; | ||||
957 | PyErr_SetString(PyExc_TypeError, | ||||
958 | "TFE_Py_RegisterJVPFunction: " | ||||
959 | "Registered object should be function."); | ||||
960 | return nullptr; | ||||
961 | } else { | ||||
962 | Py_INCREF(e)_Py_INCREF(((PyObject*)(e))); | ||||
963 | forward_gradient_function = e; | ||||
964 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
965 | } | ||||
966 | } | ||||
967 | |||||
968 | void RaiseFallbackException(const char* message) { | ||||
969 | if (fallback_exception_class != nullptr) { | ||||
970 | PyErr_SetString(fallback_exception_class, message); | ||||
971 | return; | ||||
972 | } | ||||
973 | |||||
974 | PyErr_SetString( | ||||
975 | PyExc_RuntimeError, | ||||
976 | tensorflow::strings::StrCat( | ||||
977 | "Fallback exception type not set, attempting to fallback due to ", | ||||
978 | message) | ||||
979 | .data()); | ||||
980 | } | ||||
981 | |||||
982 | // Format and return `status`' error message with the attached stack trace if | ||||
983 | // available. `status` must have an error. | ||||
984 | std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) { | ||||
985 | tensorflow::DCheckPyGilState(); | ||||
986 | DCHECK(!status.ok())while (false && (!status.ok())) ::tensorflow::internal ::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 986); | ||||
987 | |||||
988 | if (status.stack_trace().empty()) return status.error_message(); | ||||
989 | |||||
990 | const std::vector<tensorflow::StackFrame>& stack_trace = status.stack_trace(); | ||||
991 | |||||
992 | PyObject* linecache = PyImport_ImportModule("linecache"); | ||||
993 | PyObject* getline = | ||||
994 | PyObject_GetAttr(linecache, PyUnicode_FromString("getline")); | ||||
995 | DCHECK(getline)while (false && (getline)) ::tensorflow::internal::LogMessageFatal ("tensorflow/python/eager/pywrap_tfe_src.cc", 995); | ||||
996 | |||||
997 | std::ostringstream result; | ||||
998 | result << "Exception originated from\n\n"; | ||||
999 | |||||
1000 | for (const tensorflow::StackFrame& stack_frame : stack_trace) { | ||||
1001 | PyObject* line_str_obj = PyObject_CallFunction( | ||||
1002 | getline, const_cast<char*>("si"), stack_frame.file_name.c_str(), | ||||
1003 | stack_frame.line_number); | ||||
1004 | tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj); | ||||
1005 | tensorflow::str_util::RemoveWhitespaceContext(&line_str); | ||||
1006 | result << " File \"" << stack_frame.file_name << "\", line " | ||||
1007 | << stack_frame.line_number << ", in " << stack_frame.function_name | ||||
1008 | << '\n'; | ||||
1009 | |||||
1010 | if (!line_str.empty()) result << " " << line_str << '\n'; | ||||
1011 | Py_XDECREF(line_str_obj)_Py_XDECREF(((PyObject*)(line_str_obj))); | ||||
1012 | } | ||||
1013 | |||||
1014 | Py_DecRef(getline); | ||||
1015 | Py_DecRef(linecache); | ||||
1016 | |||||
1017 | result << '\n' << status.error_message(); | ||||
1018 | return result.str(); | ||||
1019 | } | ||||
1020 | |||||
1021 | int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) { | ||||
1022 | if (status->status.ok()) return 0; | ||||
1023 | const char* msg = TF_Message(status); | ||||
1024 | if (exception == nullptr) { | ||||
1025 | tensorflow::mutex_lock l(exception_class_mutex); | ||||
1026 | if (exception_class != nullptr) { | ||||
1027 | tensorflow::Safe_PyObjectPtr payloads(PyDict_New()); | ||||
1028 | for (const auto& payload : | ||||
1029 | tensorflow::errors::GetPayloads(status->status)) { | ||||
1030 | PyDict_SetItem(payloads.get(), | ||||
1031 | PyBytes_FromString(payload.first.c_str()), | ||||
1032 | PyBytes_FromString(payload.second.c_str())); | ||||
1033 | } | ||||
1034 | tensorflow::Safe_PyObjectPtr val(Py_BuildValue( | ||||
1035 | "siO", FormatErrorStatusStackTrace(status->status).c_str(), | ||||
1036 | TF_GetCode(status), payloads.get())); | ||||
1037 | if (PyErr_Occurred()) { | ||||
1038 | // NOTE: This hides the actual error (i.e. the reason `status` was not | ||||
1039 | // TF_OK), but there is nothing we can do at this point since we can't | ||||
1040 | // generate a reasonable error from the status. | ||||
1041 | // Consider adding a message explaining this. | ||||
1042 | return -1; | ||||
1043 | } | ||||
1044 | PyErr_SetObject(exception_class, val.get()); | ||||
1045 | return -1; | ||||
1046 | } else { | ||||
1047 | exception = PyExc_RuntimeError; | ||||
1048 | } | ||||
1049 | } | ||||
1050 | // May be update already set exception. | ||||
1051 | PyErr_SetString(exception, msg); | ||||
1052 | return -1; | ||||
1053 | } | ||||
1054 | |||||
1055 | int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, | ||||
1056 | PyObject* exception) { | ||||
1057 | if (status.ok()) return 0; | ||||
1058 | const char* msg = status.error_message().c_str(); | ||||
1059 | if (exception == nullptr) { | ||||
1060 | tensorflow::mutex_lock l(exception_class_mutex); | ||||
1061 | if (exception_class != nullptr) { | ||||
1062 | tensorflow::Safe_PyObjectPtr payloads(PyDict_New()); | ||||
1063 | for (const auto& element : tensorflow::errors::GetPayloads(status)) { | ||||
1064 | PyDict_SetItem(payloads.get(), | ||||
1065 | PyBytes_FromString(element.first.c_str()), | ||||
1066 | PyBytes_FromString(element.second.c_str())); | ||||
1067 | } | ||||
1068 | tensorflow::Safe_PyObjectPtr val( | ||||
1069 | Py_BuildValue("siO", FormatErrorStatusStackTrace(status).c_str(), | ||||
1070 | status.code(), payloads.get())); | ||||
1071 | PyErr_SetObject(exception_class, val.get()); | ||||
1072 | return -1; | ||||
1073 | } else { | ||||
1074 | exception = PyExc_RuntimeError; | ||||
1075 | } | ||||
1076 | } | ||||
1077 | // May be update already set exception. | ||||
1078 | PyErr_SetString(exception, msg); | ||||
1079 | return -1; | ||||
1080 | } | ||||
1081 | |||||
1082 | const char* TFE_GetPythonString(PyObject* o) { | ||||
1083 | #if PY_MAJOR_VERSION3 >= 3 | ||||
1084 | if (PyBytes_Check(o)((((((PyObject*)(o))->ob_type))->tp_flags & ((1UL << 27))) != 0)) { | ||||
1085 | return PyBytes_AsString(o); | ||||
1086 | } else { | ||||
1087 | return PyUnicode_AsUTF8(o); | ||||
1088 | } | ||||
1089 | #else | ||||
1090 | return PyBytes_AsString(o); | ||||
1091 | #endif | ||||
1092 | } | ||||
1093 | |||||
1094 | int64_t get_uid() { return _uid++; } | ||||
1095 | |||||
1096 | PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); } | ||||
1097 | |||||
1098 | void TFE_DeleteContextCapsule(PyObject* context) { | ||||
1099 | TFE_Context* ctx = | ||||
1100 | reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr)); | ||||
1101 | auto op = ReleaseThreadLocalOp(ctx); | ||||
1102 | op.reset(); | ||||
1103 | TFE_DeleteContext(ctx); | ||||
1104 | } | ||||
1105 | |||||
1106 | static int64_t MakeInt(PyObject* integer) { | ||||
1107 | #if PY_MAJOR_VERSION3 >= 3 | ||||
1108 | return PyLong_AsLong(integer); | ||||
1109 | #else | ||||
1110 | return PyInt_AsLong(integer); | ||||
1111 | #endif | ||||
1112 | } | ||||
1113 | |||||
1114 | static int64_t FastTensorId(PyObject* tensor) { | ||||
1115 | if (EagerTensor_CheckExact(tensor)) { | ||||
1116 | return PyEagerTensor_ID(tensor); | ||||
1117 | } | ||||
1118 | PyObject* id_field = PyObject_GetAttrString(tensor, "_id"); | ||||
1119 | if (id_field == nullptr) { | ||||
1120 | return -1; | ||||
1121 | } | ||||
1122 | int64_t id = MakeInt(id_field); | ||||
1123 | Py_DECREF(id_field)_Py_DECREF(((PyObject*)(id_field))); | ||||
1124 | return id; | ||||
1125 | } | ||||
1126 | |||||
1127 | namespace tensorflow { | ||||
1128 | DataType PyTensor_DataType(PyObject* tensor) { | ||||
1129 | if (EagerTensor_CheckExact(tensor)) { | ||||
1130 | return PyEagerTensor_Dtype(tensor); | ||||
1131 | } else { | ||||
1132 | #if PY_MAJOR_VERSION3 < 3 | ||||
1133 | // Python 2.x: | ||||
1134 | static PyObject* dtype_attr = PyString_InternFromString("dtype"); | ||||
1135 | static PyObject* type_enum_attr = PyString_InternFromString("_type_enum"); | ||||
1136 | #else | ||||
1137 | // Python 3.x: | ||||
1138 | static PyObject* dtype_attr = PyUnicode_InternFromString("dtype"); | ||||
1139 | static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum"); | ||||
1140 | #endif | ||||
1141 | Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr)); | ||||
1142 | if (!dtype_field) { | ||||
1143 | return DT_INVALID; | ||||
1144 | } | ||||
1145 | |||||
1146 | Safe_PyObjectPtr enum_field( | ||||
1147 | PyObject_GetAttr(dtype_field.get(), type_enum_attr)); | ||||
1148 | if (!enum_field) { | ||||
1149 | return DT_INVALID; | ||||
1150 | } | ||||
1151 | |||||
1152 | return static_cast<DataType>(MakeInt(enum_field.get())); | ||||
1153 | } | ||||
1154 | } | ||||
1155 | } // namespace tensorflow | ||||
1156 | |||||
1157 | class PyTapeTensor { | ||||
1158 | public: | ||||
1159 | PyTapeTensor(int64_t id, tensorflow::DataType dtype, | ||||
1160 | const tensorflow::TensorShape& shape) | ||||
1161 | : id_(id), dtype_(dtype), shape_(shape) {} | ||||
1162 | PyTapeTensor(int64_t id, tensorflow::DataType dtype, PyObject* shape) | ||||
1163 | : id_(id), dtype_(dtype), shape_(shape) { | ||||
1164 | Py_INCREF(absl::get<1>(shape_))_Py_INCREF(((PyObject*)(absl::get<1>(shape_)))); | ||||
1165 | } | ||||
1166 | PyTapeTensor(const PyTapeTensor& other) { | ||||
1167 | id_ = other.id_; | ||||
1168 | dtype_ = other.dtype_; | ||||
1169 | shape_ = other.shape_; | ||||
1170 | if (shape_.index() == 1) { | ||||
1171 | Py_INCREF(absl::get<1>(shape_))_Py_INCREF(((PyObject*)(absl::get<1>(shape_)))); | ||||
1172 | } | ||||
1173 | } | ||||
1174 | |||||
1175 | ~PyTapeTensor() { | ||||
1176 | if (shape_.index() == 1) { | ||||
1177 | Py_DECREF(absl::get<1>(shape_))_Py_DECREF(((PyObject*)(absl::get<1>(shape_)))); | ||||
1178 | } | ||||
1179 | } | ||||
1180 | PyObject* GetShape() const; | ||||
1181 | PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); } | ||||
1182 | int64_t GetID() const { return id_; } | ||||
1183 | tensorflow::DataType GetDType() const { return dtype_; } | ||||
1184 | |||||
1185 | PyObject* OnesLike() const; | ||||
1186 | PyObject* ZerosLike() const; | ||||
1187 | |||||
1188 | private: | ||||
1189 | int64_t id_; | ||||
1190 | tensorflow::DataType dtype_; | ||||
1191 | |||||
1192 | // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that | ||||
1193 | // PyObject is the tensor itself. This is used to support tf.shape(tensor) for | ||||
1194 | // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype | ||||
1195 | // tensors. | ||||
1196 | absl::variant<tensorflow::TensorShape, PyObject*> shape_; | ||||
1197 | }; | ||||
1198 | |||||
1199 | static PyTapeTensor TapeTensorFromTensor(PyObject* tensor); | ||||
1200 | |||||
1201 | class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction, | ||||
1202 | PyTapeTensor> { | ||||
1203 | public: | ||||
1204 | explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { | ||||
1205 | Py_INCREF(py_vspace_)_Py_INCREF(((PyObject*)(py_vspace_))); | ||||
1206 | } | ||||
1207 | |||||
1208 | tensorflow::Status Initialize() { | ||||
1209 | num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); | ||||
1210 | if (num_elements_ == nullptr) { | ||||
1211 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
1212 | } | ||||
1213 | aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn"); | ||||
1214 | if (aggregate_fn_ == nullptr) { | ||||
1215 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
1216 | } | ||||
1217 | zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn"); | ||||
1218 | if (zeros_fn_ == nullptr) { | ||||
1219 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
1220 | } | ||||
1221 | zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn"); | ||||
1222 | if (zeros_like_fn_ == nullptr) { | ||||
1223 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
1224 | } | ||||
1225 | ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn"); | ||||
1226 | if (ones_fn_ == nullptr) { | ||||
1227 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
1228 | } | ||||
1229 | ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn"); | ||||
1230 | if (ones_like_fn_ == nullptr) { | ||||
1231 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
1232 | } | ||||
1233 | graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn"); | ||||
1234 | if (graph_shape_fn_ == nullptr) { | ||||
1235 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
1236 | } | ||||
1237 | return tensorflow::Status::OK(); | ||||
1238 | } | ||||
1239 | |||||
1240 | ~PyVSpace() override { | ||||
1241 | Py_XDECREF(num_elements_)_Py_XDECREF(((PyObject*)(num_elements_))); | ||||
1242 | Py_XDECREF(aggregate_fn_)_Py_XDECREF(((PyObject*)(aggregate_fn_))); | ||||
1243 | Py_XDECREF(zeros_fn_)_Py_XDECREF(((PyObject*)(zeros_fn_))); | ||||
1244 | Py_XDECREF(zeros_like_fn_)_Py_XDECREF(((PyObject*)(zeros_like_fn_))); | ||||
1245 | Py_XDECREF(ones_fn_)_Py_XDECREF(((PyObject*)(ones_fn_))); | ||||
1246 | Py_XDECREF(ones_like_fn_)_Py_XDECREF(((PyObject*)(ones_like_fn_))); | ||||
1247 | Py_XDECREF(graph_shape_fn_)_Py_XDECREF(((PyObject*)(graph_shape_fn_))); | ||||
1248 | |||||
1249 | Py_DECREF(py_vspace_)_Py_DECREF(((PyObject*)(py_vspace_))); | ||||
1250 | } | ||||
1251 | |||||
1252 | int64_t NumElements(PyObject* tensor) const final { | ||||
1253 | if (EagerTensor_CheckExact(tensor)) { | ||||
1254 | return PyEagerTensor_NumElements(tensor); | ||||
1255 | } | ||||
1256 | PyObject* arglist = | ||||
1257 | Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor)); | ||||
1258 | PyObject* result = PyEval_CallObject(num_elements_, arglist)PyEval_CallObjectWithKeywords(num_elements_, arglist, (PyObject *)__null); | ||||
1259 | Py_DECREF(arglist)_Py_DECREF(((PyObject*)(arglist))); | ||||
1260 | if (result == nullptr) { | ||||
1261 | // The caller detects whether a python exception has been raised. | ||||
1262 | return -1; | ||||
1263 | } | ||||
1264 | int64_t r = MakeInt(result); | ||||
1265 | Py_DECREF(result)_Py_DECREF(((PyObject*)(result))); | ||||
1266 | return r; | ||||
1267 | } | ||||
1268 | |||||
1269 | PyObject* AggregateGradients( | ||||
1270 | tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final { | ||||
1271 | PyObject* list = PyList_New(gradient_tensors.size()); | ||||
1272 | for (int i = 0; i < gradient_tensors.size(); ++i) { | ||||
1273 | // Note: stealing a reference to the gradient tensors. | ||||
1274 | CHECK(gradient_tensors[i] != nullptr)if ((__builtin_expect(!(gradient_tensors[i] != nullptr), 0))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 1274) << "Check failed: " "gradient_tensors[i] != nullptr" " "; | ||||
1275 | CHECK(gradient_tensors[i] != Py_None)if ((__builtin_expect(!(gradient_tensors[i] != (&_Py_NoneStruct )), 0))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 1275) << "Check failed: " "gradient_tensors[i] != Py_None" " "; | ||||
1276 | PyList_SET_ITEM(list, i,PyList_SetItem(list, i, reinterpret_cast<PyObject*>(gradient_tensors [i])) | ||||
1277 | reinterpret_cast<PyObject*>(gradient_tensors[i]))PyList_SetItem(list, i, reinterpret_cast<PyObject*>(gradient_tensors [i])); | ||||
1278 | } | ||||
1279 | PyObject* arglist = Py_BuildValue("(O)", list); | ||||
1280 | CHECK(arglist != nullptr)if ((__builtin_expect(!(arglist != nullptr), 0))) ::tensorflow ::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 1280) << "Check failed: " "arglist != nullptr" " "; | ||||
1281 | PyObject* result = PyEval_CallObject(aggregate_fn_, arglist)PyEval_CallObjectWithKeywords(aggregate_fn_, arglist, (PyObject *)__null); | ||||
1282 | Py_DECREF(arglist)_Py_DECREF(((PyObject*)(arglist))); | ||||
1283 | Py_DECREF(list)_Py_DECREF(((PyObject*)(list))); | ||||
1284 | return result; | ||||
1285 | } | ||||
1286 | |||||
1287 | int64_t TensorId(PyObject* tensor) const final { | ||||
1288 | return FastTensorId(tensor); | ||||
1289 | } | ||||
1290 | |||||
1291 | void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient)_Py_INCREF(((PyObject*)(gradient))); } | ||||
1292 | |||||
1293 | PyObject* Ones(PyObject* shape, PyObject* dtype) const { | ||||
1294 | if (PyErr_Occurred()) { | ||||
1295 | return nullptr; | ||||
1296 | } | ||||
1297 | PyObject* arg_list = Py_BuildValue("OO", shape, dtype); | ||||
1298 | PyObject* result = PyEval_CallObject(ones_fn_, arg_list)PyEval_CallObjectWithKeywords(ones_fn_, arg_list, (PyObject * )__null); | ||||
1299 | Py_DECREF(arg_list)_Py_DECREF(((PyObject*)(arg_list))); | ||||
1300 | return result; | ||||
1301 | } | ||||
1302 | |||||
1303 | PyObject* OnesLike(PyObject* tensor) const { | ||||
1304 | if (PyErr_Occurred()) { | ||||
1305 | return nullptr; | ||||
1306 | } | ||||
1307 | return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL__null); | ||||
1308 | } | ||||
1309 | |||||
1310 | // Builds a tensor filled with ones with the same shape and dtype as `t`. | ||||
1311 | Status BuildOnesLike(const PyTapeTensor& t, | ||||
1312 | PyObject** result) const override { | ||||
1313 | *result = t.OnesLike(); | ||||
1314 | return Status::OK(); | ||||
1315 | } | ||||
1316 | |||||
1317 | PyObject* Zeros(PyObject* shape, PyObject* dtype) const { | ||||
1318 | if (PyErr_Occurred()) { | ||||
1319 | return nullptr; | ||||
1320 | } | ||||
1321 | PyObject* arg_list = Py_BuildValue("OO", shape, dtype); | ||||
1322 | PyObject* result = PyEval_CallObject(zeros_fn_, arg_list)PyEval_CallObjectWithKeywords(zeros_fn_, arg_list, (PyObject * )__null); | ||||
1323 | Py_DECREF(arg_list)_Py_DECREF(((PyObject*)(arg_list))); | ||||
1324 | return result; | ||||
1325 | } | ||||
1326 | |||||
1327 | PyObject* ZerosLike(PyObject* tensor) const { | ||||
1328 | if (PyErr_Occurred()) { | ||||
1329 | return nullptr; | ||||
1330 | } | ||||
1331 | return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL__null); | ||||
1332 | } | ||||
1333 | |||||
1334 | PyObject* GraphShape(PyObject* tensor) const { | ||||
1335 | PyObject* arg_list = Py_BuildValue("(O)", tensor); | ||||
1336 | PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list)PyEval_CallObjectWithKeywords(graph_shape_fn_, arg_list, (PyObject *)__null); | ||||
1337 | Py_DECREF(arg_list)_Py_DECREF(((PyObject*)(arg_list))); | ||||
1338 | return result; | ||||
1339 | } | ||||
1340 | |||||
1341 | tensorflow::Status CallBackwardFunction( | ||||
1342 | const string& op_type, PyBackwardFunction* backward_function, | ||||
1343 | const std::vector<int64_t>& unneeded_gradients, | ||||
1344 | tensorflow::gtl::ArraySlice<PyObject*> output_gradients, | ||||
1345 | absl::Span<PyObject*> result) const final { | ||||
1346 | PyObject* grads = PyTuple_New(output_gradients.size()); | ||||
1347 | for (int i = 0; i < output_gradients.size(); ++i) { | ||||
1348 | if (output_gradients[i] == nullptr) { | ||||
1349 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
1350 | PyTuple_SET_ITEM(grads, i, Py_None)PyTuple_SetItem(grads, i, (&_Py_NoneStruct)); | ||||
1351 | } else { | ||||
1352 | PyTuple_SET_ITEM(grads, i,PyTuple_SetItem(grads, i, reinterpret_cast<PyObject*>(output_gradients [i])) | ||||
1353 | reinterpret_cast<PyObject*>(output_gradients[i]))PyTuple_SetItem(grads, i, reinterpret_cast<PyObject*>(output_gradients [i])); | ||||
1354 | } | ||||
1355 | } | ||||
1356 | PyObject* py_result = (*backward_function)(grads, unneeded_gradients); | ||||
1357 | Py_DECREF(grads)_Py_DECREF(((PyObject*)(grads))); | ||||
1358 | if (py_result == nullptr) { | ||||
1359 | return tensorflow::errors::Internal("gradient function threw exceptions"); | ||||
1360 | } | ||||
1361 | PyObject* seq = | ||||
1362 | PySequence_Fast(py_result, "expected a sequence of gradients"); | ||||
1363 | if (seq == nullptr) { | ||||
1364 | return tensorflow::errors::InvalidArgument( | ||||
1365 | "gradient function did not return a list"); | ||||
1366 | } | ||||
1367 | int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size)); | ||||
1368 | if (len != result.size()) { | ||||
1369 | return tensorflow::errors::Internal( | ||||
1370 | "Recorded operation '", op_type, | ||||
1371 | "' returned too few gradients. Expected ", result.size(), | ||||
1372 | " but received ", len); | ||||
1373 | } | ||||
1374 | PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item); | ||||
1375 | VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 1375, tensorflow::INFO) << "Gradient length is " << len; | ||||
1376 | for (int i = 0; i < len; ++i) { | ||||
1377 | PyObject* item = seq_array[i]; | ||||
1378 | if (item == Py_None(&_Py_NoneStruct)) { | ||||
1379 | result[i] = nullptr; | ||||
1380 | } else { | ||||
1381 | Py_INCREF(item)_Py_INCREF(((PyObject*)(item))); | ||||
1382 | result[i] = item; | ||||
1383 | } | ||||
1384 | } | ||||
1385 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | ||||
1386 | Py_DECREF(py_result)_Py_DECREF(((PyObject*)(py_result))); | ||||
1387 | return tensorflow::Status::OK(); | ||||
1388 | } | ||||
1389 | |||||
1390 | void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor)_Py_XDECREF(((PyObject*)(tensor))); } | ||||
1391 | |||||
1392 | PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final { | ||||
1393 | return TapeTensorFromTensor(tensor); | ||||
1394 | } | ||||
1395 | |||||
1396 | private: | ||||
1397 | PyObject* py_vspace_; | ||||
1398 | |||||
1399 | PyObject* num_elements_; | ||||
1400 | PyObject* aggregate_fn_; | ||||
1401 | PyObject* zeros_fn_; | ||||
1402 | PyObject* zeros_like_fn_; | ||||
1403 | PyObject* ones_fn_; | ||||
1404 | PyObject* ones_like_fn_; | ||||
1405 | PyObject* graph_shape_fn_; | ||||
1406 | }; | ||||
1407 | PyVSpace* py_vspace = nullptr; | ||||
1408 | |||||
1409 | bool HasAccumulator(); | ||||
1410 | |||||
1411 | PyObject* TFE_Py_RegisterVSpace(PyObject* e) { | ||||
1412 | if (py_vspace != nullptr) { | ||||
1413 | if (HasAccumulator()) { | ||||
1414 | // Accumulators reference py_vspace, so we can't swap it out while one is | ||||
1415 | // active. This is unlikely to ever happen. | ||||
1416 | MaybeRaiseExceptionFromStatus( | ||||
1417 | tensorflow::errors::Internal( | ||||
1418 | "Can't change the vspace implementation while a " | ||||
1419 | "forward accumulator is active."), | ||||
1420 | nullptr); | ||||
1421 | } | ||||
1422 | delete py_vspace; | ||||
1423 | } | ||||
1424 | |||||
1425 | py_vspace = new PyVSpace(e); | ||||
1426 | auto status = py_vspace->Initialize(); | ||||
1427 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { | ||||
1428 | delete py_vspace; | ||||
1429 | return nullptr; | ||||
1430 | } | ||||
1431 | |||||
1432 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
1433 | } | ||||
1434 | |||||
1435 | PyObject* PyTapeTensor::GetShape() const { | ||||
1436 | if (shape_.index() == 0) { | ||||
1437 | auto& shape = absl::get<0>(shape_); | ||||
1438 | PyObject* py_shape = PyTuple_New(shape.dims()); | ||||
1439 | for (int i = 0; i < shape.dims(); ++i) { | ||||
1440 | PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)))PyTuple_SetItem(py_shape, i, PyLong_FromLong(shape.dim_size(i ))); | ||||
1441 | } | ||||
1442 | |||||
1443 | return py_shape; | ||||
1444 | } | ||||
1445 | |||||
1446 | return py_vspace->GraphShape(absl::get<1>(shape_)); | ||||
1447 | } | ||||
1448 | |||||
1449 | PyObject* PyTapeTensor::OnesLike() const { | ||||
1450 | if (shape_.index() == 1) { | ||||
1451 | PyObject* tensor = absl::get<1>(shape_); | ||||
1452 | return py_vspace->OnesLike(tensor); | ||||
1453 | } | ||||
1454 | PyObject* py_shape = GetShape(); | ||||
1455 | PyObject* dtype_field = GetPyDType(); | ||||
1456 | PyObject* result = py_vspace->Ones(py_shape, dtype_field); | ||||
1457 | Py_DECREF(dtype_field)_Py_DECREF(((PyObject*)(dtype_field))); | ||||
1458 | Py_DECREF(py_shape)_Py_DECREF(((PyObject*)(py_shape))); | ||||
1459 | return result; | ||||
1460 | } | ||||
1461 | |||||
1462 | PyObject* PyTapeTensor::ZerosLike() const { | ||||
1463 | if (GetDType() == tensorflow::DT_RESOURCE) { | ||||
1464 | // Gradient functions for ops which return resource tensors accept | ||||
1465 | // None. This is the behavior of py_vspace->Zeros, but checking here avoids | ||||
1466 | // issues with ZerosLike. | ||||
1467 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
1468 | } | ||||
1469 | if (shape_.index() == 1) { | ||||
1470 | PyObject* tensor = absl::get<1>(shape_); | ||||
1471 | return py_vspace->ZerosLike(tensor); | ||||
1472 | } | ||||
1473 | PyObject* py_shape = GetShape(); | ||||
1474 | PyObject* dtype_field = GetPyDType(); | ||||
1475 | PyObject* result = py_vspace->Zeros(py_shape, dtype_field); | ||||
1476 | Py_DECREF(dtype_field)_Py_DECREF(((PyObject*)(dtype_field))); | ||||
1477 | Py_DECREF(py_shape)_Py_DECREF(((PyObject*)(py_shape))); | ||||
1478 | return result; | ||||
1479 | } | ||||
1480 | |||||
1481 | // Keeps track of all variables that have been accessed during execution. | ||||
1482 | class VariableWatcher { | ||||
1483 | public: | ||||
1484 | VariableWatcher() {} | ||||
1485 | |||||
1486 | ~VariableWatcher() { | ||||
1487 | for (const IdAndVariable& v : watched_variables_) { | ||||
1488 | Py_DECREF(v.variable)_Py_DECREF(((PyObject*)(v.variable))); | ||||
1489 | } | ||||
1490 | } | ||||
1491 | |||||
1492 | int64_t WatchVariable(PyObject* v) { | ||||
1493 | tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle")); | ||||
1494 | if (handle == nullptr) { | ||||
1495 | return -1; | ||||
1496 | } | ||||
1497 | int64_t id = FastTensorId(handle.get()); | ||||
1498 | |||||
1499 | tensorflow::mutex_lock l(watched_variables_mu_); | ||||
1500 | auto insert_result = watched_variables_.emplace(id, v); | ||||
1501 | |||||
1502 | if (insert_result.second) { | ||||
1503 | // Only increment the reference count if we aren't already watching this | ||||
1504 | // variable. | ||||
1505 | Py_INCREF(v)_Py_INCREF(((PyObject*)(v))); | ||||
1506 | } | ||||
1507 | |||||
1508 | return id; | ||||
1509 | } | ||||
1510 | |||||
1511 | PyObject* GetVariablesAsPyTuple() { | ||||
1512 | tensorflow::mutex_lock l(watched_variables_mu_); | ||||
1513 | PyObject* result = PyTuple_New(watched_variables_.size()); | ||||
1514 | Py_ssize_t pos = 0; | ||||
1515 | for (const IdAndVariable& id_and_variable : watched_variables_) { | ||||
1516 | PyTuple_SET_ITEM(result, pos++, id_and_variable.variable)PyTuple_SetItem(result, pos++, id_and_variable.variable); | ||||
1517 | Py_INCREF(id_and_variable.variable)_Py_INCREF(((PyObject*)(id_and_variable.variable))); | ||||
1518 | } | ||||
1519 | return result; | ||||
1520 | } | ||||
1521 | |||||
1522 | private: | ||||
1523 | // We store an IdAndVariable in the map since the map needs to be locked | ||||
1524 | // during insert, but should not call back into python during insert to avoid | ||||
1525 | // deadlocking with the GIL. | ||||
1526 | struct IdAndVariable { | ||||
1527 | int64_t id; | ||||
1528 | PyObject* variable; | ||||
1529 | |||||
1530 | IdAndVariable(int64_t id, PyObject* variable) | ||||
1531 | : id(id), variable(variable) {} | ||||
1532 | }; | ||||
1533 | struct CompareById { | ||||
1534 | bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const { | ||||
1535 | return lhs.id < rhs.id; | ||||
1536 | } | ||||
1537 | }; | ||||
1538 | |||||
1539 | tensorflow::mutex watched_variables_mu_; | ||||
1540 | std::set<IdAndVariable, CompareById> watched_variables_ | ||||
1541 | TF_GUARDED_BY(watched_variables_mu_)__attribute__((guarded_by(watched_variables_mu_))); | ||||
1542 | }; | ||||
1543 | |||||
1544 | class GradientTape | ||||
1545 | : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, | ||||
1546 | PyTapeTensor> { | ||||
1547 | public: | ||||
1548 | explicit GradientTape(bool persistent, bool watch_accessed_variables) | ||||
1549 | : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, | ||||
1550 | PyTapeTensor>(persistent), | ||||
1551 | watch_accessed_variables_(watch_accessed_variables) {} | ||||
1552 | |||||
1553 | virtual ~GradientTape() {} | ||||
1554 | |||||
1555 | void VariableAccessed(PyObject* v) { | ||||
1556 | if (watch_accessed_variables_) { | ||||
1557 | WatchVariable(v); | ||||
1558 | } | ||||
1559 | } | ||||
1560 | |||||
1561 | void WatchVariable(PyObject* v) { | ||||
1562 | int64_t id = variable_watcher_.WatchVariable(v); | ||||
1563 | |||||
1564 | if (!PyErr_Occurred()) { | ||||
1565 | this->Watch(id); | ||||
1566 | } | ||||
1567 | } | ||||
1568 | |||||
1569 | PyObject* GetVariablesAsPyTuple() { | ||||
1570 | return variable_watcher_.GetVariablesAsPyTuple(); | ||||
1571 | } | ||||
1572 | |||||
1573 | private: | ||||
1574 | bool watch_accessed_variables_; | ||||
1575 | VariableWatcher variable_watcher_; | ||||
1576 | }; | ||||
1577 | |||||
1578 | typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction, | ||||
1579 | PyTapeTensor> | ||||
1580 | ForwardAccumulator; | ||||
1581 | |||||
1582 | // Incremented when a GradientTape or accumulator is newly added to a set, and | ||||
1583 | // used to enforce an ordering between them. | ||||
1584 | std::atomic_uint_fast64_t tape_nesting_id_counter(0); | ||||
1585 | |||||
1586 | typedef struct { | ||||
1587 | PyObject_HEADPyObject ob_base; | ||||
1588 | /* Type-specific fields go here. */ | ||||
1589 | GradientTape* tape; | ||||
1590 | // A nesting order between GradientTapes and ForwardAccumulators, used to | ||||
1591 | // ensure that GradientTapes do not watch the products of outer | ||||
1592 | // ForwardAccumulators. | ||||
1593 | int64_t nesting_id; | ||||
1594 | } TFE_Py_Tape; | ||||
1595 | |||||
1596 | static void TFE_Py_Tape_Delete(PyObject* tape) { | ||||
1597 | delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape; | ||||
1598 | Py_TYPE(tape)(((PyObject*)(tape))->ob_type)->tp_free(tape); | ||||
1599 | } | ||||
1600 | |||||
1601 | static PyTypeObject TFE_Py_Tape_Type = { | ||||
1602 | PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, "tfe.Tape", /* tp_name */ | ||||
1603 | sizeof(TFE_Py_Tape), /* tp_basicsize */ | ||||
1604 | 0, /* tp_itemsize */ | ||||
1605 | &TFE_Py_Tape_Delete, /* tp_dealloc */ | ||||
1606 | #if PY_VERSION_HEX((3 << 24) | (8 << 16) | (5 << 8) | (0xF << 4) | (0 << 0)) < 0x03080000 | ||||
1607 | nullptr, /* tp_print */ | ||||
1608 | #else | ||||
1609 | 0, /* tp_vectorcall_offset */ | ||||
1610 | #endif | ||||
1611 | nullptr, /* tp_getattr */ | ||||
1612 | nullptr, /* tp_setattr */ | ||||
1613 | nullptr, /* tp_reserved */ | ||||
1614 | nullptr, /* tp_repr */ | ||||
1615 | nullptr, /* tp_as_number */ | ||||
1616 | nullptr, /* tp_as_sequence */ | ||||
1617 | nullptr, /* tp_as_mapping */ | ||||
1618 | nullptr, /* tp_hash */ | ||||
1619 | nullptr, /* tp_call */ | ||||
1620 | nullptr, /* tp_str */ | ||||
1621 | nullptr, /* tp_getattro */ | ||||
1622 | nullptr, /* tp_setattro */ | ||||
1623 | nullptr, /* tp_as_buffer */ | ||||
1624 | Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0), /* tp_flags */ | ||||
1625 | "TFE_Py_Tape objects", /* tp_doc */ | ||||
1626 | }; | ||||
1627 | |||||
1628 | typedef struct { | ||||
1629 | PyObject_HEADPyObject ob_base; | ||||
1630 | /* Type-specific fields go here. */ | ||||
1631 | ForwardAccumulator* accumulator; | ||||
1632 | // A nesting order between GradientTapes and ForwardAccumulators, used to | ||||
1633 | // ensure that GradientTapes do not watch the products of outer | ||||
1634 | // ForwardAccumulators. | ||||
1635 | int64_t nesting_id; | ||||
1636 | } TFE_Py_ForwardAccumulator; | ||||
1637 | |||||
1638 | static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) { | ||||
1639 | delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator; | ||||
1640 | Py_TYPE(accumulator)(((PyObject*)(accumulator))->ob_type)->tp_free(accumulator); | ||||
1641 | } | ||||
1642 | |||||
1643 | static PyTypeObject TFE_Py_ForwardAccumulator_Type = { | ||||
1644 | PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, "ForwardAccumulator", /* tp_name */ | ||||
1645 | sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */ | ||||
1646 | 0, /* tp_itemsize */ | ||||
1647 | &TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */ | ||||
1648 | #if PY_VERSION_HEX((3 << 24) | (8 << 16) | (5 << 8) | (0xF << 4) | (0 << 0)) < 0x03080000 | ||||
1649 | nullptr, /* tp_print */ | ||||
1650 | #else | ||||
1651 | 0, /* tp_vectorcall_offset */ | ||||
1652 | #endif | ||||
1653 | nullptr, /* tp_getattr */ | ||||
1654 | nullptr, /* tp_setattr */ | ||||
1655 | nullptr, /* tp_reserved */ | ||||
1656 | nullptr, /* tp_repr */ | ||||
1657 | nullptr, /* tp_as_number */ | ||||
1658 | nullptr, /* tp_as_sequence */ | ||||
1659 | nullptr, /* tp_as_mapping */ | ||||
1660 | nullptr, /* tp_hash */ | ||||
1661 | nullptr, /* tp_call */ | ||||
1662 | nullptr, /* tp_str */ | ||||
1663 | nullptr, /* tp_getattro */ | ||||
1664 | nullptr, /* tp_setattro */ | ||||
1665 | nullptr, /* tp_as_buffer */ | ||||
1666 | Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0), /* tp_flags */ | ||||
1667 | "TFE_Py_ForwardAccumulator objects", /* tp_doc */ | ||||
1668 | }; | ||||
1669 | |||||
1670 | typedef struct { | ||||
1671 | PyObject_HEADPyObject ob_base; | ||||
1672 | /* Type-specific fields go here. */ | ||||
1673 | VariableWatcher* variable_watcher; | ||||
1674 | } TFE_Py_VariableWatcher; | ||||
1675 | |||||
1676 | static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) { | ||||
1677 | delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher) | ||||
1678 | ->variable_watcher; | ||||
1679 | Py_TYPE(variable_watcher)(((PyObject*)(variable_watcher))->ob_type)->tp_free(variable_watcher); | ||||
1680 | } | ||||
1681 | |||||
1682 | static PyTypeObject TFE_Py_VariableWatcher_Type = { | ||||
1683 | PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, "tfe.VariableWatcher", /* tp_name */ | ||||
1684 | sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */ | ||||
1685 | 0, /* tp_itemsize */ | ||||
1686 | &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */ | ||||
1687 | #if PY_VERSION_HEX((3 << 24) | (8 << 16) | (5 << 8) | (0xF << 4) | (0 << 0)) < 0x03080000 | ||||
1688 | nullptr, /* tp_print */ | ||||
1689 | #else | ||||
1690 | 0, /* tp_vectorcall_offset */ | ||||
1691 | #endif | ||||
1692 | nullptr, /* tp_getattr */ | ||||
1693 | nullptr, /* tp_setattr */ | ||||
1694 | nullptr, /* tp_reserved */ | ||||
1695 | nullptr, /* tp_repr */ | ||||
1696 | nullptr, /* tp_as_number */ | ||||
1697 | nullptr, /* tp_as_sequence */ | ||||
1698 | nullptr, /* tp_as_mapping */ | ||||
1699 | nullptr, /* tp_hash */ | ||||
1700 | nullptr, /* tp_call */ | ||||
1701 | nullptr, /* tp_str */ | ||||
1702 | nullptr, /* tp_getattro */ | ||||
1703 | nullptr, /* tp_setattro */ | ||||
1704 | nullptr, /* tp_as_buffer */ | ||||
1705 | Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0), /* tp_flags */ | ||||
1706 | "TFE_Py_VariableWatcher objects", /* tp_doc */ | ||||
1707 | }; | ||||
1708 | |||||
1709 | // Note: in the current design no mutex is needed here because of the python | ||||
1710 | // GIL, which is always held when any TFE_Py_* methods are called. We should | ||||
1711 | // revisit this if/when decide to not hold the GIL while manipulating the tape | ||||
1712 | // stack. | ||||
1713 | tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() { | ||||
1714 | thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> | ||||
1715 | tape_set = nullptr; | ||||
1716 | if (tape_set == nullptr) { | ||||
1717 | tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>); | ||||
1718 | } | ||||
1719 | return tape_set.get(); | ||||
1720 | } | ||||
1721 | |||||
1722 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>* | ||||
1723 | GetVariableWatcherSet() { | ||||
1724 | thread_local std::unique_ptr< | ||||
1725 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> | ||||
1726 | variable_watcher_set = nullptr; | ||||
1727 | if (variable_watcher_set == nullptr) { | ||||
1728 | variable_watcher_set.reset( | ||||
1729 | new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>); | ||||
1730 | } | ||||
1731 | return variable_watcher_set.get(); | ||||
1732 | } | ||||
1733 | |||||
1734 | // A linked hash set, where iteration is in insertion order. | ||||
1735 | // | ||||
1736 | // Nested accumulators rely on op recording happening in insertion order, so an | ||||
1737 | // unordered data structure like CompactPointerSet is not suitable. Outer | ||||
1738 | // accumulators need to observe operations first so they know to watch the inner | ||||
1739 | // accumulator's jvp computation. | ||||
1740 | // | ||||
1741 | // Not thread safe. | ||||
1742 | class AccumulatorSet { | ||||
1743 | public: | ||||
1744 | // Returns true if `element` was newly inserted, false if it already exists. | ||||
1745 | bool insert(TFE_Py_ForwardAccumulator* element) { | ||||
1746 | if (map_.find(element) != map_.end()) { | ||||
1747 | return false; | ||||
1748 | } | ||||
1749 | ListType::iterator it = ordered_.insert(ordered_.end(), element); | ||||
1750 | map_.insert(std::make_pair(element, it)); | ||||
1751 | return true; | ||||
1752 | } | ||||
1753 | |||||
1754 | void erase(TFE_Py_ForwardAccumulator* element) { | ||||
1755 | MapType::iterator existing = map_.find(element); | ||||
1756 | if (existing == map_.end()) { | ||||
1757 | return; | ||||
1758 | } | ||||
1759 | ListType::iterator list_position = existing->second; | ||||
1760 | map_.erase(existing); | ||||
1761 | ordered_.erase(list_position); | ||||
1762 | } | ||||
1763 | |||||
1764 | bool empty() const { return ordered_.empty(); } | ||||
1765 | |||||
1766 | size_t size() const { return ordered_.size(); } | ||||
1767 | |||||
1768 | private: | ||||
1769 | typedef std::list<TFE_Py_ForwardAccumulator*> ListType; | ||||
1770 | typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*, | ||||
1771 | ListType::iterator> | ||||
1772 | MapType; | ||||
1773 | |||||
1774 | public: | ||||
1775 | typedef ListType::const_iterator const_iterator; | ||||
1776 | typedef ListType::const_reverse_iterator const_reverse_iterator; | ||||
1777 | |||||
1778 | const_iterator begin() const { return ordered_.begin(); } | ||||
1779 | const_iterator end() const { return ordered_.end(); } | ||||
1780 | |||||
1781 | const_reverse_iterator rbegin() const { return ordered_.rbegin(); } | ||||
1782 | const_reverse_iterator rend() const { return ordered_.rend(); } | ||||
1783 | |||||
1784 | private: | ||||
1785 | MapType map_; | ||||
1786 | ListType ordered_; | ||||
1787 | }; | ||||
1788 | |||||
1789 | AccumulatorSet* GetAccumulatorSet() { | ||||
1790 | thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr}; | ||||
1791 | if (accumulator_set == nullptr) { | ||||
1792 | accumulator_set.reset(new AccumulatorSet); | ||||
1793 | } | ||||
1794 | return accumulator_set.get(); | ||||
1795 | } | ||||
1796 | |||||
1797 | inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); } | ||||
1798 | |||||
1799 | inline bool HasGradientTape() { return !GetTapeSet()->empty(); } | ||||
1800 | |||||
1801 | inline bool HasAccumulatorOrTape() { | ||||
1802 | return HasGradientTape() || HasAccumulator(); | ||||
1803 | } | ||||
1804 | |||||
1805 | // A safe copy of a set, used for tapes and accumulators. The copy is not | ||||
1806 | // affected by other python threads changing the set of active tapes. | ||||
1807 | template <typename ContainerType> | ||||
1808 | class SafeSetCopy { | ||||
1809 | public: | ||||
1810 | explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) { | ||||
1811 | for (auto* member : set_copy_) { | ||||
1812 | Py_INCREF(member)_Py_INCREF(((PyObject*)(member))); | ||||
1813 | } | ||||
1814 | } | ||||
1815 | |||||
1816 | ~SafeSetCopy() { | ||||
1817 | for (auto* member : set_copy_) { | ||||
1818 | Py_DECREF(member)_Py_DECREF(((PyObject*)(member))); | ||||
1819 | } | ||||
1820 | } | ||||
1821 | |||||
1822 | typename ContainerType::const_iterator begin() const { | ||||
1823 | return set_copy_.begin(); | ||||
1824 | } | ||||
1825 | |||||
1826 | typename ContainerType::const_iterator end() const { return set_copy_.end(); } | ||||
1827 | |||||
1828 | bool empty() const { return set_copy_.empty(); } | ||||
1829 | size_t size() const { return set_copy_.size(); } | ||||
1830 | |||||
1831 | protected: | ||||
1832 | ContainerType set_copy_; | ||||
1833 | }; | ||||
1834 | |||||
1835 | class SafeTapeSet | ||||
1836 | : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> { | ||||
1837 | public: | ||||
1838 | SafeTapeSet() | ||||
1839 | : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>( | ||||
1840 | *GetTapeSet()) {} | ||||
1841 | }; | ||||
1842 | |||||
1843 | class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> { | ||||
1844 | public: | ||||
1845 | SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {} | ||||
1846 | |||||
1847 | typename AccumulatorSet::const_reverse_iterator rbegin() const { | ||||
1848 | return set_copy_.rbegin(); | ||||
1849 | } | ||||
1850 | |||||
1851 | typename AccumulatorSet::const_reverse_iterator rend() const { | ||||
1852 | return set_copy_.rend(); | ||||
1853 | } | ||||
1854 | }; | ||||
1855 | |||||
1856 | class SafeVariableWatcherSet | ||||
1857 | : public SafeSetCopy< | ||||
1858 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> { | ||||
1859 | public: | ||||
1860 | SafeVariableWatcherSet() | ||||
1861 | : SafeSetCopy< | ||||
1862 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>( | ||||
1863 | *GetVariableWatcherSet()) {} | ||||
1864 | }; | ||||
1865 | |||||
1866 | bool* ThreadTapeIsStopped() { | ||||
1867 | thread_local bool thread_tape_is_stopped{false}; | ||||
1868 | return &thread_tape_is_stopped; | ||||
1869 | } | ||||
1870 | |||||
1871 | void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; } | ||||
1872 | |||||
1873 | void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; } | ||||
1874 | |||||
1875 | PyObject* TFE_Py_TapeSetIsStopped() { | ||||
1876 | if (*ThreadTapeIsStopped()) { | ||||
1877 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | ||||
1878 | } | ||||
1879 | Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct); | ||||
1880 | } | ||||
1881 | |||||
1882 | PyObject* TFE_Py_TapeSetNew(PyObject* persistent, | ||||
1883 | PyObject* watch_accessed_variables) { | ||||
1884 | TFE_Py_Tape_Type.tp_new = PyType_GenericNew; | ||||
1885 | if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; | ||||
1886 | TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type)( (TFE_Py_Tape *) PyObject_Init( (PyObject *) PyObject_Malloc ( ( (&TFE_Py_Tape_Type)->tp_basicsize ) ), (&TFE_Py_Tape_Type )) ); | ||||
1887 | tape->tape = new GradientTape(persistent == Py_True((PyObject *) &_Py_TrueStruct), | ||||
1888 | watch_accessed_variables == Py_True((PyObject *) &_Py_TrueStruct)); | ||||
1889 | Py_INCREF(tape)_Py_INCREF(((PyObject*)(tape))); | ||||
1890 | tape->nesting_id = tape_nesting_id_counter.fetch_add(1); | ||||
1891 | GetTapeSet()->insert(tape); | ||||
1892 | return reinterpret_cast<PyObject*>(tape); | ||||
1893 | } | ||||
1894 | |||||
1895 | void TFE_Py_TapeSetAdd(PyObject* tape) { | ||||
1896 | Py_INCREF(tape)_Py_INCREF(((PyObject*)(tape))); | ||||
1897 | TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape); | ||||
1898 | if (!GetTapeSet()->insert(tfe_tape).second) { | ||||
1899 | // Already exists in the tape set. | ||||
1900 | Py_DECREF(tape)_Py_DECREF(((PyObject*)(tape))); | ||||
1901 | } else { | ||||
1902 | tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1); | ||||
1903 | } | ||||
1904 | } | ||||
1905 | |||||
1906 | PyObject* TFE_Py_TapeSetIsEmpty() { | ||||
1907 | if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) { | ||||
1908 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | ||||
1909 | } | ||||
1910 | Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct); | ||||
1911 | } | ||||
1912 | |||||
1913 | void TFE_Py_TapeSetRemove(PyObject* tape) { | ||||
1914 | auto* stack = GetTapeSet(); | ||||
1915 | stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape)); | ||||
1916 | // We kept a reference to the tape in the set to ensure it wouldn't get | ||||
1917 | // deleted under us; cleaning it up here. | ||||
1918 | Py_DECREF(tape)_Py_DECREF(((PyObject*)(tape))); | ||||
1919 | } | ||||
1920 | |||||
1921 | static std::vector<int64_t> MakeIntList(PyObject* list) { | ||||
1922 | if (list == Py_None(&_Py_NoneStruct)) { | ||||
1923 | return {}; | ||||
1924 | } | ||||
1925 | PyObject* seq = PySequence_Fast(list, "expected a sequence"); | ||||
1926 | if (seq == nullptr) { | ||||
1927 | return {}; | ||||
1928 | } | ||||
1929 | int len = PySequence_Size(list); | ||||
1930 | PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item); | ||||
1931 | std::vector<int64_t> tensor_ids; | ||||
1932 | tensor_ids.reserve(len); | ||||
1933 | for (int i = 0; i < len; ++i) { | ||||
1934 | PyObject* item = seq_array[i]; | ||||
1935 | #if PY_MAJOR_VERSION3 >= 3 | ||||
1936 | if (PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL << 24))) != 0)) { | ||||
1937 | #else | ||||
1938 | if (PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL << 24))) != 0) || PyInt_Check(item)) { | ||||
1939 | #endif | ||||
1940 | int64_t id = MakeInt(item); | ||||
1941 | tensor_ids.push_back(id); | ||||
1942 | } else { | ||||
1943 | tensor_ids.push_back(-1); | ||||
1944 | } | ||||
1945 | } | ||||
1946 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | ||||
1947 | return tensor_ids; | ||||
1948 | } | ||||
1949 | |||||
1950 | // Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be | ||||
1951 | // null. Returns true on success and false on a Python exception. | ||||
1952 | bool TensorShapesAndDtypes(PyObject* tensors, std::vector<int64_t>* tensor_ids, | ||||
1953 | std::vector<tensorflow::DataType>* dtypes) { | ||||
1954 | tensorflow::Safe_PyObjectPtr seq( | ||||
1955 | PySequence_Fast(tensors, "expected a sequence")); | ||||
1956 | if (seq == nullptr) { | ||||
1957 | return false; | ||||
1958 | } | ||||
1959 | int len = PySequence_Fast_GET_SIZE(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(seq.get()))->ob_size)) : (((PyVarObject* )(((static_cast<void> (0)), (PyTupleObject *)(seq.get() ))))->ob_size)); | ||||
1960 | PyObject** seq_array = PySequence_Fast_ITEMS(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq.get()))-> ob_item : ((PyTupleObject *)(seq.get()))->ob_item); | ||||
1961 | tensor_ids->reserve(len); | ||||
1962 | dtypes->reserve(len); | ||||
1963 | for (int i = 0; i < len; ++i) { | ||||
1964 | PyObject* item = seq_array[i]; | ||||
1965 | tensor_ids->push_back(FastTensorId(item)); | ||||
1966 | dtypes->push_back(tensorflow::PyTensor_DataType(item)); | ||||
1967 | } | ||||
1968 | return true; | ||||
1969 | } | ||||
1970 | |||||
1971 | bool TapeCouldPossiblyRecord(PyObject* tensors) { | ||||
1972 | if (tensors == Py_None(&_Py_NoneStruct)) { | ||||
1973 | return false; | ||||
1974 | } | ||||
1975 | if (*ThreadTapeIsStopped()) { | ||||
1976 | return false; | ||||
1977 | } | ||||
1978 | if (!HasAccumulatorOrTape()) { | ||||
1979 | return false; | ||||
1980 | } | ||||
1981 | return true; | ||||
1982 | } | ||||
1983 | |||||
1984 | bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); } | ||||
1985 | |||||
1986 | bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); } | ||||
1987 | |||||
1988 | PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) { | ||||
1989 | if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) { | ||||
1990 | Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct); | ||||
1991 | } | ||||
1992 | // TODO(apassos) consider not building a list and changing the API to check | ||||
1993 | // each tensor individually. | ||||
1994 | std::vector<int64_t> tensor_ids; | ||||
1995 | std::vector<tensorflow::DataType> dtypes; | ||||
1996 | if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) { | ||||
1997 | return nullptr; | ||||
1998 | } | ||||
1999 | auto tape_set = *GetTapeSet(); | ||||
2000 | for (TFE_Py_Tape* tape : tape_set) { | ||||
2001 | if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { | ||||
2002 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | ||||
2003 | } | ||||
2004 | } | ||||
2005 | |||||
2006 | Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct); | ||||
2007 | } | ||||
2008 | |||||
2009 | PyObject* TFE_Py_ForwardAccumulatorPushState() { | ||||
2010 | auto forward_accumulators = *GetAccumulatorSet(); | ||||
2011 | for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { | ||||
2012 | accumulator->accumulator->PushState(); | ||||
2013 | } | ||||
2014 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
2015 | } | ||||
2016 | |||||
2017 | PyObject* TFE_Py_ForwardAccumulatorPopState() { | ||||
2018 | auto forward_accumulators = *GetAccumulatorSet(); | ||||
2019 | for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { | ||||
2020 | accumulator->accumulator->PopState(); | ||||
2021 | } | ||||
2022 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
2023 | } | ||||
2024 | |||||
2025 | PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) { | ||||
2026 | if (!TapeCouldPossiblyRecord(tensors)) { | ||||
2027 | return GetPythonObjectFromInt(0); | ||||
2028 | } | ||||
2029 | std::vector<int64_t> tensor_ids; | ||||
2030 | std::vector<tensorflow::DataType> dtypes; | ||||
2031 | if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) { | ||||
2032 | return nullptr; | ||||
2033 | } | ||||
2034 | |||||
2035 | // If there is a persistent tape watching, or if there are multiple tapes | ||||
2036 | // watching, we'll return immediately indicating that higher-order tape | ||||
2037 | // gradients are possible. | ||||
2038 | bool some_tape_watching = false; | ||||
2039 | if (CouldBackprop()) { | ||||
2040 | auto tape_set = *GetTapeSet(); | ||||
2041 | for (TFE_Py_Tape* tape : tape_set) { | ||||
2042 | if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { | ||||
2043 | if (tape->tape->IsPersistent() || some_tape_watching) { | ||||
2044 | // Either this is the second tape watching, or this tape is | ||||
2045 | // persistent: higher-order gradients are possible. | ||||
2046 | return GetPythonObjectFromInt(2); | ||||
2047 | } | ||||
2048 | some_tape_watching = true; | ||||
2049 | } | ||||
2050 | } | ||||
2051 | } | ||||
2052 | if (CouldForwardprop()) { | ||||
2053 | auto forward_accumulators = *GetAccumulatorSet(); | ||||
2054 | for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { | ||||
2055 | if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) { | ||||
2056 | if (some_tape_watching) { | ||||
2057 | // This is the second tape watching: higher-order gradients are | ||||
2058 | // possible. Note that there's no equivalent of persistence for | ||||
2059 | // forward-mode. | ||||
2060 | return GetPythonObjectFromInt(2); | ||||
2061 | } | ||||
2062 | some_tape_watching = true; | ||||
2063 | } | ||||
2064 | } | ||||
2065 | } | ||||
2066 | if (some_tape_watching) { | ||||
2067 | // There's exactly one non-persistent tape. The user can request first-order | ||||
2068 | // gradients but won't be able to get higher-order tape gradients. | ||||
2069 | return GetPythonObjectFromInt(1); | ||||
2070 | } else { | ||||
2071 | // There are no tapes. The user can't request tape gradients. | ||||
2072 | return GetPythonObjectFromInt(0); | ||||
2073 | } | ||||
2074 | } | ||||
2075 | |||||
2076 | void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { | ||||
2077 | if (!CouldBackprop()) { | ||||
2078 | return; | ||||
2079 | } | ||||
2080 | int64_t tensor_id = FastTensorId(tensor); | ||||
2081 | if (PyErr_Occurred()) { | ||||
2082 | return; | ||||
2083 | } | ||||
2084 | reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id); | ||||
2085 | } | ||||
2086 | |||||
2087 | bool ListContainsNone(PyObject* list) { | ||||
2088 | if (list == Py_None(&_Py_NoneStruct)) return true; | ||||
2089 | tensorflow::Safe_PyObjectPtr seq( | ||||
2090 | PySequence_Fast(list, "expected a sequence")); | ||||
2091 | if (seq == nullptr) { | ||||
2092 | return false; | ||||
2093 | } | ||||
2094 | |||||
2095 | int len = PySequence_Size(list); | ||||
2096 | PyObject** seq_array = PySequence_Fast_ITEMS(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq.get()))-> ob_item : ((PyTupleObject *)(seq.get()))->ob_item); | ||||
2097 | for (int i = 0; i < len; ++i) { | ||||
2098 | PyObject* item = seq_array[i]; | ||||
2099 | if (item == Py_None(&_Py_NoneStruct)) return true; | ||||
2100 | } | ||||
2101 | |||||
2102 | return false; | ||||
2103 | } | ||||
2104 | |||||
2105 | // As an optimization, the tape generally keeps only the shape and dtype of | ||||
2106 | // tensors, and uses this information to generate ones/zeros tensors. However, | ||||
2107 | // some tensors require OnesLike/ZerosLike because their gradients do not match | ||||
2108 | // their inference shape/dtype. | ||||
2109 | bool DTypeNeedsHandleData(tensorflow::DataType dtype) { | ||||
2110 | return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE; | ||||
2111 | } | ||||
2112 | |||||
2113 | static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { | ||||
2114 | if (EagerTensor_CheckExact(tensor)) { | ||||
2115 | tensorflow::ImmediateExecutionTensorHandle* handle = | ||||
2116 | tensorflow::unwrap(EagerTensor_Handle(tensor)); | ||||
2117 | int64_t id = PyEagerTensor_ID(tensor); | ||||
2118 | tensorflow::DataType dtype = | ||||
2119 | static_cast<tensorflow::DataType>(handle->DataType()); | ||||
2120 | if (DTypeNeedsHandleData(dtype)) { | ||||
2121 | return PyTapeTensor(id, dtype, tensor); | ||||
2122 | } | ||||
2123 | |||||
2124 | tensorflow::TensorShape tensor_shape; | ||||
2125 | int num_dims; | ||||
2126 | tensorflow::Status status = handle->NumDims(&num_dims); | ||||
2127 | if (status.ok()) { | ||||
2128 | for (int i = 0; i < num_dims; ++i) { | ||||
2129 | int64_t dim_size; | ||||
2130 | status = handle->Dim(i, &dim_size); | ||||
2131 | if (!status.ok()) break; | ||||
2132 | tensor_shape.AddDim(dim_size); | ||||
2133 | } | ||||
2134 | } | ||||
2135 | |||||
2136 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { | ||||
2137 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), | ||||
2138 | tensorflow::TensorShape({})); | ||||
2139 | } else { | ||||
2140 | return PyTapeTensor(id, dtype, tensor_shape); | ||||
2141 | } | ||||
2142 | } | ||||
2143 | int64_t id = FastTensorId(tensor); | ||||
2144 | if (PyErr_Occurred()) { | ||||
2145 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), | ||||
2146 | tensorflow::TensorShape({})); | ||||
2147 | } | ||||
2148 | PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype"); | ||||
2149 | PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum"); | ||||
2150 | Py_DECREF(dtype_object)_Py_DECREF(((PyObject*)(dtype_object))); | ||||
2151 | tensorflow::DataType dtype = | ||||
2152 | static_cast<tensorflow::DataType>(MakeInt(dtype_enum)); | ||||
2153 | Py_DECREF(dtype_enum)_Py_DECREF(((PyObject*)(dtype_enum))); | ||||
2154 | if (PyErr_Occurred()) { | ||||
2155 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), | ||||
2156 | tensorflow::TensorShape({})); | ||||
2157 | } | ||||
2158 | static char _shape_tuple[] = "_shape_tuple"; | ||||
2159 | tensorflow::Safe_PyObjectPtr shape_tuple( | ||||
2160 | PyObject_CallMethod(tensor, _shape_tuple, nullptr)); | ||||
2161 | if (PyErr_Occurred()) { | ||||
2162 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), | ||||
2163 | tensorflow::TensorShape({})); | ||||
2164 | } | ||||
2165 | |||||
2166 | if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) { | ||||
2167 | return PyTapeTensor(id, dtype, tensor); | ||||
2168 | } | ||||
2169 | |||||
2170 | auto l = MakeIntList(shape_tuple.get()); | ||||
2171 | // Replace -1, which represents accidental Nones which can occur in graph mode | ||||
2172 | // and can cause errors in shape construction with 0s. | ||||
2173 | for (auto& c : l) { | ||||
2174 | if (c < 0) { | ||||
2175 | c = 0; | ||||
2176 | } | ||||
2177 | } | ||||
2178 | tensorflow::TensorShape shape(l); | ||||
2179 | return PyTapeTensor(id, dtype, shape); | ||||
2180 | } | ||||
2181 | |||||
2182 | // Populates output_info from output_seq, which must come from PySequence_Fast. | ||||
2183 | // | ||||
2184 | // Does not take ownership of output_seq. Returns true on success and false if a | ||||
2185 | // Python exception has been set. | ||||
2186 | bool TapeTensorsFromTensorSequence(PyObject* output_seq, | ||||
2187 | std::vector<PyTapeTensor>* output_info) { | ||||
2188 | Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq)(((((((PyObject*)(output_seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(output_seq))->ob_size)) : (((PyVarObject *)(((static_cast<void> (0)), (PyTupleObject *)(output_seq ))))->ob_size)); | ||||
2189 | PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq)(((((((PyObject*)(output_seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(output_seq))-> ob_item : ((PyTupleObject *)(output_seq))->ob_item); | ||||
2190 | output_info->reserve(output_len); | ||||
2191 | for (Py_ssize_t i = 0; i < output_len; ++i) { | ||||
2192 | output_info->push_back(TapeTensorFromTensor(output_seq_array[i])); | ||||
2193 | if (PyErr_Occurred() != nullptr) { | ||||
2194 | return false; | ||||
2195 | } | ||||
2196 | } | ||||
2197 | return true; | ||||
2198 | } | ||||
2199 | |||||
2200 | std::vector<int64_t> MakeTensorIDList(PyObject* tensors) { | ||||
2201 | PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); | ||||
2202 | if (seq == nullptr) { | ||||
2203 | return {}; | ||||
2204 | } | ||||
2205 | int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size)); | ||||
2206 | PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item); | ||||
2207 | std::vector<int64_t> list; | ||||
2208 | list.reserve(len); | ||||
2209 | for (int i = 0; i < len; ++i) { | ||||
2210 | PyObject* tensor = seq_array[i]; | ||||
2211 | list.push_back(FastTensorId(tensor)); | ||||
2212 | if (PyErr_Occurred()) { | ||||
2213 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | ||||
2214 | return list; | ||||
2215 | } | ||||
2216 | } | ||||
2217 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | ||||
2218 | return list; | ||||
2219 | } | ||||
2220 | |||||
2221 | void TFE_Py_TapeVariableAccessed(PyObject* variable) { | ||||
2222 | if (!CouldBackprop()) { | ||||
2223 | return; | ||||
2224 | } | ||||
2225 | for (TFE_Py_Tape* tape : SafeTapeSet()) { | ||||
2226 | tape->tape->VariableAccessed(variable); | ||||
2227 | } | ||||
2228 | } | ||||
2229 | |||||
2230 | void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) { | ||||
2231 | if (!CouldBackprop()) { | ||||
2232 | return; | ||||
2233 | } | ||||
2234 | reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable); | ||||
2235 | } | ||||
2236 | |||||
2237 | PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { | ||||
2238 | return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple(); | ||||
2239 | } | ||||
2240 | |||||
2241 | PyObject* TFE_Py_VariableWatcherNew() { | ||||
2242 | TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew; | ||||
2243 | if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr; | ||||
2244 | TFE_Py_VariableWatcher* variable_watcher = | ||||
2245 | PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type)( (TFE_Py_VariableWatcher *) PyObject_Init( (PyObject *) PyObject_Malloc ( ( (&TFE_Py_VariableWatcher_Type)->tp_basicsize ) ), ( &TFE_Py_VariableWatcher_Type)) ); | ||||
2246 | variable_watcher->variable_watcher = new VariableWatcher(); | ||||
2247 | Py_INCREF(variable_watcher)_Py_INCREF(((PyObject*)(variable_watcher))); | ||||
2248 | GetVariableWatcherSet()->insert(variable_watcher); | ||||
2249 | return reinterpret_cast<PyObject*>(variable_watcher); | ||||
2250 | } | ||||
2251 | |||||
2252 | void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) { | ||||
2253 | auto* stack = GetVariableWatcherSet(); | ||||
2254 | stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)); | ||||
2255 | // We kept a reference to the variable watcher in the set to ensure it | ||||
2256 | // wouldn't get deleted under us; cleaning it up here. | ||||
2257 | Py_DECREF(variable_watcher)_Py_DECREF(((PyObject*)(variable_watcher))); | ||||
2258 | } | ||||
2259 | |||||
2260 | void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) { | ||||
2261 | for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) { | ||||
2262 | variable_watcher->variable_watcher->WatchVariable(variable); | ||||
2263 | } | ||||
2264 | } | ||||
2265 | |||||
2266 | PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) { | ||||
2267 | return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher) | ||||
2268 | ->variable_watcher->GetVariablesAsPyTuple(); | ||||
2269 | } | ||||
2270 | |||||
2271 | namespace { | ||||
2272 | std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) { | ||||
2273 | PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); | ||||
2274 | if (seq == nullptr) { | ||||
2275 | return {}; | ||||
2276 | } | ||||
2277 | int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size)); | ||||
2278 | PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item); | ||||
2279 | std::vector<tensorflow::DataType> list; | ||||
2280 | list.reserve(len); | ||||
2281 | for (int i = 0; i < len; ++i) { | ||||
2282 | PyObject* tensor = seq_array[i]; | ||||
2283 | list.push_back(tensorflow::PyTensor_DataType(tensor)); | ||||
2284 | } | ||||
2285 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | ||||
2286 | return list; | ||||
2287 | } | ||||
2288 | |||||
2289 | PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id, | ||||
2290 | PyObject* weak_tensor_ref) { | ||||
2291 | int64_t parsed_tensor_id = MakeInt(tensor_id); | ||||
2292 | for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) { | ||||
2293 | accumulator->accumulator->DeleteGradient(parsed_tensor_id); | ||||
2294 | } | ||||
2295 | Py_DECREF(weak_tensor_ref)_Py_DECREF(((PyObject*)(weak_tensor_ref))); | ||||
2296 | Py_DECREF(tensor_id)_Py_DECREF(((PyObject*)(tensor_id))); | ||||
2297 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
2298 | return Py_None(&_Py_NoneStruct); | ||||
2299 | } | ||||
2300 | |||||
2301 | static PyMethodDef forward_accumulator_delete_gradient_method_def = { | ||||
2302 | "ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient, | ||||
2303 | METH_O0x0008, "ForwardAccumulatorDeleteGradient"}; | ||||
2304 | |||||
2305 | void RegisterForwardAccumulatorCleanup(PyObject* tensor, int64_t tensor_id) { | ||||
2306 | tensorflow::Safe_PyObjectPtr callback( | ||||
2307 | PyCFunction_New(&forward_accumulator_delete_gradient_method_def,PyCFunction_NewEx((&forward_accumulator_delete_gradient_method_def ), (PyLong_FromLong(tensor_id)), __null) | ||||
2308 | PyLong_FromLong(tensor_id))PyCFunction_NewEx((&forward_accumulator_delete_gradient_method_def ), (PyLong_FromLong(tensor_id)), __null)); | ||||
2309 | // We need to keep a reference to the weakref active if we want our callback | ||||
2310 | // called. The callback itself now owns the weakref object and the tensor ID | ||||
2311 | // object. | ||||
2312 | PyWeakref_NewRef(tensor, callback.get()); | ||||
2313 | } | ||||
2314 | |||||
2315 | void TapeSetRecordBackprop( | ||||
2316 | const string& op_type, const std::vector<PyTapeTensor>& output_info, | ||||
2317 | const std::vector<int64_t>& input_ids, | ||||
2318 | const std::vector<tensorflow::DataType>& input_dtypes, | ||||
2319 | const std::function<PyBackwardFunction*()>& backward_function_getter, | ||||
2320 | const std::function<void(PyBackwardFunction*)>& backward_function_killer, | ||||
2321 | tensorflow::uint64 max_gradient_tape_id) { | ||||
2322 | if (!CouldBackprop()) { | ||||
2323 | return; | ||||
2324 | } | ||||
2325 | for (TFE_Py_Tape* tape : SafeTapeSet()) { | ||||
2326 | if (tape->nesting_id < max_gradient_tape_id) { | ||||
2327 | tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes, | ||||
2328 | backward_function_getter, | ||||
2329 | backward_function_killer); | ||||
2330 | } | ||||
2331 | } | ||||
2332 | } | ||||
2333 | |||||
2334 | bool TapeSetRecordForwardprop( | ||||
2335 | const string& op_type, PyObject* output_seq, | ||||
2336 | const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors, | ||||
2337 | const std::vector<int64_t>& input_ids, | ||||
2338 | const std::vector<tensorflow::DataType>& input_dtypes, | ||||
2339 | const std::function<PyBackwardFunction*()>& backward_function_getter, | ||||
2340 | const std::function<void(PyBackwardFunction*)>& backward_function_killer, | ||||
2341 | const tensorflow::eager::ForwardFunction<PyObject>* forward_function, | ||||
2342 | PyObject* forwardprop_output_indices, | ||||
2343 | tensorflow::uint64* max_gradient_tape_id) { | ||||
2344 | *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max(); | ||||
2345 | if (!CouldForwardprop()) { | ||||
2346 | return true; | ||||
2347 | } | ||||
2348 | auto accumulator_set = SafeAccumulatorSet(); | ||||
2349 | tensorflow::Safe_PyObjectPtr input_seq( | ||||
2350 | PySequence_Fast(input_tensors, "expected a sequence of tensors")); | ||||
2351 | if (input_seq == nullptr || PyErr_Occurred()) return false; | ||||
2352 | Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get())(((((((PyObject*)(input_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(input_seq.get()))->ob_size)) : ((( PyVarObject*)(((static_cast<void> (0)), (PyTupleObject * )(input_seq.get()))))->ob_size)); | ||||
2353 | PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq)(((((((PyObject*)(output_seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(output_seq))-> ob_item : ((PyTupleObject *)(output_seq))->ob_item); | ||||
2354 | for (int i = 0; i < output_info.size(); ++i) { | ||||
2355 | RegisterForwardAccumulatorCleanup(output_seq_array[i], | ||||
2356 | output_info[i].GetID()); | ||||
2357 | } | ||||
2358 | if (forwardprop_output_indices != nullptr && | ||||
2359 | forwardprop_output_indices != Py_None(&_Py_NoneStruct)) { | ||||
2360 | tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast( | ||||
2361 | forwardprop_output_indices, "Expected a sequence of indices")); | ||||
2362 | if (indices_fast == nullptr || PyErr_Occurred()) { | ||||
2363 | return false; | ||||
2364 | } | ||||
2365 | if (PySequence_Fast_GET_SIZE(indices_fast.get())(((((((PyObject*)(indices_fast.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(indices_fast.get()))->ob_size)) : ( ((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(indices_fast.get()))))->ob_size)) != | ||||
2366 | accumulator_set.size()) { | ||||
2367 | MaybeRaiseExceptionFromStatus( | ||||
2368 | tensorflow::errors::Internal( | ||||
2369 | "Accumulators were added or removed from the active set " | ||||
2370 | "between packing and unpacking."), | ||||
2371 | nullptr); | ||||
2372 | } | ||||
2373 | PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get())(((((((PyObject*)(indices_fast.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(indices_fast .get()))->ob_item : ((PyTupleObject *)(indices_fast.get()) )->ob_item); | ||||
2374 | Py_ssize_t accumulator_index = 0; | ||||
2375 | for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin(); | ||||
2376 | it != accumulator_set.rend(); ++it, ++accumulator_index) { | ||||
2377 | tensorflow::Safe_PyObjectPtr jvp_index_seq( | ||||
2378 | PySequence_Fast(indices_fast_array[accumulator_index], | ||||
2379 | "Expected a sequence of jvp indices.")); | ||||
2380 | if (jvp_index_seq == nullptr || PyErr_Occurred()) { | ||||
2381 | return false; | ||||
2382 | } | ||||
2383 | Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get())(((((((PyObject*)(jvp_index_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(jvp_index_seq.get()))->ob_size)) : (((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(jvp_index_seq.get()))))->ob_size)); | ||||
2384 | PyObject** jvp_index_seq_array = | ||||
2385 | PySequence_Fast_ITEMS(jvp_index_seq.get())(((((((PyObject*)(jvp_index_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(jvp_index_seq .get()))->ob_item : ((PyTupleObject *)(jvp_index_seq.get() ))->ob_item); | ||||
2386 | for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) { | ||||
2387 | PyObject* tuple = jvp_index_seq_array[jvp_index]; | ||||
2388 | int64_t primal_tensor_id = | ||||
2389 | output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID(); | ||||
2390 | (*it)->accumulator->Watch( | ||||
2391 | primal_tensor_id, | ||||
2392 | output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]); | ||||
2393 | } | ||||
2394 | } | ||||
2395 | } else { | ||||
2396 | std::vector<PyTapeTensor> input_info; | ||||
2397 | input_info.reserve(input_len); | ||||
2398 | PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get())(((((((PyObject*)(input_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(input_seq .get()))->ob_item : ((PyTupleObject *)(input_seq.get()))-> ob_item); | ||||
2399 | for (Py_ssize_t i = 0; i < input_len; ++i) { | ||||
2400 | input_info.push_back(TapeTensorFromTensor(input_seq_array[i])); | ||||
2401 | } | ||||
2402 | for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) { | ||||
2403 | tensorflow::Status status = accumulator->accumulator->Accumulate( | ||||
2404 | op_type, input_info, output_info, input_ids, input_dtypes, | ||||
2405 | forward_function, backward_function_getter, backward_function_killer); | ||||
2406 | if (PyErr_Occurred()) return false; // Don't swallow Python exceptions. | ||||
2407 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { | ||||
2408 | return false; | ||||
2409 | } | ||||
2410 | if (accumulator->accumulator->BusyAccumulating()) { | ||||
2411 | // Ensure inner accumulators don't see outer accumulators' jvps. This | ||||
2412 | // mostly happens on its own, with some potentially surprising | ||||
2413 | // exceptions, so the blanket policy is for consistency. | ||||
2414 | *max_gradient_tape_id = accumulator->nesting_id; | ||||
2415 | break; | ||||
2416 | } | ||||
2417 | } | ||||
2418 | } | ||||
2419 | return true; | ||||
2420 | } | ||||
2421 | |||||
2422 | PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) { | ||||
2423 | PyObject* py_input_tangents = PyTuple_New(input_tangents.size()); | ||||
2424 | for (int i = 0; i < input_tangents.size(); ++i) { | ||||
2425 | PyObject* element; | ||||
2426 | if (input_tangents[i] == nullptr) { | ||||
2427 | element = Py_None(&_Py_NoneStruct); | ||||
2428 | } else { | ||||
2429 | element = input_tangents[i]; | ||||
2430 | } | ||||
2431 | Py_INCREF(element)_Py_INCREF(((PyObject*)(element))); | ||||
2432 | PyTuple_SET_ITEM(py_input_tangents, i, element)PyTuple_SetItem(py_input_tangents, i, element); | ||||
2433 | } | ||||
2434 | return py_input_tangents; | ||||
2435 | } | ||||
2436 | |||||
2437 | tensorflow::Status ParseTangentOutputs( | ||||
2438 | PyObject* user_output, std::vector<PyObject*>* output_tangents) { | ||||
2439 | if (user_output == Py_None(&_Py_NoneStruct)) { | ||||
2440 | // No connected gradients. | ||||
2441 | return tensorflow::Status::OK(); | ||||
2442 | } | ||||
2443 | tensorflow::Safe_PyObjectPtr fast_result( | ||||
2444 | PySequence_Fast(user_output, "expected a sequence of forward gradients")); | ||||
2445 | if (fast_result == nullptr) { | ||||
2446 | return tensorflow::errors::InvalidArgument( | ||||
2447 | "forward gradient function did not return a sequence."); | ||||
2448 | } | ||||
2449 | int len = PySequence_Fast_GET_SIZE(fast_result.get())(((((((PyObject*)(fast_result.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_result.get()))->ob_size)) : ( ((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(fast_result.get()))))->ob_size)); | ||||
2450 | PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get())(((((((PyObject*)(fast_result.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_result .get()))->ob_item : ((PyTupleObject *)(fast_result.get())) ->ob_item); | ||||
2451 | output_tangents->reserve(len); | ||||
2452 | for (int i = 0; i < len; ++i) { | ||||
2453 | PyObject* item = fast_result_array[i]; | ||||
2454 | if (item == Py_None(&_Py_NoneStruct)) { | ||||
2455 | output_tangents->push_back(nullptr); | ||||
2456 | } else { | ||||
2457 | Py_INCREF(item)_Py_INCREF(((PyObject*)(item))); | ||||
2458 | output_tangents->push_back(item); | ||||
2459 | } | ||||
2460 | } | ||||
2461 | return tensorflow::Status::OK(); | ||||
2462 | } | ||||
2463 | |||||
2464 | // Calls the registered forward_gradient_function, computing `output_tangents` | ||||
2465 | // from `input_tangents`. `output_tangents` must not be null. | ||||
2466 | // | ||||
2467 | // `op_name`, `attrs`, `inputs`, and `results` describe the operation for which | ||||
2468 | // the forward function is being called. | ||||
2469 | tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs, | ||||
2470 | PyObject* inputs, PyObject* results, | ||||
2471 | const std::vector<PyObject*>& input_tangents, | ||||
2472 | std::vector<PyObject*>* output_tangents, | ||||
2473 | bool use_batch) { | ||||
2474 | if (forward_gradient_function == nullptr) { | ||||
2475 | return tensorflow::errors::Internal( | ||||
2476 | "No forward gradient function registered."); | ||||
2477 | } | ||||
2478 | tensorflow::Safe_PyObjectPtr py_input_tangents( | ||||
2479 | TangentsAsPyTuple(input_tangents)); | ||||
2480 | |||||
2481 | // Normalize the input sequence to a tuple so it works with function | ||||
2482 | // caching; otherwise it may be an opaque _InputList object. | ||||
2483 | tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs)); | ||||
2484 | PyObject* to_batch = (use_batch) ? Py_True((PyObject *) &_Py_TrueStruct) : Py_False((PyObject *) &_Py_FalseStruct); | ||||
2485 | tensorflow::Safe_PyObjectPtr callback_args( | ||||
2486 | Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results, | ||||
2487 | py_input_tangents.get(), to_batch)); | ||||
2488 | tensorflow::Safe_PyObjectPtr py_result( | ||||
2489 | PyObject_CallObject(forward_gradient_function, callback_args.get())); | ||||
2490 | if (py_result == nullptr || PyErr_Occurred()) { | ||||
2491 | return tensorflow::errors::Internal( | ||||
2492 | "forward gradient function threw exceptions"); | ||||
2493 | } | ||||
2494 | return ParseTangentOutputs(py_result.get(), output_tangents); | ||||
2495 | } | ||||
2496 | |||||
2497 | // Like CallJVPFunction, but calls a pre-bound forward function. | ||||
2498 | // These are passed in from a record_gradient argument. | ||||
2499 | tensorflow::Status CallOpSpecificJVPFunction( | ||||
2500 | PyObject* op_specific_forward_function, | ||||
2501 | const std::vector<PyObject*>& input_tangents, | ||||
2502 | std::vector<PyObject*>* output_tangents) { | ||||
2503 | tensorflow::Safe_PyObjectPtr py_input_tangents( | ||||
2504 | TangentsAsPyTuple(input_tangents)); | ||||
2505 | |||||
2506 | tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject( | ||||
2507 | op_specific_forward_function, py_input_tangents.get())); | ||||
2508 | if (py_result == nullptr || PyErr_Occurred()) { | ||||
2509 | return tensorflow::errors::Internal( | ||||
2510 | "forward gradient function threw exceptions"); | ||||
2511 | } | ||||
2512 | return ParseTangentOutputs(py_result.get(), output_tangents); | ||||
2513 | } | ||||
2514 | |||||
2515 | bool ParseOpTypeString(PyObject* op_type, string* op_type_string) { | ||||
2516 | if (PyBytes_Check(op_type)((((((PyObject*)(op_type))->ob_type))->tp_flags & ( (1UL << 27))) != 0)) { | ||||
2517 | *op_type_string = PyBytes_AsString(op_type); | ||||
2518 | } else if (PyUnicode_Check(op_type)((((((PyObject*)(op_type))->ob_type))->tp_flags & ( (1UL << 28))) != 0)) { | ||||
2519 | #if PY_MAJOR_VERSION3 >= 3 | ||||
2520 | *op_type_string = PyUnicode_AsUTF8(op_type); | ||||
2521 | #else | ||||
2522 | PyObject* py_str = PyUnicode_AsUTF8String(op_type); | ||||
2523 | if (py_str == nullptr) { | ||||
2524 | return false; | ||||
2525 | } | ||||
2526 | *op_type_string = PyBytes_AS_STRING(py_str)((static_cast<void> (0)), (((PyBytesObject *)(py_str))-> ob_sval)); | ||||
2527 | Py_DECREF(py_str)_Py_DECREF(((PyObject*)(py_str))); | ||||
2528 | #endif | ||||
2529 | } else { | ||||
2530 | PyErr_SetString(PyExc_RuntimeError, "op_type should be a string."); | ||||
2531 | return false; | ||||
2532 | } | ||||
2533 | return true; | ||||
2534 | } | ||||
2535 | |||||
2536 | bool TapeSetRecordOperation( | ||||
2537 | PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors, | ||||
2538 | const std::vector<int64_t>& input_ids, | ||||
2539 | const std::vector<tensorflow::DataType>& input_dtypes, | ||||
2540 | const std::function<PyBackwardFunction*()>& backward_function_getter, | ||||
2541 | const std::function<void(PyBackwardFunction*)>& backward_function_killer, | ||||
2542 | const tensorflow::eager::ForwardFunction<PyObject>* forward_function) { | ||||
2543 | std::vector<PyTapeTensor> output_info; | ||||
2544 | tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( | ||||
2545 | output_tensors, "expected a sequence of integer tensor ids")); | ||||
2546 | if (PyErr_Occurred() || | ||||
2547 | !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { | ||||
2548 | return false; | ||||
2549 | } | ||||
2550 | string op_type_str; | ||||
2551 | if (!ParseOpTypeString(op_type, &op_type_str)) { | ||||
2552 | return false; | ||||
2553 | } | ||||
2554 | tensorflow::uint64 max_gradient_tape_id; | ||||
2555 | if (!TapeSetRecordForwardprop( | ||||
2556 | op_type_str, output_seq.get(), output_info, input_tensors, input_ids, | ||||
2557 | input_dtypes, backward_function_getter, backward_function_killer, | ||||
2558 | forward_function, nullptr /* No special-cased jvps. */, | ||||
2559 | &max_gradient_tape_id)) { | ||||
2560 | return false; | ||||
2561 | } | ||||
2562 | TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes, | ||||
2563 | backward_function_getter, backward_function_killer, | ||||
2564 | max_gradient_tape_id); | ||||
2565 | return true; | ||||
2566 | } | ||||
2567 | } // namespace | ||||
2568 | |||||
2569 | PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type, | ||||
2570 | PyObject* output_tensors, | ||||
2571 | PyObject* input_tensors, | ||||
2572 | PyObject* backward_function, | ||||
2573 | PyObject* forward_function) { | ||||
2574 | if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) { | ||||
2575 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
2576 | } | ||||
2577 | std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors); | ||||
2578 | if (PyErr_Occurred()) return nullptr; | ||||
2579 | |||||
2580 | std::vector<tensorflow::DataType> input_dtypes = | ||||
2581 | MakeTensorDtypeList(input_tensors); | ||||
2582 | if (PyErr_Occurred()) return nullptr; | ||||
2583 | |||||
2584 | std::function<PyBackwardFunction*()> backward_function_getter( | ||||
2585 | [backward_function]() { | ||||
2586 | Py_INCREF(backward_function)_Py_INCREF(((PyObject*)(backward_function))); | ||||
2587 | PyBackwardFunction* function = new PyBackwardFunction( | ||||
2588 | [backward_function](PyObject* out_grads, | ||||
2589 | const std::vector<int64_t>& unused) { | ||||
2590 | return PyObject_CallObject(backward_function, out_grads); | ||||
2591 | }); | ||||
2592 | return function; | ||||
2593 | }); | ||||
2594 | std::function<void(PyBackwardFunction*)> backward_function_killer( | ||||
2595 | [backward_function](PyBackwardFunction* py_backward_function) { | ||||
2596 | Py_DECREF(backward_function)_Py_DECREF(((PyObject*)(backward_function))); | ||||
2597 | delete py_backward_function; | ||||
2598 | }); | ||||
2599 | |||||
2600 | if (forward_function == Py_None(&_Py_NoneStruct)) { | ||||
2601 | if (!TapeSetRecordOperation( | ||||
2602 | op_type, input_tensors, output_tensors, input_ids, input_dtypes, | ||||
2603 | backward_function_getter, backward_function_killer, | ||||
2604 | nullptr /* No special-cased forward function */)) { | ||||
2605 | return nullptr; | ||||
2606 | } | ||||
2607 | } else { | ||||
2608 | tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function( | ||||
2609 | [forward_function](const std::vector<PyObject*>& input_tangents, | ||||
2610 | std::vector<PyObject*>* output_tangents, | ||||
2611 | bool use_batch = false) { | ||||
2612 | return CallOpSpecificJVPFunction(forward_function, input_tangents, | ||||
2613 | output_tangents); | ||||
2614 | }); | ||||
2615 | if (!TapeSetRecordOperation( | ||||
2616 | op_type, input_tensors, output_tensors, input_ids, input_dtypes, | ||||
2617 | backward_function_getter, backward_function_killer, | ||||
2618 | &wrapped_forward_function)) { | ||||
2619 | return nullptr; | ||||
2620 | } | ||||
2621 | } | ||||
2622 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
2623 | } | ||||
2624 | |||||
2625 | PyObject* TFE_Py_TapeSetRecordOperationForwardprop( | ||||
2626 | PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors, | ||||
2627 | PyObject* backward_function, PyObject* forwardprop_output_indices) { | ||||
2628 | if (!HasAccumulator() || *ThreadTapeIsStopped()) { | ||||
2629 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
2630 | } | ||||
2631 | std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors); | ||||
2632 | if (PyErr_Occurred()) return nullptr; | ||||
2633 | |||||
2634 | std::vector<tensorflow::DataType> input_dtypes = | ||||
2635 | MakeTensorDtypeList(input_tensors); | ||||
2636 | if (PyErr_Occurred()) return nullptr; | ||||
2637 | |||||
2638 | std::function<PyBackwardFunction*()> backward_function_getter( | ||||
2639 | [backward_function]() { | ||||
2640 | Py_INCREF(backward_function)_Py_INCREF(((PyObject*)(backward_function))); | ||||
2641 | PyBackwardFunction* function = new PyBackwardFunction( | ||||
2642 | [backward_function](PyObject* out_grads, | ||||
2643 | const std::vector<int64_t>& unused) { | ||||
2644 | return PyObject_CallObject(backward_function, out_grads); | ||||
2645 | }); | ||||
2646 | return function; | ||||
2647 | }); | ||||
2648 | std::function<void(PyBackwardFunction*)> backward_function_killer( | ||||
2649 | [backward_function](PyBackwardFunction* py_backward_function) { | ||||
2650 | Py_DECREF(backward_function)_Py_DECREF(((PyObject*)(backward_function))); | ||||
2651 | delete py_backward_function; | ||||
2652 | }); | ||||
2653 | std::vector<PyTapeTensor> output_info; | ||||
2654 | tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( | ||||
2655 | output_tensors, "expected a sequence of integer tensor ids")); | ||||
2656 | if (PyErr_Occurred() || | ||||
2657 | !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { | ||||
2658 | return nullptr; | ||||
2659 | } | ||||
2660 | string op_type_str; | ||||
2661 | if (!ParseOpTypeString(op_type, &op_type_str)) { | ||||
2662 | return nullptr; | ||||
2663 | } | ||||
2664 | tensorflow::uint64 max_gradient_tape_id; | ||||
2665 | if (!TapeSetRecordForwardprop( | ||||
2666 | op_type_str, output_seq.get(), output_info, input_tensors, input_ids, | ||||
2667 | input_dtypes, backward_function_getter, backward_function_killer, | ||||
2668 | nullptr /* no special-cased forward function */, | ||||
2669 | forwardprop_output_indices, &max_gradient_tape_id)) { | ||||
2670 | return nullptr; | ||||
2671 | } | ||||
2672 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
2673 | } | ||||
2674 | |||||
2675 | PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type, | ||||
2676 | PyObject* output_tensors, | ||||
2677 | PyObject* input_tensors, | ||||
2678 | PyObject* backward_function) { | ||||
2679 | if (!CouldBackprop()) { | ||||
2680 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
2681 | } | ||||
2682 | std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors); | ||||
2683 | if (PyErr_Occurred()) return nullptr; | ||||
2684 | |||||
2685 | std::vector<tensorflow::DataType> input_dtypes = | ||||
2686 | MakeTensorDtypeList(input_tensors); | ||||
2687 | if (PyErr_Occurred()) return nullptr; | ||||
2688 | |||||
2689 | std::function<PyBackwardFunction*()> backward_function_getter( | ||||
2690 | [backward_function]() { | ||||
2691 | Py_INCREF(backward_function)_Py_INCREF(((PyObject*)(backward_function))); | ||||
2692 | PyBackwardFunction* function = new PyBackwardFunction( | ||||
2693 | [backward_function](PyObject* out_grads, | ||||
2694 | const std::vector<int64_t>& unused) { | ||||
2695 | return PyObject_CallObject(backward_function, out_grads); | ||||
2696 | }); | ||||
2697 | return function; | ||||
2698 | }); | ||||
2699 | std::function<void(PyBackwardFunction*)> backward_function_killer( | ||||
2700 | [backward_function](PyBackwardFunction* py_backward_function) { | ||||
2701 | Py_DECREF(backward_function)_Py_DECREF(((PyObject*)(backward_function))); | ||||
2702 | delete py_backward_function; | ||||
2703 | }); | ||||
2704 | std::vector<PyTapeTensor> output_info; | ||||
2705 | tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( | ||||
2706 | output_tensors, "expected a sequence of integer tensor ids")); | ||||
2707 | if (PyErr_Occurred() || | ||||
2708 | !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { | ||||
2709 | return nullptr; | ||||
2710 | } | ||||
2711 | string op_type_str; | ||||
2712 | if (!ParseOpTypeString(op_type, &op_type_str)) { | ||||
2713 | return nullptr; | ||||
2714 | } | ||||
2715 | TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes, | ||||
2716 | backward_function_getter, backward_function_killer, | ||||
2717 | // No filtering based on relative ordering with forward | ||||
2718 | // accumulators. | ||||
2719 | std::numeric_limits<tensorflow::uint64>::max()); | ||||
2720 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
2721 | } | ||||
2722 | |||||
2723 | void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id) { | ||||
2724 | for (TFE_Py_Tape* tape : *GetTapeSet()) { | ||||
2725 | tape->tape->DeleteTrace(tensor_id); | ||||
2726 | } | ||||
2727 | } | ||||
2728 | |||||
2729 | std::vector<PyObject*> MakeTensorList(PyObject* tensors) { | ||||
2730 | PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); | ||||
2731 | if (seq == nullptr) { | ||||
2732 | return {}; | ||||
2733 | } | ||||
2734 | int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size)); | ||||
2735 | PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item); | ||||
2736 | std::vector<PyObject*> list(seq_array, seq_array + len); | ||||
2737 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | ||||
2738 | return list; | ||||
2739 | } | ||||
2740 | |||||
2741 | PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target, | ||||
2742 | PyObject* sources, PyObject* output_gradients, | ||||
2743 | PyObject* sources_raw, | ||||
2744 | PyObject* unconnected_gradients, | ||||
2745 | TF_Status* status) { | ||||
2746 | TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape); | ||||
2747 | if (!tape_obj->tape->IsPersistent()) { | ||||
2748 | auto* tape_set = GetTapeSet(); | ||||
2749 | if (tape_set->find(tape_obj) != tape_set->end()) { | ||||
2750 | PyErr_SetString(PyExc_RuntimeError, | ||||
2751 | "gradient() cannot be invoked within the " | ||||
2752 | "GradientTape context (i.e., while operations are being " | ||||
2753 | "recorded). Either move the call to gradient() to be " | ||||
2754 | "outside the 'with tf.GradientTape' block, or " | ||||
2755 | "use a persistent tape: " | ||||
2756 | "'with tf.GradientTape(persistent=true)'"); | ||||
2757 | return nullptr; | ||||
2758 | } | ||||
2759 | } | ||||
2760 | |||||
2761 | std::vector<int64_t> target_vec = MakeTensorIDList(target); | ||||
2762 | if (PyErr_Occurred()) { | ||||
2763 | return nullptr; | ||||
2764 | } | ||||
2765 | std::vector<int64_t> sources_vec = MakeTensorIDList(sources); | ||||
2766 | if (PyErr_Occurred()) { | ||||
2767 | return nullptr; | ||||
2768 | } | ||||
2769 | tensorflow::gtl::FlatSet<int64_t> sources_set(sources_vec.begin(), | ||||
2770 | sources_vec.end()); | ||||
2771 | |||||
2772 | tensorflow::Safe_PyObjectPtr seq = | ||||
2773 | tensorflow::make_safe(PySequence_Fast(target, "expected a sequence")); | ||||
2774 | int len = PySequence_Fast_GET_SIZE(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(seq.get()))->ob_size)) : (((PyVarObject* )(((static_cast<void> (0)), (PyTupleObject *)(seq.get() ))))->ob_size)); | ||||
2775 | PyObject** seq_array = PySequence_Fast_ITEMS(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq.get()))-> ob_item : ((PyTupleObject *)(seq.get()))->ob_item); | ||||
2776 | std::unordered_map<int64_t, PyTapeTensor> source_tensors_that_are_targets; | ||||
2777 | for (int i = 0; i < len; ++i) { | ||||
2778 | int64_t target_id = target_vec[i]; | ||||
2779 | if (sources_set.find(target_id) != sources_set.end()) { | ||||
2780 | auto tensor = seq_array[i]; | ||||
2781 | source_tensors_that_are_targets.insert( | ||||
2782 | std::make_pair(target_id, TapeTensorFromTensor(tensor))); | ||||
2783 | } | ||||
2784 | if (PyErr_Occurred()) { | ||||
2785 | return nullptr; | ||||
2786 | } | ||||
2787 | } | ||||
2788 | if (PyErr_Occurred()) { | ||||
2789 | return nullptr; | ||||
2790 | } | ||||
2791 | |||||
2792 | std::vector<PyObject*> outgrad_vec; | ||||
2793 | if (output_gradients != Py_None(&_Py_NoneStruct)) { | ||||
2794 | outgrad_vec = MakeTensorList(output_gradients); | ||||
2795 | if (PyErr_Occurred()) { | ||||
2796 | return nullptr; | ||||
2797 | } | ||||
2798 | for (PyObject* tensor : outgrad_vec) { | ||||
2799 | // Calling the backward function will eat a reference to the tensors in | ||||
2800 | // outgrad_vec, so we need to increase their reference count. | ||||
2801 | Py_INCREF(tensor)_Py_INCREF(((PyObject*)(tensor))); | ||||
2802 | } | ||||
2803 | } | ||||
2804 | std::vector<PyObject*> result(sources_vec.size()); | ||||
2805 | status->status = tape_obj->tape->ComputeGradient( | ||||
2806 | *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets, | ||||
2807 | outgrad_vec, absl::MakeSpan(result)); | ||||
2808 | if (!status->status.ok()) { | ||||
2809 | if (PyErr_Occurred()) { | ||||
2810 | // Do not propagate the erroneous status as that would swallow the | ||||
2811 | // exception which caused the problem. | ||||
2812 | status->status = tensorflow::Status::OK(); | ||||
2813 | } | ||||
2814 | return nullptr; | ||||
2815 | } | ||||
2816 | |||||
2817 | bool unconnected_gradients_zero = | ||||
2818 | strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0; | ||||
2819 | std::vector<PyObject*> sources_obj; | ||||
2820 | if (unconnected_gradients_zero) { | ||||
2821 | // Uses the "raw" sources here so it can properly make a zeros tensor even | ||||
2822 | // if there are resource variables as sources. | ||||
2823 | sources_obj = MakeTensorList(sources_raw); | ||||
2824 | } | ||||
2825 | |||||
2826 | if (!result.empty()) { | ||||
2827 | PyObject* py_result = PyList_New(result.size()); | ||||
2828 | tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size()); | ||||
2829 | for (int i = 0; i < result.size(); ++i) { | ||||
2830 | if (result[i] == nullptr) { | ||||
2831 | if (unconnected_gradients_zero) { | ||||
2832 | // generate a zeros tensor in the shape of sources[i] | ||||
2833 | tensorflow::DataType dtype = | ||||
2834 | tensorflow::PyTensor_DataType(sources_obj[i]); | ||||
2835 | PyTapeTensor tensor = | ||||
2836 | PyTapeTensor(sources_vec[i], dtype, sources_obj[i]); | ||||
2837 | result[i] = tensor.ZerosLike(); | ||||
2838 | } else { | ||||
2839 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
2840 | result[i] = Py_None(&_Py_NoneStruct); | ||||
2841 | } | ||||
2842 | } else if (seen_results.find(result[i]) != seen_results.end()) { | ||||
2843 | Py_INCREF(result[i])_Py_INCREF(((PyObject*)(result[i]))); | ||||
2844 | } | ||||
2845 | seen_results.insert(result[i]); | ||||
2846 | PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]))PyList_SetItem(py_result, i, reinterpret_cast<PyObject*> (result[i])); | ||||
2847 | } | ||||
2848 | return py_result; | ||||
2849 | } | ||||
2850 | return PyList_New(0); | ||||
2851 | } | ||||
2852 | |||||
2853 | PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) { | ||||
2854 | TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew; | ||||
2855 | if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr; | ||||
2856 | TFE_Py_ForwardAccumulator* accumulator = | ||||
2857 | PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type)( (TFE_Py_ForwardAccumulator *) PyObject_Init( (PyObject *) PyObject_Malloc ( ( (&TFE_Py_ForwardAccumulator_Type)->tp_basicsize ) ) , (&TFE_Py_ForwardAccumulator_Type)) ); | ||||
2858 | if (py_vspace == nullptr) { | ||||
2859 | MaybeRaiseExceptionFromStatus( | ||||
2860 | tensorflow::errors::Internal( | ||||
2861 | "ForwardAccumulator requires a PyVSpace to be registered."), | ||||
2862 | nullptr); | ||||
2863 | } | ||||
2864 | accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch); | ||||
2865 | return reinterpret_cast<PyObject*>(accumulator); | ||||
2866 | } | ||||
2867 | |||||
2868 | PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) { | ||||
2869 | TFE_Py_ForwardAccumulator* c_accumulator( | ||||
2870 | reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)); | ||||
2871 | c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1); | ||||
2872 | if (GetAccumulatorSet()->insert(c_accumulator)) { | ||||
2873 | Py_INCREF(accumulator)_Py_INCREF(((PyObject*)(accumulator))); | ||||
2874 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
2875 | } else { | ||||
2876 | MaybeRaiseExceptionFromStatus( | ||||
2877 | tensorflow::errors::Internal( | ||||
2878 | "A ForwardAccumulator was added to the active set twice."), | ||||
2879 | nullptr); | ||||
2880 | return nullptr; | ||||
2881 | } | ||||
2882 | } | ||||
2883 | |||||
2884 | void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) { | ||||
2885 | GetAccumulatorSet()->erase( | ||||
2886 | reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)); | ||||
2887 | Py_DECREF(accumulator)_Py_DECREF(((PyObject*)(accumulator))); | ||||
2888 | } | ||||
2889 | |||||
2890 | void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor, | ||||
2891 | PyObject* tangent) { | ||||
2892 | int64_t tensor_id = FastTensorId(tensor); | ||||
2893 | reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator) | ||||
2894 | ->accumulator->Watch(tensor_id, tangent); | ||||
2895 | RegisterForwardAccumulatorCleanup(tensor, tensor_id); | ||||
2896 | } | ||||
2897 | |||||
2898 | // Returns a new reference to the JVP Tensor. | ||||
2899 | PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, | ||||
2900 | PyObject* tensor) { | ||||
2901 | PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator) | ||||
2902 | ->accumulator->FetchJVP(FastTensorId(tensor)); | ||||
2903 | if (jvp == nullptr) { | ||||
2904 | jvp = Py_None(&_Py_NoneStruct); | ||||
2905 | } | ||||
2906 | Py_INCREF(jvp)_Py_INCREF(((PyObject*)(jvp))); | ||||
2907 | return jvp; | ||||
2908 | } | ||||
2909 | |||||
2910 | PyObject* TFE_Py_PackJVPs(PyObject* tensors) { | ||||
2911 | if (!TapeCouldPossiblyRecord(tensors)) { | ||||
2912 | tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0)); | ||||
2913 | tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0)); | ||||
2914 | return PyTuple_Pack(2, empty_tuple.get(), empty_list.get()); | ||||
2915 | } | ||||
2916 | auto accumulators = *GetAccumulatorSet(); | ||||
2917 | tensorflow::Safe_PyObjectPtr tensors_fast( | ||||
2918 | PySequence_Fast(tensors, "Expected a sequence of input Tensors.")); | ||||
2919 | if (tensors_fast == nullptr || PyErr_Occurred()) { | ||||
2920 | return nullptr; | ||||
2921 | } | ||||
2922 | std::vector<int64_t> augmented_input_ids; | ||||
2923 | int len = PySequence_Fast_GET_SIZE(tensors_fast.get())(((((((PyObject*)(tensors_fast.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(tensors_fast.get()))->ob_size)) : ( ((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(tensors_fast.get()))))->ob_size)); | ||||
2924 | PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get())(((((((PyObject*)(tensors_fast.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(tensors_fast .get()))->ob_item : ((PyTupleObject *)(tensors_fast.get()) )->ob_item); | ||||
2925 | for (Py_ssize_t position = 0; position < len; ++position) { | ||||
2926 | PyObject* input = tensors_fast_array[position]; | ||||
2927 | if (input == Py_None(&_Py_NoneStruct)) { | ||||
2928 | continue; | ||||
2929 | } | ||||
2930 | tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input)); | ||||
2931 | if (input_dtype == tensorflow::DT_INVALID) { | ||||
2932 | return nullptr; | ||||
2933 | } | ||||
2934 | augmented_input_ids.push_back(FastTensorId(input)); | ||||
2935 | } | ||||
2936 | if (PyErr_Occurred()) { | ||||
2937 | return nullptr; | ||||
2938 | } | ||||
2939 | // Find the innermost accumulator such that all outer accumulators are | ||||
2940 | // recording. Any more deeply nested accumulators will not have their JVPs | ||||
2941 | // saved. | ||||
2942 | AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin(); | ||||
2943 | for (; innermost_all_recording != accumulators.end(); | ||||
2944 | ++innermost_all_recording) { | ||||
2945 | if ((*innermost_all_recording)->accumulator->BusyAccumulating()) { | ||||
2946 | break; | ||||
2947 | } | ||||
2948 | } | ||||
2949 | AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording( | ||||
2950 | innermost_all_recording); | ||||
2951 | |||||
2952 | bool saving_jvps = false; | ||||
2953 | tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size())); | ||||
2954 | std::vector<PyObject*> new_tensors; | ||||
2955 | Py_ssize_t accumulator_index = 0; | ||||
2956 | // Start with the innermost accumulators to give outer accumulators a chance | ||||
2957 | // to find their higher-order JVPs. | ||||
2958 | for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin(); | ||||
2959 | it != accumulators.rend(); ++it, ++accumulator_index) { | ||||
2960 | std::vector<int64_t> new_input_ids; | ||||
2961 | std::vector<std::pair<int64_t, int64_t>> accumulator_indices; | ||||
2962 | if (it == reverse_innermost_all_recording) { | ||||
2963 | saving_jvps = true; | ||||
2964 | } | ||||
2965 | if (saving_jvps) { | ||||
2966 | for (int input_index = 0; input_index < augmented_input_ids.size(); | ||||
2967 | ++input_index) { | ||||
2968 | int64_t existing_input = augmented_input_ids[input_index]; | ||||
2969 | PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input); | ||||
2970 | if (jvp != nullptr) { | ||||
2971 | new_tensors.push_back(jvp); | ||||
2972 | new_input_ids.push_back(FastTensorId(jvp)); | ||||
2973 | accumulator_indices.emplace_back( | ||||
2974 | input_index, | ||||
2975 | augmented_input_ids.size() + new_input_ids.size() - 1); | ||||
2976 | } | ||||
2977 | } | ||||
2978 | } | ||||
2979 | tensorflow::Safe_PyObjectPtr accumulator_indices_py( | ||||
2980 | PyTuple_New(accumulator_indices.size())); | ||||
2981 | for (int i = 0; i < accumulator_indices.size(); ++i) { | ||||
2982 | tensorflow::Safe_PyObjectPtr from_index( | ||||
2983 | GetPythonObjectFromInt(accumulator_indices[i].first)); | ||||
2984 | tensorflow::Safe_PyObjectPtr to_index( | ||||
2985 | GetPythonObjectFromInt(accumulator_indices[i].second)); | ||||
2986 | PyTuple_SetItem(accumulator_indices_py.get(), i, | ||||
2987 | PyTuple_Pack(2, from_index.get(), to_index.get())); | ||||
2988 | } | ||||
2989 | PyTuple_SetItem(all_indices.get(), accumulator_index, | ||||
2990 | accumulator_indices_py.release()); | ||||
2991 | augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(), | ||||
2992 | new_input_ids.end()); | ||||
2993 | } | ||||
2994 | |||||
2995 | tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size())); | ||||
2996 | for (int i = 0; i < new_tensors.size(); ++i) { | ||||
2997 | PyObject* jvp = new_tensors[i]; | ||||
2998 | Py_INCREF(jvp)_Py_INCREF(((PyObject*)(jvp))); | ||||
2999 | PyList_SET_ITEM(new_tensors_py.get(), i, jvp)PyList_SetItem(new_tensors_py.get(), i, jvp); | ||||
3000 | } | ||||
3001 | return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get()); | ||||
3002 | } | ||||
3003 | |||||
3004 | namespace { | ||||
3005 | |||||
3006 | // Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C. | ||||
3007 | enum FastPathExecuteArgIndex { | ||||
3008 | FAST_PATH_EXECUTE_ARG_CONTEXT = 0, | ||||
3009 | FAST_PATH_EXECUTE_ARG_OP_NAME = 1, | ||||
3010 | FAST_PATH_EXECUTE_ARG_NAME = 2, | ||||
3011 | FAST_PATH_EXECUTE_ARG_INPUT_START = 3 | ||||
3012 | }; | ||||
3013 | |||||
3014 | PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) { | ||||
3015 | #if PY_MAJOR_VERSION3 >= 3 | ||||
3016 | return PyUnicode_FromStringAndSize(s.data(), s.size()); | ||||
3017 | #else | ||||
3018 | return PyBytes_FromStringAndSize(s.data(), s.size()); | ||||
3019 | #endif | ||||
3020 | } | ||||
3021 | |||||
3022 | bool CheckResourceVariable(PyObject* item) { | ||||
3023 | if (tensorflow::swig::IsResourceVariable(item)) { | ||||
3024 | tensorflow::Safe_PyObjectPtr handle( | ||||
3025 | PyObject_GetAttrString(item, "_handle")); | ||||
3026 | return EagerTensor_CheckExact(handle.get()); | ||||
3027 | } | ||||
3028 | |||||
3029 | return false; | ||||
3030 | } | ||||
3031 | |||||
3032 | bool IsNumberType(PyObject* item) { | ||||
3033 | #if PY_MAJOR_VERSION3 >= 3 | ||||
3034 | return PyFloat_Check(item)((((PyObject*)(item))->ob_type) == (&PyFloat_Type) || PyType_IsSubtype ((((PyObject*)(item))->ob_type), (&PyFloat_Type))) || PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL << 24))) != 0); | ||||
3035 | #else | ||||
3036 | return PyFloat_Check(item)((((PyObject*)(item))->ob_type) == (&PyFloat_Type) || PyType_IsSubtype ((((PyObject*)(item))->ob_type), (&PyFloat_Type))) || PyInt_Check(item) || PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL << 24))) != 0); | ||||
3037 | #endif | ||||
3038 | } | ||||
3039 | |||||
3040 | bool CheckOneInput(PyObject* item) { | ||||
3041 | if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) || | ||||
3042 | PyArray_Check(item)((((PyObject*)(item))->ob_type) == (&(*(PyTypeObject * )_tensorflow_numpy_api[2])) || PyType_IsSubtype((((PyObject*) (item))->ob_type), (&(*(PyTypeObject *)_tensorflow_numpy_api [2])))) || IsNumberType(item)) { | ||||
3043 | return true; | ||||
3044 | } | ||||
3045 | |||||
3046 | // Sequences are not properly handled. Sequences with purely python numeric | ||||
3047 | // types work, but sequences with mixes of EagerTensors and python numeric | ||||
3048 | // types don't work. | ||||
3049 | // TODO(nareshmodi): fix | ||||
3050 | return false; | ||||
3051 | } | ||||
3052 | |||||
3053 | bool CheckInputsOk(PyObject* seq, int start_index, | ||||
3054 | const tensorflow::OpDef& op_def) { | ||||
3055 | for (int i = 0; i < op_def.input_arg_size(); i++) { | ||||
3056 | PyObject* item = PyTuple_GET_ITEM(seq, i + start_index)(((static_cast<void> (0)), (PyTupleObject *)(seq))-> ob_item[i + start_index]); | ||||
3057 | if (!op_def.input_arg(i).number_attr().empty() || | ||||
3058 | !op_def.input_arg(i).type_list_attr().empty()) { | ||||
3059 | // This item should be a seq input. | ||||
3060 | if (!PySequence_Check(item)) { | ||||
3061 | VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 3061, tensorflow::INFO) << "Falling back to slow path for Op \"" << op_def.name() | ||||
3062 | << "\", Input \"" << op_def.input_arg(i).name() | ||||
3063 | << "\" since we expected a sequence, but got " | ||||
3064 | << item->ob_type->tp_name; | ||||
3065 | return false; | ||||
3066 | } | ||||
3067 | tensorflow::Safe_PyObjectPtr fast_item( | ||||
3068 | PySequence_Fast(item, "Could not parse sequence.")); | ||||
3069 | if (fast_item.get() == nullptr) { | ||||
3070 | return false; | ||||
3071 | } | ||||
3072 | int len = PySequence_Fast_GET_SIZE(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_item.get()))->ob_size)) : ((( PyVarObject*)(((static_cast<void> (0)), (PyTupleObject * )(fast_item.get()))))->ob_size)); | ||||
3073 | PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_item .get()))->ob_item : ((PyTupleObject *)(fast_item.get()))-> ob_item); | ||||
3074 | for (Py_ssize_t j = 0; j < len; j++) { | ||||
3075 | PyObject* inner_item = fast_item_array[j]; | ||||
3076 | if (!CheckOneInput(inner_item)) { | ||||
3077 | VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 3077, tensorflow::INFO) << "Falling back to slow path for Op \"" << op_def.name() | ||||
3078 | << "\", Input \"" << op_def.input_arg(i).name() | ||||
3079 | << "\", Index " << j | ||||
3080 | << " since we expected an EagerTensor/ResourceVariable, " | ||||
3081 | "but got " | ||||
3082 | << inner_item->ob_type->tp_name; | ||||
3083 | return false; | ||||
3084 | } | ||||
3085 | } | ||||
3086 | } else if (!CheckOneInput(item)) { | ||||
3087 | VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 3087, tensorflow::INFO) | ||||
3088 | << "Falling back to slow path for Op \"" << op_def.name() | ||||
3089 | << "\", Input \"" << op_def.input_arg(i).name() | ||||
3090 | << "\" since we expected an EagerTensor/ResourceVariable, but got " | ||||
3091 | << item->ob_type->tp_name; | ||||
3092 | return false; | ||||
3093 | } | ||||
3094 | } | ||||
3095 | |||||
3096 | return true; | ||||
3097 | } | ||||
3098 | |||||
3099 | tensorflow::DataType MaybeGetDType(PyObject* item) { | ||||
3100 | if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) { | ||||
3101 | return tensorflow::PyTensor_DataType(item); | ||||
3102 | } | ||||
3103 | |||||
3104 | return tensorflow::DT_INVALID; | ||||
3105 | } | ||||
3106 | |||||
3107 | tensorflow::DataType MaybeGetDTypeForAttr(const string& attr, | ||||
3108 | FastPathOpExecInfo* op_exec_info) { | ||||
3109 | auto cached_it = op_exec_info->cached_dtypes.find(attr); | ||||
3110 | if (cached_it != op_exec_info->cached_dtypes.end()) { | ||||
3111 | return cached_it->second; | ||||
3112 | } | ||||
3113 | |||||
3114 | auto it = op_exec_info->attr_to_inputs_map->find(attr); | ||||
3115 | if (it == op_exec_info->attr_to_inputs_map->end()) { | ||||
3116 | // No other inputs - this should never happen. | ||||
3117 | return tensorflow::DT_INVALID; | ||||
3118 | } | ||||
3119 | |||||
3120 | for (const auto& input_info : it->second) { | ||||
3121 | PyObject* item = PyTuple_GET_ITEM((((static_cast<void> (0)), (PyTupleObject *)(op_exec_info ->args))->ob_item[FAST_PATH_EXECUTE_ARG_INPUT_START + input_info .i]) | ||||
3122 | op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i)(((static_cast<void> (0)), (PyTupleObject *)(op_exec_info ->args))->ob_item[FAST_PATH_EXECUTE_ARG_INPUT_START + input_info .i]); | ||||
3123 | if (input_info.is_list) { | ||||
3124 | tensorflow::Safe_PyObjectPtr fast_item( | ||||
3125 | PySequence_Fast(item, "Unable to allocate")); | ||||
3126 | int len = PySequence_Fast_GET_SIZE(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_item.get()))->ob_size)) : ((( PyVarObject*)(((static_cast<void> (0)), (PyTupleObject * )(fast_item.get()))))->ob_size)); | ||||
3127 | PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_item .get()))->ob_item : ((PyTupleObject *)(fast_item.get()))-> ob_item); | ||||
3128 | for (int i = 0; i < len; i++) { | ||||
3129 | auto dtype = MaybeGetDType(fast_item_array[i]); | ||||
3130 | if (dtype != tensorflow::DT_INVALID) return dtype; | ||||
3131 | } | ||||
3132 | } else { | ||||
3133 | auto dtype = MaybeGetDType(item); | ||||
3134 | if (dtype != tensorflow::DT_INVALID) return dtype; | ||||
3135 | } | ||||
3136 | } | ||||
3137 | |||||
3138 | auto default_it = op_exec_info->default_dtypes->find(attr); | ||||
3139 | if (default_it != op_exec_info->default_dtypes->end()) { | ||||
3140 | return default_it->second; | ||||
3141 | } | ||||
3142 | |||||
3143 | return tensorflow::DT_INVALID; | ||||
3144 | } | ||||
3145 | |||||
3146 | PyObject* CopySequenceSettingIndicesToNull( | ||||
3147 | PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) { | ||||
3148 | tensorflow::Safe_PyObjectPtr fast_seq( | ||||
3149 | PySequence_Fast(seq, "unable to allocate")); | ||||
3150 | int len = PySequence_Fast_GET_SIZE(fast_seq.get())(((((((PyObject*)(fast_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_seq.get()))->ob_size)) : (((PyVarObject *)(((static_cast<void> (0)), (PyTupleObject *)(fast_seq .get()))))->ob_size)); | ||||
3151 | PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get())(((((((PyObject*)(fast_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_seq .get()))->ob_item : ((PyTupleObject *)(fast_seq.get()))-> ob_item); | ||||
3152 | PyObject* result = PyTuple_New(len); | ||||
3153 | for (int i = 0; i < len; i++) { | ||||
3154 | PyObject* item; | ||||
3155 | if (indices.find(i) != indices.end()) { | ||||
3156 | item = Py_None(&_Py_NoneStruct); | ||||
3157 | } else { | ||||
3158 | item = fast_seq_array[i]; | ||||
3159 | } | ||||
3160 | Py_INCREF(item)_Py_INCREF(((PyObject*)(item))); | ||||
3161 | PyTuple_SET_ITEM(result, i, item)PyTuple_SetItem(result, i, item); | ||||
3162 | } | ||||
3163 | return result; | ||||
3164 | } | ||||
3165 | |||||
3166 | PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, | ||||
3167 | PyObject* results, | ||||
3168 | PyObject* forward_pass_name_scope = nullptr) { | ||||
3169 | std::vector<int64_t> input_ids = MakeTensorIDList(inputs); | ||||
3170 | if (PyErr_Occurred()) return nullptr; | ||||
3171 | std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs); | ||||
3172 | if (PyErr_Occurred()) return nullptr; | ||||
3173 | |||||
3174 | bool should_record = false; | ||||
3175 | for (TFE_Py_Tape* tape : SafeTapeSet()) { | ||||
3176 | if (tape->tape->ShouldRecord(input_ids, input_dtypes)) { | ||||
3177 | should_record = true; | ||||
3178 | break; | ||||
3179 | } | ||||
3180 | } | ||||
3181 | if (!should_record) { | ||||
3182 | for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) { | ||||
3183 | if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) { | ||||
3184 | should_record = true; | ||||
3185 | break; | ||||
3186 | } | ||||
3187 | } | ||||
3188 | } | ||||
3189 | if (!should_record) Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
3190 | |||||
3191 | string c_op_name = TFE_GetPythonString(op_name); | ||||
3192 | |||||
3193 | PyObject* op_outputs; | ||||
3194 | bool op_outputs_tuple_created = false; | ||||
3195 | |||||
3196 | if (const auto unused_output_indices = | ||||
3197 | OpGradientUnusedOutputIndices(c_op_name)) { | ||||
3198 | if (unused_output_indices->empty()) { | ||||
3199 | op_outputs = Py_None(&_Py_NoneStruct); | ||||
3200 | } else { | ||||
3201 | op_outputs_tuple_created = true; | ||||
3202 | op_outputs = | ||||
3203 | CopySequenceSettingIndicesToNull(results, *unused_output_indices); | ||||
3204 | } | ||||
3205 | } else { | ||||
3206 | op_outputs = results; | ||||
3207 | } | ||||
3208 | |||||
3209 | PyObject* op_inputs; | ||||
3210 | bool op_inputs_tuple_created = false; | ||||
3211 | |||||
3212 | if (const auto unused_input_indices = | ||||
3213 | OpGradientUnusedInputIndices(c_op_name)) { | ||||
3214 | if (unused_input_indices->empty()) { | ||||
3215 | op_inputs = Py_None(&_Py_NoneStruct); | ||||
3216 | } else { | ||||
3217 | op_inputs_tuple_created = true; | ||||
3218 | op_inputs = | ||||
3219 | CopySequenceSettingIndicesToNull(inputs, *unused_input_indices); | ||||
3220 | } | ||||
3221 | } else { | ||||
3222 | op_inputs = inputs; | ||||
3223 | } | ||||
3224 | |||||
3225 | tensorflow::eager::ForwardFunction<PyObject> py_forward_function( | ||||
3226 | [op_name, attrs, inputs, results]( | ||||
3227 | const std::vector<PyObject*>& input_tangents, | ||||
3228 | std::vector<PyObject*>* output_tangents, bool use_batch) { | ||||
3229 | return CallJVPFunction(op_name, attrs, inputs, results, input_tangents, | ||||
3230 | output_tangents, use_batch); | ||||
3231 | }); | ||||
3232 | tensorflow::eager::ForwardFunction<PyObject>* forward_function; | ||||
3233 | if (c_op_name == "While" || c_op_name == "StatelessWhile" || | ||||
3234 | c_op_name == "If" || c_op_name == "StatelessIf") { | ||||
3235 | // Control flow contains non-hashable attributes. Handling them in Python is | ||||
3236 | // a headache, so instead we'll stay as close to GradientTape's handling as | ||||
3237 | // possible (a null forward function means the accumulator forwards to a | ||||
3238 | // tape). | ||||
3239 | // | ||||
3240 | // This is safe to do since we'll only see control flow when graph building, | ||||
3241 | // in which case we can rely on pruning. | ||||
3242 | forward_function = nullptr; | ||||
3243 | } else { | ||||
3244 | forward_function = &py_forward_function; | ||||
3245 | } | ||||
3246 | |||||
3247 | PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs)); | ||||
3248 | |||||
3249 | if (!forward_pass_name_scope) forward_pass_name_scope = Py_None(&_Py_NoneStruct); | ||||
3250 | |||||
3251 | TapeSetRecordOperation( | ||||
3252 | op_name, inputs, results, input_ids, input_dtypes, | ||||
3253 | [op_name, attrs, num_inputs, op_inputs, op_outputs, | ||||
3254 | forward_pass_name_scope]() { | ||||
3255 | Py_INCREF(op_name)_Py_INCREF(((PyObject*)(op_name))); | ||||
3256 | Py_INCREF(attrs)_Py_INCREF(((PyObject*)(attrs))); | ||||
3257 | Py_INCREF(num_inputs)_Py_INCREF(((PyObject*)(num_inputs))); | ||||
3258 | Py_INCREF(op_inputs)_Py_INCREF(((PyObject*)(op_inputs))); | ||||
3259 | Py_INCREF(op_outputs)_Py_INCREF(((PyObject*)(op_outputs))); | ||||
3260 | Py_INCREF(forward_pass_name_scope)_Py_INCREF(((PyObject*)(forward_pass_name_scope))); | ||||
3261 | PyBackwardFunction* function = new PyBackwardFunction( | ||||
3262 | [op_name, attrs, num_inputs, op_inputs, op_outputs, | ||||
3263 | forward_pass_name_scope]( | ||||
3264 | PyObject* output_grads, | ||||
3265 | const std::vector<int64_t>& unneeded_gradients) { | ||||
3266 | if (PyErr_Occurred()) { | ||||
3267 | return static_cast<PyObject*>(nullptr); | ||||
3268 | } | ||||
3269 | tensorflow::Safe_PyObjectPtr skip_input_indices; | ||||
3270 | if (!unneeded_gradients.empty()) { | ||||
3271 | skip_input_indices.reset( | ||||
3272 | PyTuple_New(unneeded_gradients.size())); | ||||
3273 | for (int i = 0; i < unneeded_gradients.size(); i++) { | ||||
3274 | PyTuple_SET_ITEM(PyTuple_SetItem(skip_input_indices.get(), i, GetPythonObjectFromInt (unneeded_gradients[i])) | ||||
3275 | skip_input_indices.get(), i,PyTuple_SetItem(skip_input_indices.get(), i, GetPythonObjectFromInt (unneeded_gradients[i])) | ||||
3276 | GetPythonObjectFromInt(unneeded_gradients[i]))PyTuple_SetItem(skip_input_indices.get(), i, GetPythonObjectFromInt (unneeded_gradients[i])); | ||||
3277 | } | ||||
3278 | } else { | ||||
3279 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
3280 | skip_input_indices.reset(Py_None(&_Py_NoneStruct)); | ||||
3281 | } | ||||
3282 | tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue( | ||||
3283 | "OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs, | ||||
3284 | output_grads, skip_input_indices.get(), | ||||
3285 | forward_pass_name_scope)); | ||||
3286 | |||||
3287 | tensorflow::Safe_PyObjectPtr result( | ||||
3288 | PyObject_CallObject(gradient_function, callback_args.get())); | ||||
3289 | |||||
3290 | if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr); | ||||
3291 | |||||
3292 | return tensorflow::swig::Flatten(result.get()); | ||||
3293 | }); | ||||
3294 | return function; | ||||
3295 | }, | ||||
3296 | [op_name, attrs, num_inputs, op_inputs, op_outputs, | ||||
3297 | forward_pass_name_scope](PyBackwardFunction* backward_function) { | ||||
3298 | Py_DECREF(op_name)_Py_DECREF(((PyObject*)(op_name))); | ||||
3299 | Py_DECREF(attrs)_Py_DECREF(((PyObject*)(attrs))); | ||||
3300 | Py_DECREF(num_inputs)_Py_DECREF(((PyObject*)(num_inputs))); | ||||
3301 | Py_DECREF(op_inputs)_Py_DECREF(((PyObject*)(op_inputs))); | ||||
3302 | Py_DECREF(op_outputs)_Py_DECREF(((PyObject*)(op_outputs))); | ||||
3303 | Py_DECREF(forward_pass_name_scope)_Py_DECREF(((PyObject*)(forward_pass_name_scope))); | ||||
3304 | |||||
3305 | delete backward_function; | ||||
3306 | }, | ||||
3307 | forward_function); | ||||
3308 | |||||
3309 | Py_DECREF(num_inputs)_Py_DECREF(((PyObject*)(num_inputs))); | ||||
3310 | if (op_outputs_tuple_created) Py_DECREF(op_outputs)_Py_DECREF(((PyObject*)(op_outputs))); | ||||
3311 | if (op_inputs_tuple_created) Py_DECREF(op_inputs)_Py_DECREF(((PyObject*)(op_inputs))); | ||||
3312 | |||||
3313 | if (PyErr_Occurred()) { | ||||
3314 | return nullptr; | ||||
3315 | } | ||||
3316 | |||||
3317 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
3318 | } | ||||
3319 | |||||
3320 | void MaybeNotifyVariableAccessed(PyObject* input) { | ||||
3321 | DCHECK(CheckResourceVariable(input))while (false && (CheckResourceVariable(input))) ::tensorflow ::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 3321); | ||||
3322 | DCHECK(PyObject_HasAttrString(input, "_trainable"))while (false && (PyObject_HasAttrString(input, "_trainable" ))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 3322); | ||||
3323 | |||||
3324 | tensorflow::Safe_PyObjectPtr trainable( | ||||
3325 | PyObject_GetAttrString(input, "_trainable")); | ||||
3326 | if (trainable.get() == Py_False((PyObject *) &_Py_FalseStruct)) return; | ||||
3327 | TFE_Py_TapeVariableAccessed(input); | ||||
3328 | TFE_Py_VariableWatcherVariableAccessed(input); | ||||
3329 | } | ||||
3330 | |||||
3331 | bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, | ||||
3332 | PyObject* input, tensorflow::Safe_PyObjectPtr* output, | ||||
3333 | TF_Status* status) { | ||||
3334 | MaybeNotifyVariableAccessed(input); | ||||
3335 | |||||
3336 | TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status); | ||||
3337 | auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); }); | ||||
3338 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; | ||||
3339 | |||||
3340 | TFE_OpSetDevice(op, parent_op_exec_info.device_name, status); | ||||
3341 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; | ||||
3342 | |||||
3343 | // Set dtype | ||||
3344 | DCHECK(PyObject_HasAttrString(input, "_dtype"))while (false && (PyObject_HasAttrString(input, "_dtype" ))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 3344); | ||||
3345 | tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype")); | ||||
3346 | int value; | ||||
3347 | if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) { | ||||
3348 | return false; | ||||
3349 | } | ||||
3350 | TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value)); | ||||
3351 | |||||
3352 | // Get handle | ||||
3353 | tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle")); | ||||
3354 | if (!EagerTensor_CheckExact(handle.get())) return false; | ||||
3355 | TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status); | ||||
3356 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; | ||||
3357 | |||||
3358 | int num_retvals = 1; | ||||
3359 | TFE_TensorHandle* output_handle; | ||||
3360 | TFE_Execute(op, &output_handle, &num_retvals, status); | ||||
3361 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; | ||||
3362 | |||||
3363 | // Always create the py object (and correctly DECREF it) from the returned | ||||
3364 | // value, else the data will leak. | ||||
3365 | output->reset(EagerTensorFromHandle(output_handle)); | ||||
3366 | |||||
3367 | // TODO(nareshmodi): Should we run post exec callbacks here? | ||||
3368 | if (parent_op_exec_info.run_gradient_callback) { | ||||
3369 | tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1)); | ||||
3370 | PyTuple_SET_ITEM(inputs.get(), 0, handle.release())PyTuple_SetItem(inputs.get(), 0, handle.release()); | ||||
3371 | |||||
3372 | tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1)); | ||||
3373 | Py_INCREF(output->get())_Py_INCREF(((PyObject*)(output->get()))); // stay alive after since tuple steals. | ||||
3374 | PyTuple_SET_ITEM(outputs.get(), 0, output->get())PyTuple_SetItem(outputs.get(), 0, output->get()); | ||||
3375 | |||||
3376 | tensorflow::Safe_PyObjectPtr op_string( | ||||
3377 | GetPythonObjectFromString("ReadVariableOp")); | ||||
3378 | if (!RecordGradient(op_string.get(), inputs.get(), Py_None(&_Py_NoneStruct), | ||||
3379 | outputs.get())) { | ||||
3380 | return false; | ||||
3381 | } | ||||
3382 | } | ||||
3383 | |||||
3384 | return true; | ||||
3385 | } | ||||
3386 | |||||
3387 | // Supports 3 cases at the moment: | ||||
3388 | // i) input is an EagerTensor. | ||||
3389 | // ii) input is a ResourceVariable - in this case, the is_variable param is | ||||
3390 | // set to true. | ||||
3391 | // iii) input is an arbitrary python list/tuple (note, this handling doesn't | ||||
3392 | // support packing). | ||||
3393 | // | ||||
3394 | // NOTE: dtype_hint_getter must *always* return a PyObject that can be | ||||
3395 | // decref'd. So if no hint is found, Py_RETURN_NONE (which correctly | ||||
3396 | // increfs Py_None). | ||||
3397 | // | ||||
3398 | // NOTE: This function sets a python error directly, and returns false. | ||||
3399 | // TF_Status is only passed since we don't want to have to reallocate it. | ||||
3400 | bool ConvertToTensor( | ||||
3401 | const FastPathOpExecInfo& op_exec_info, PyObject* input, | ||||
3402 | tensorflow::Safe_PyObjectPtr* output_handle, | ||||
3403 | // This gets a hint for this particular input. | ||||
3404 | const std::function<tensorflow::DataType()>& dtype_hint_getter, | ||||
3405 | // This sets the dtype after conversion is complete. | ||||
3406 | const std::function<void(const tensorflow::DataType dtype)>& dtype_setter, | ||||
3407 | TF_Status* status) { | ||||
3408 | if (EagerTensor_CheckExact(input)) { | ||||
3409 | Py_INCREF(input)_Py_INCREF(((PyObject*)(input))); | ||||
3410 | output_handle->reset(input); | ||||
3411 | return true; | ||||
3412 | } else if (CheckResourceVariable(input)) { | ||||
3413 | return ReadVariableOp(op_exec_info, input, output_handle, status); | ||||
3414 | } | ||||
3415 | |||||
3416 | // The hint comes from a supposedly similarly typed tensor. | ||||
3417 | tensorflow::DataType dtype_hint = dtype_hint_getter(); | ||||
3418 | |||||
3419 | TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor( | ||||
3420 | op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name); | ||||
3421 | if (handle == nullptr) { | ||||
3422 | return MaybeRaiseExceptionFromTFStatus(status, nullptr); | ||||
3423 | } | ||||
3424 | |||||
3425 | output_handle->reset(EagerTensorFromHandle(handle)); | ||||
3426 | dtype_setter( | ||||
3427 | static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle))); | ||||
3428 | |||||
3429 | return true; | ||||
3430 | } | ||||
3431 | |||||
3432 | // Adds input and type attr to the op, and to the list of flattened | ||||
3433 | // inputs/attrs. | ||||
3434 | bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input, | ||||
3435 | const bool add_type_attr, | ||||
3436 | const tensorflow::OpDef::ArgDef& input_arg, | ||||
3437 | std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs, | ||||
3438 | std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs, | ||||
3439 | TFE_Op* op, TF_Status* status) { | ||||
3440 | // py_eager_tensor's ownership is transferred to flattened_inputs if it is | ||||
3441 | // required, else the object is destroyed and DECREF'd when the object goes | ||||
3442 | // out of scope in this function. | ||||
3443 | tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr; | ||||
3444 | |||||
3445 | if (!ConvertToTensor( | ||||
3446 | *op_exec_info, input, &py_eager_tensor, | ||||
3447 | [&]() { | ||||
3448 | if (input_arg.type() != tensorflow::DataType::DT_INVALID) { | ||||
3449 | return input_arg.type(); | ||||
3450 | } | ||||
3451 | return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info); | ||||
3452 | }, | ||||
3453 | [&](const tensorflow::DataType dtype) { | ||||
3454 | op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype; | ||||
3455 | }, | ||||
3456 | status)) { | ||||
3457 | return false; | ||||
3458 | } | ||||
3459 | |||||
3460 | TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get()); | ||||
3461 | |||||
3462 | if (add_type_attr && !input_arg.type_attr().empty()) { | ||||
3463 | auto dtype = TFE_TensorHandleDataType(input_handle); | ||||
3464 | TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype); | ||||
3465 | if (flattened_attrs != nullptr) { | ||||
3466 | flattened_attrs->emplace_back( | ||||
3467 | GetPythonObjectFromString(input_arg.type_attr())); | ||||
3468 | flattened_attrs->emplace_back(PyLong_FromLong(dtype)); | ||||
3469 | } | ||||
3470 | } | ||||
3471 | |||||
3472 | if (flattened_inputs != nullptr) { | ||||
3473 | flattened_inputs->emplace_back(std::move(py_eager_tensor)); | ||||
3474 | } | ||||
3475 | |||||
3476 | TFE_OpAddInput(op, input_handle, status); | ||||
3477 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { | ||||
3478 | return false; | ||||
3479 | } | ||||
3480 | |||||
3481 | return true; | ||||
3482 | } | ||||
3483 | |||||
3484 | const char* GetDeviceName(PyObject* py_device_name) { | ||||
3485 | if (py_device_name != Py_None(&_Py_NoneStruct)) { | ||||
3486 | return TFE_GetPythonString(py_device_name); | ||||
3487 | } | ||||
3488 | return nullptr; | ||||
3489 | } | ||||
3490 | |||||
3491 | bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) { | ||||
3492 | if (!PySequence_Check(seq)) { | ||||
3493 | PyErr_SetString(PyExc_TypeError, | ||||
3494 | Printf("expected a sequence for attr %s, got %s instead", | ||||
3495 | attr_name.data(), seq->ob_type->tp_name) | ||||
3496 | .data()); | ||||
3497 | |||||
3498 | return false; | ||||
3499 | } | ||||
3500 | if (PyArray_Check(seq)((((PyObject*)(seq))->ob_type) == (&(*(PyTypeObject *) _tensorflow_numpy_api[2])) || PyType_IsSubtype((((PyObject*)( seq))->ob_type), (&(*(PyTypeObject *)_tensorflow_numpy_api [2])))) && | ||||
3501 | PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) { | ||||
3502 | PyErr_SetString(PyExc_ValueError, | ||||
3503 | Printf("expected a sequence for attr %s, got an ndarray " | ||||
3504 | "with rank %d instead", | ||||
3505 | attr_name.data(), | ||||
3506 | PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq))) | ||||
3507 | .data()); | ||||
3508 | return false; | ||||
3509 | } | ||||
3510 | return true; | ||||
3511 | } | ||||
3512 | |||||
3513 | bool RunCallbacks( | ||||
3514 | const FastPathOpExecInfo& op_exec_info, PyObject* args, | ||||
3515 | int num_inferred_attrs, | ||||
3516 | const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs, | ||||
3517 | const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs, | ||||
3518 | PyObject* flattened_result) { | ||||
3519 | DCHECK(op_exec_info.run_callbacks)while (false && (op_exec_info.run_callbacks)) ::tensorflow ::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 3519); | ||||
3520 | |||||
3521 | tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size())); | ||||
3522 | for (int i = 0; i < flattened_inputs.size(); i++) { | ||||
3523 | PyObject* input = flattened_inputs[i].get(); | ||||
3524 | Py_INCREF(input)_Py_INCREF(((PyObject*)(input))); | ||||
3525 | PyTuple_SET_ITEM(inputs.get(), i, input)PyTuple_SetItem(inputs.get(), i, input); | ||||
3526 | } | ||||
3527 | |||||
3528 | int num_non_inferred_attrs = PyTuple_GET_SIZE(args)(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(args))))->ob_size) - num_inferred_attrs; | ||||
3529 | int num_attrs = flattened_attrs.size() + num_non_inferred_attrs; | ||||
3530 | tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs)); | ||||
3531 | |||||
3532 | for (int i = 0; i < num_non_inferred_attrs; i++) { | ||||
3533 | auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[num_inferred_attrs + i]); | ||||
3534 | Py_INCREF(attr)_Py_INCREF(((PyObject*)(attr))); | ||||
3535 | PyTuple_SET_ITEM(attrs.get(), i, attr)PyTuple_SetItem(attrs.get(), i, attr); | ||||
3536 | } | ||||
3537 | |||||
3538 | for (int i = num_non_inferred_attrs; i < num_attrs; i++) { | ||||
3539 | PyObject* attr_or_name = | ||||
3540 | flattened_attrs.at(i - num_non_inferred_attrs).get(); | ||||
3541 | Py_INCREF(attr_or_name)_Py_INCREF(((PyObject*)(attr_or_name))); | ||||
3542 | PyTuple_SET_ITEM(attrs.get(), i, attr_or_name)PyTuple_SetItem(attrs.get(), i, attr_or_name); | ||||
3543 | } | ||||
3544 | |||||
3545 | if (op_exec_info.run_gradient_callback) { | ||||
3546 | if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(), | ||||
3547 | flattened_result)) { | ||||
3548 | return false; | ||||
3549 | } | ||||
3550 | } | ||||
3551 | |||||
3552 | if (op_exec_info.run_post_exec_callbacks) { | ||||
3553 | tensorflow::Safe_PyObjectPtr callback_args( | ||||
3554 | Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(), | ||||
3555 | flattened_result, op_exec_info.name)); | ||||
3556 | for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) { | ||||
3557 | PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i)(((PyListObject *)(op_exec_info.callbacks))->ob_item[i]); | ||||
3558 | if (!PyCallable_Check(callback_fn)) { | ||||
3559 | PyErr_SetString( | ||||
3560 | PyExc_TypeError, | ||||
3561 | Printf("expected a function for " | ||||
3562 | "post execution callback in index %ld, got %s instead", | ||||
3563 | i, callback_fn->ob_type->tp_name) | ||||
3564 | .c_str()); | ||||
3565 | return false; | ||||
3566 | } | ||||
3567 | PyObject* callback_result = | ||||
3568 | PyObject_CallObject(callback_fn, callback_args.get()); | ||||
3569 | if (!callback_result) { | ||||
3570 | return false; | ||||
3571 | } | ||||
3572 | Py_DECREF(callback_result)_Py_DECREF(((PyObject*)(callback_result))); | ||||
3573 | } | ||||
3574 | } | ||||
3575 | |||||
3576 | return true; | ||||
3577 | } | ||||
3578 | |||||
3579 | } // namespace | ||||
3580 | |||||
3581 | PyObject* TFE_Py_FastPathExecute_C(PyObject* args) { | ||||
3582 | tensorflow::profiler::TraceMe activity( | ||||
3583 | "TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo); | ||||
3584 | Py_ssize_t args_size = PyTuple_GET_SIZE(args)(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(args))))->ob_size); | ||||
3585 | if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) { | ||||
3586 | PyErr_SetString( | ||||
3587 | PyExc_ValueError, | ||||
3588 | Printf("There must be at least %d items in the input tuple.", | ||||
3589 | FAST_PATH_EXECUTE_ARG_INPUT_START) | ||||
3590 | .c_str()); | ||||
3591 | return nullptr; | ||||
3592 | } | ||||
3593 | |||||
3594 | FastPathOpExecInfo op_exec_info; | ||||
3595 | |||||
3596 | PyObject* py_eager_context = | ||||
3597 | PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[FAST_PATH_EXECUTE_ARG_CONTEXT]); | ||||
3598 | |||||
3599 | // TODO(edoper): Use interned string here | ||||
3600 | PyObject* eager_context_handle = | ||||
3601 | PyObject_GetAttrString(py_eager_context, "_context_handle"); | ||||
3602 | |||||
3603 | TFE_Context* ctx = reinterpret_cast<TFE_Context*>( | ||||
3604 | PyCapsule_GetPointer(eager_context_handle, nullptr)); | ||||
3605 | op_exec_info.ctx = ctx; | ||||
3606 | op_exec_info.args = args; | ||||
3607 | |||||
3608 | if (ctx == nullptr) { | ||||
3609 | // The context hasn't been initialized. It will be in the slow path. | ||||
3610 | RaiseFallbackException( | ||||
3611 | "This function does not handle the case of the path where " | ||||
3612 | "all inputs are not already EagerTensors."); | ||||
3613 | return nullptr; | ||||
3614 | } | ||||
3615 | |||||
3616 | auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context); | ||||
3617 | if (tld == nullptr) { | ||||
3618 | return nullptr; | ||||
3619 | } | ||||
3620 | op_exec_info.device_name = GetDeviceName(tld->device_name.get()); | ||||
3621 | op_exec_info.callbacks = tld->op_callbacks.get(); | ||||
3622 | |||||
3623 | op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[FAST_PATH_EXECUTE_ARG_OP_NAME]); | ||||
3624 | op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[FAST_PATH_EXECUTE_ARG_NAME]); | ||||
3625 | |||||
3626 | // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks | ||||
3627 | // (similar to benchmark_tf_gradient_function_*). Also consider using an | ||||
3628 | // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks | ||||
3629 | // point out problems with heap allocs. | ||||
3630 | op_exec_info.run_gradient_callback = | ||||
3631 | !*ThreadTapeIsStopped() && HasAccumulatorOrTape(); | ||||
3632 | op_exec_info.run_post_exec_callbacks = | ||||
3633 | op_exec_info.callbacks != Py_None(&_Py_NoneStruct) && | ||||
3634 | PyList_Size(op_exec_info.callbacks) > 0; | ||||
3635 | op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || | ||||
3636 | op_exec_info.run_post_exec_callbacks; | ||||
3637 | |||||
3638 | TF_Status* status = GetStatus(); | ||||
3639 | const char* op_name = TFE_GetPythonString(op_exec_info.op_name); | ||||
3640 | if (op_name == nullptr) { | ||||
3641 | PyErr_SetString(PyExc_TypeError, | ||||
3642 | Printf("expected a string for op_name, got %s instead", | ||||
3643 | op_exec_info.op_name->ob_type->tp_name) | ||||
3644 | .c_str()); | ||||
3645 | return nullptr; | ||||
3646 | } | ||||
3647 | |||||
3648 | TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status); | ||||
3649 | |||||
3650 | auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] { | ||||
3651 | ReturnStatus(status); | ||||
3652 | ReturnOp(ctx, op); | ||||
3653 | }); | ||||
3654 | |||||
3655 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { | ||||
3656 | return nullptr; | ||||
3657 | } | ||||
3658 | |||||
3659 | tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace( | ||||
3660 | tensorflow::StackTrace::kStackTraceInitialSize)); | ||||
3661 | |||||
3662 | const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef(); | ||||
3663 | if (op_def == nullptr) return nullptr; | ||||
3664 | |||||
3665 | if (args_size < | ||||
3666 | FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) { | ||||
3667 | PyErr_SetString( | ||||
3668 | PyExc_ValueError, | ||||
3669 | Printf("Tuple size smaller than intended. Expected to be at least %d, " | ||||
3670 | "was %ld", | ||||
3671 | FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(), | ||||
3672 | args_size) | ||||
3673 | .c_str()); | ||||
3674 | return nullptr; | ||||
3675 | } | ||||
3676 | |||||
3677 | if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) { | ||||
3678 | RaiseFallbackException( | ||||
3679 | "This function does not handle the case of the path where " | ||||
3680 | "all inputs are not already EagerTensors."); | ||||
3681 | return nullptr; | ||||
3682 | } | ||||
3683 | |||||
3684 | op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def); | ||||
3685 | op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def); | ||||
3686 | |||||
3687 | // Mapping of attr name to size - used to calculate the number of values | ||||
3688 | // to be expected by the TFE_Execute run. | ||||
3689 | tensorflow::gtl::FlatMap<string, int64_t> attr_list_sizes; | ||||
3690 | |||||
3691 | // Set non-inferred attrs, including setting defaults if the attr is passed in | ||||
3692 | // as None. | ||||
3693 | for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(); | ||||
3694 | i < args_size; i += 2) { | ||||
3695 | PyObject* py_attr_name = PyTuple_GET_ITEM(args, i)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[i]); | ||||
3696 | const char* attr_name = TFE_GetPythonString(py_attr_name); | ||||
3697 | PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[i + 1]); | ||||
3698 | |||||
3699 | // Not creating an index since most of the time there are not more than a | ||||
3700 | // few attrs. | ||||
3701 | // TODO(nareshmodi): Maybe include the index as part of the | ||||
3702 | // OpRegistrationData. | ||||
3703 | for (const auto& attr : op_def->attr()) { | ||||
3704 | if (tensorflow::StringPiece(attr_name) == attr.name()) { | ||||
3705 | SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value, | ||||
3706 | &attr_list_sizes, status); | ||||
3707 | |||||
3708 | if (!status->status.ok()) { | ||||
3709 | VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 3709, tensorflow::INFO) << "Falling back to slow path for Op \"" << op_def->name() | ||||
3710 | << "\" since we are unable to set the value for attr \"" | ||||
3711 | << attr.name() << "\" due to: " << TF_Message(status); | ||||
3712 | RaiseFallbackException(TF_Message(status)); | ||||
3713 | return nullptr; | ||||
3714 | } | ||||
3715 | |||||
3716 | break; | ||||
3717 | } | ||||
3718 | } | ||||
3719 | } | ||||
3720 | |||||
3721 | // Flat attrs and inputs as required by the record_gradient call. The attrs | ||||
3722 | // here only contain inferred attrs (non-inferred attrs are added directly | ||||
3723 | // from the input args). | ||||
3724 | // All items in flattened_attrs and flattened_inputs contain | ||||
3725 | // Safe_PyObjectPtr - any time something steals a reference to this, it must | ||||
3726 | // INCREF. | ||||
3727 | // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work | ||||
3728 | // directly. | ||||
3729 | std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs = | ||||
3730 | nullptr; | ||||
3731 | std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs = | ||||
3732 | nullptr; | ||||
3733 | |||||
3734 | // TODO(nareshmodi): Encapsulate callbacks information into a struct. | ||||
3735 | if (op_exec_info.run_callbacks) { | ||||
3736 | flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>); | ||||
3737 | flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>); | ||||
3738 | } | ||||
3739 | |||||
3740 | // Add inferred attrs and inputs. | ||||
3741 | // The following code might set duplicate type attrs. This will result in | ||||
3742 | // the CacheKey for the generated AttrBuilder possibly differing from | ||||
3743 | // those where the type attrs are correctly set. Inconsistent CacheKeys | ||||
3744 | // for ops means that there might be unnecessarily duplicated kernels. | ||||
3745 | // TODO(nareshmodi): Fix this. | ||||
3746 | for (int i = 0; i < op_def->input_arg_size(); i++) { | ||||
3747 | const auto& input_arg = op_def->input_arg(i); | ||||
3748 | |||||
3749 | PyObject* input = | ||||
3750 | PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[FAST_PATH_EXECUTE_ARG_INPUT_START + i]); | ||||
3751 | if (!input_arg.number_attr().empty()) { | ||||
3752 | // The item is a homogeneous list. | ||||
3753 | if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr; | ||||
3754 | tensorflow::Safe_PyObjectPtr fast_input( | ||||
3755 | PySequence_Fast(input, "Could not parse sequence.")); | ||||
3756 | if (fast_input.get() == nullptr) { | ||||
3757 | return nullptr; | ||||
3758 | } | ||||
3759 | Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_input.get()))->ob_size)) : (( (PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(fast_input.get()))))->ob_size)); | ||||
3760 | PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_input .get()))->ob_item : ((PyTupleObject *)(fast_input.get()))-> ob_item); | ||||
3761 | |||||
3762 | TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len); | ||||
3763 | if (op_exec_info.run_callbacks) { | ||||
3764 | flattened_attrs->emplace_back( | ||||
3765 | GetPythonObjectFromString(input_arg.number_attr())); | ||||
3766 | flattened_attrs->emplace_back(PyLong_FromLong(len)); | ||||
3767 | } | ||||
3768 | attr_list_sizes[input_arg.number_attr()] = len; | ||||
3769 | |||||
3770 | if (len > 0) { | ||||
3771 | // First item adds the type attr. | ||||
3772 | if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg, | ||||
3773 | flattened_attrs.get(), flattened_inputs.get(), op, | ||||
3774 | status)) { | ||||
3775 | return nullptr; | ||||
3776 | } | ||||
3777 | |||||
3778 | for (Py_ssize_t j = 1; j < len; j++) { | ||||
3779 | // Since the list is homogeneous, we don't need to re-add the attr. | ||||
3780 | if (!AddInputToOp(&op_exec_info, fast_input_array[j], false, | ||||
3781 | input_arg, nullptr /* flattened_attrs */, | ||||
3782 | flattened_inputs.get(), op, status)) { | ||||
3783 | return nullptr; | ||||
3784 | } | ||||
3785 | } | ||||
3786 | } | ||||
3787 | } else if (!input_arg.type_list_attr().empty()) { | ||||
3788 | // The item is a heterogeneous list. | ||||
3789 | if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) { | ||||
3790 | return nullptr; | ||||
3791 | } | ||||
3792 | tensorflow::Safe_PyObjectPtr fast_input( | ||||
3793 | PySequence_Fast(input, "Could not parse sequence.")); | ||||
3794 | if (fast_input.get() == nullptr) { | ||||
3795 | return nullptr; | ||||
3796 | } | ||||
3797 | const string& attr_name = input_arg.type_list_attr(); | ||||
3798 | Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_input.get()))->ob_size)) : (( (PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(fast_input.get()))))->ob_size)); | ||||
3799 | PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_input .get()))->ob_item : ((PyTupleObject *)(fast_input.get()))-> ob_item); | ||||
3800 | tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len); | ||||
3801 | PyObject* py_attr_value = nullptr; | ||||
3802 | if (op_exec_info.run_callbacks) { | ||||
3803 | py_attr_value = PyTuple_New(len); | ||||
3804 | } | ||||
3805 | for (Py_ssize_t j = 0; j < len; j++) { | ||||
3806 | PyObject* py_input = fast_input_array[j]; | ||||
3807 | tensorflow::Safe_PyObjectPtr py_eager_tensor; | ||||
3808 | if (!ConvertToTensor( | ||||
3809 | op_exec_info, py_input, &py_eager_tensor, | ||||
3810 | []() { return tensorflow::DT_INVALID; }, | ||||
3811 | [](const tensorflow::DataType dtype) {}, status)) { | ||||
3812 | return nullptr; | ||||
3813 | } | ||||
3814 | |||||
3815 | TFE_TensorHandle* input_handle = | ||||
3816 | EagerTensor_Handle(py_eager_tensor.get()); | ||||
3817 | |||||
3818 | attr_value[j] = TFE_TensorHandleDataType(input_handle); | ||||
3819 | |||||
3820 | TFE_OpAddInput(op, input_handle, status); | ||||
3821 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { | ||||
3822 | return nullptr; | ||||
3823 | } | ||||
3824 | |||||
3825 | if (op_exec_info.run_callbacks) { | ||||
3826 | flattened_inputs->emplace_back(std::move(py_eager_tensor)); | ||||
3827 | |||||
3828 | PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]))PyTuple_SetItem(py_attr_value, j, PyLong_FromLong(attr_value[ j])); | ||||
3829 | } | ||||
3830 | } | ||||
3831 | if (op_exec_info.run_callbacks) { | ||||
3832 | flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name)); | ||||
3833 | flattened_attrs->emplace_back(py_attr_value); | ||||
3834 | } | ||||
3835 | TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(), | ||||
3836 | attr_value.size()); | ||||
3837 | attr_list_sizes[attr_name] = len; | ||||
3838 | } else { | ||||
3839 | // The item is a single item. | ||||
3840 | if (!AddInputToOp(&op_exec_info, input, true, input_arg, | ||||
3841 | flattened_attrs.get(), flattened_inputs.get(), op, | ||||
3842 | status)) { | ||||
3843 | return nullptr; | ||||
3844 | } | ||||
3845 | } | ||||
3846 | } | ||||
3847 | |||||
3848 | int64_t num_outputs = 0; | ||||
3849 | for (int i = 0; i < op_def->output_arg_size(); i++) { | ||||
3850 | const auto& output_arg = op_def->output_arg(i); | ||||
3851 | int64_t delta = 1; | ||||
3852 | if (!output_arg.number_attr().empty()) { | ||||
3853 | delta = attr_list_sizes[output_arg.number_attr()]; | ||||
3854 | } else if (!output_arg.type_list_attr().empty()) { | ||||
3855 | delta = attr_list_sizes[output_arg.type_list_attr()]; | ||||
3856 | } | ||||
3857 | if (delta < 0) { | ||||
3858 | RaiseFallbackException( | ||||
3859 | "Attributes suggest that the size of an output list is less than 0"); | ||||
3860 | return nullptr; | ||||
3861 | } | ||||
3862 | num_outputs += delta; | ||||
3863 | } | ||||
3864 | |||||
3865 | // If number of retvals is larger than int32, we error out. | ||||
3866 | if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) { | ||||
3867 | PyErr_SetString( | ||||
3868 | PyExc_ValueError, | ||||
3869 | Printf("Number of outputs is too big: %ld", num_outputs).c_str()); | ||||
3870 | return nullptr; | ||||
3871 | } | ||||
3872 | int num_retvals = num_outputs; | ||||
3873 | |||||
3874 | tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals); | ||||
3875 | |||||
3876 | Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();; | ||||
3877 | TFE_Execute(op, retvals.data(), &num_retvals, status); | ||||
3878 | Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); }; | ||||
3879 | |||||
3880 | if (!status->status.ok()) { | ||||
3881 | // Augment the status with the op_name for easier debugging similar to | ||||
3882 | // TFE_Py_Execute. | ||||
3883 | status->status = tensorflow::errors::CreateWithUpdatedMessage( | ||||
3884 | status->status, tensorflow::strings::StrCat( | ||||
3885 | TF_Message(status), " [Op:", | ||||
3886 | TFE_GetPythonString(op_exec_info.op_name), "]")); | ||||
3887 | MaybeRaiseExceptionFromTFStatus(status, nullptr); | ||||
3888 | return nullptr; | ||||
3889 | } | ||||
3890 | |||||
3891 | tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals)); | ||||
3892 | for (int i = 0; i < num_retvals; ++i) { | ||||
3893 | PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]))PyList_SetItem(flat_result.get(), i, EagerTensorFromHandle(retvals [i])); | ||||
3894 | } | ||||
3895 | |||||
3896 | if (op_exec_info.run_callbacks) { | ||||
3897 | if (!RunCallbacks( | ||||
3898 | op_exec_info, args, | ||||
3899 | FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(), | ||||
3900 | *flattened_inputs, *flattened_attrs, flat_result.get())) { | ||||
3901 | return nullptr; | ||||
3902 | } | ||||
3903 | } | ||||
3904 | |||||
3905 | // Unflatten results. | ||||
3906 | if (op_def->output_arg_size() == 0) { | ||||
3907 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
3908 | } | ||||
3909 | |||||
3910 | if (op_def->output_arg_size() == 1) { | ||||
3911 | if (!op_def->output_arg(0).number_attr().empty() || | ||||
3912 | !op_def->output_arg(0).type_list_attr().empty()) { | ||||
3913 | return flat_result.release(); | ||||
3914 | } else { | ||||
3915 | auto* result = PyList_GET_ITEM(flat_result.get(), 0)(((PyListObject *)(flat_result.get()))->ob_item[0]); | ||||
3916 | Py_INCREF(result)_Py_INCREF(((PyObject*)(result))); | ||||
3917 | return result; | ||||
3918 | } | ||||
3919 | } | ||||
3920 | |||||
3921 | // Correctly output the results that are made into a namedtuple. | ||||
3922 | PyObject* result = PyList_New(op_def->output_arg_size()); | ||||
3923 | int flat_result_index = 0; | ||||
3924 | for (int i = 0; i < op_def->output_arg_size(); i++) { | ||||
3925 | if (!op_def->output_arg(i).number_attr().empty()) { | ||||
3926 | int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()]; | ||||
3927 | PyObject* inner_list = PyList_New(list_length); | ||||
3928 | for (int j = 0; j < list_length; j++) { | ||||
3929 | PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++)(((PyListObject *)(flat_result.get()))->ob_item[flat_result_index ++]); | ||||
3930 | Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj))); | ||||
3931 | PyList_SET_ITEM(inner_list, j, obj)PyList_SetItem(inner_list, j, obj); | ||||
3932 | } | ||||
3933 | PyList_SET_ITEM(result, i, inner_list)PyList_SetItem(result, i, inner_list); | ||||
3934 | } else if (!op_def->output_arg(i).type_list_attr().empty()) { | ||||
3935 | int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()]; | ||||
3936 | PyObject* inner_list = PyList_New(list_length); | ||||
3937 | for (int j = 0; j < list_length; j++) { | ||||
3938 | PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++)(((PyListObject *)(flat_result.get()))->ob_item[flat_result_index ++]); | ||||
3939 | Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj))); | ||||
3940 | PyList_SET_ITEM(inner_list, j, obj)PyList_SetItem(inner_list, j, obj); | ||||
3941 | } | ||||
3942 | PyList_SET_ITEM(result, i, inner_list)PyList_SetItem(result, i, inner_list); | ||||
3943 | } else { | ||||
3944 | PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++)(((PyListObject *)(flat_result.get()))->ob_item[flat_result_index ++]); | ||||
3945 | Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj))); | ||||
3946 | PyList_SET_ITEM(result, i, obj)PyList_SetItem(result, i, obj); | ||||
3947 | } | ||||
3948 | } | ||||
3949 | return result; | ||||
3950 | } | ||||
3951 | |||||
3952 | PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, | ||||
3953 | PyObject* attrs, PyObject* results, | ||||
3954 | PyObject* forward_pass_name_scope) { | ||||
3955 | if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) { | ||||
3956 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
3957 | } | ||||
3958 | |||||
3959 | return RecordGradient(op_name, inputs, attrs, results, | ||||
3960 | forward_pass_name_scope); | ||||
3961 | } | ||||
3962 | |||||
3963 | namespace { | ||||
3964 | const char kTensor[] = "T"; | ||||
3965 | const char kList[] = "L"; | ||||
3966 | const char kListEnd[] = "l"; | ||||
3967 | const char kTuple[] = "U"; | ||||
3968 | const char kTupleEnd[] = "u"; | ||||
3969 | const char kDIter[] = "I"; | ||||
3970 | const char kDict[] = "D"; | ||||
3971 | const char kRaw[] = "R"; | ||||
3972 | const char kResourceVariable[] = "r"; | ||||
3973 | const char kShape[] = "s"; | ||||
3974 | const char kShapeDelim[] = "-"; | ||||
3975 | const char kDType[] = "d"; | ||||
3976 | const char kNone[] = "n"; | ||||
3977 | const char kCompositeTensor[] = "C"; | ||||
3978 | const char kAttrs[] = "A"; | ||||
3979 | const char kAttrsEnd[] = "a"; | ||||
3980 | const char kName[] = "'"; | ||||
3981 | const char kNameEnd[] = "'"; | ||||
3982 | const char kLocalIdDelim[] = "_"; | ||||
3983 | |||||
3984 | // Container for storing generated string encoding as well as the raw python | ||||
3985 | // objects that were not included in the string. | ||||
3986 | struct EncodeResult { | ||||
3987 | string str; | ||||
3988 | std::vector<PyObject*> objects; | ||||
3989 | |||||
3990 | PyObject* ToPyTuple() { | ||||
3991 | PyObject* result = PyTuple_New(2); | ||||
3992 | |||||
3993 | PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str))PyTuple_SetItem(result, 0, GetPythonObjectFromString(str)); | ||||
3994 | |||||
3995 | if (objects.empty()) { | ||||
3996 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
3997 | PyTuple_SET_ITEM(result, 1, Py_None)PyTuple_SetItem(result, 1, (&_Py_NoneStruct)); | ||||
3998 | } else { | ||||
3999 | PyObject* objects_tuple = PyTuple_New(objects.size()); | ||||
4000 | |||||
4001 | for (int i = 0; i < objects.size(); i++) { | ||||
4002 | PyTuple_SET_ITEM(objects_tuple, i, objects[i])PyTuple_SetItem(objects_tuple, i, objects[i]); | ||||
4003 | } | ||||
4004 | |||||
4005 | PyTuple_SET_ITEM(result, 1, objects_tuple)PyTuple_SetItem(result, 1, objects_tuple); | ||||
4006 | } | ||||
4007 | |||||
4008 | return result; | ||||
4009 | } | ||||
4010 | }; | ||||
4011 | |||||
4012 | // Gives each unique resource_id a unique incremental local_id. Provides a | ||||
4013 | // string encoding that informs an order and uniqueness sensitive input | ||||
4014 | // signature. | ||||
4015 | // This class is not thread safe and is not meant to be shared across threads. | ||||
4016 | class LocalResourceIdMap { | ||||
4017 | public: | ||||
4018 | // When the resource ID is known (such as for OwnedIterator). | ||||
4019 | // Returns the existing local ID (if present) or a new unique one. | ||||
4020 | int AddResourceId(int resource_id) { | ||||
4021 | const auto& it = resource_id_to_local_id_.find(resource_id); | ||||
4022 | if (it == resource_id_to_local_id_.end()) { | ||||
4023 | resource_id_to_local_id_[resource_id] = next_local_id_; | ||||
4024 | return next_local_id_++; | ||||
4025 | } else { | ||||
4026 | return it->second; | ||||
4027 | } | ||||
4028 | } | ||||
4029 | |||||
4030 | // When the resource ID is not known (such as for IteratorSpec). | ||||
4031 | // Returns a new unique local ID. | ||||
4032 | int AddUnknownResource() { return next_local_id_++; } | ||||
4033 | |||||
4034 | private: | ||||
4035 | absl::flat_hash_map<int, int> resource_id_to_local_id_; | ||||
4036 | int next_local_id_ = 0; | ||||
4037 | }; | ||||
4038 | |||||
4039 | // Contains encoding configuration, intermediary data and result. | ||||
4040 | struct EncodingContext { | ||||
4041 | bool include_tensor_ranks_only; | ||||
4042 | bool encode_variable_by_resource_id; | ||||
4043 | |||||
4044 | LocalResourceIdMap resource_id_map; | ||||
4045 | EncodeResult result; | ||||
4046 | }; | ||||
4047 | |||||
4048 | tensorflow::Status EncodeTensorOrTensorSpec(PyObject* arg, bool is_tensor_spec, | ||||
4049 | EncodingContext& context) { | ||||
4050 | absl::StrAppend(&context.result.str, kTensor); | ||||
4051 | |||||
4052 | if (is_tensor_spec
| ||||
4053 | tensorflow::Safe_PyObjectPtr name(PyObject_GetAttrString(arg, "name")); | ||||
4054 | if (name != nullptr && name.get() != Py_None(&_Py_NoneStruct)) { | ||||
4055 | absl::StrAppend(&context.result.str, kName, | ||||
4056 | TFE_GetPythonString(name.get()), kNameEnd); | ||||
4057 | } | ||||
4058 | } | ||||
4059 | |||||
4060 | tensorflow::Safe_PyObjectPtr dtype_object( | ||||
4061 | PyObject_GetAttrString(arg, "dtype")); | ||||
4062 | if (dtype_object == nullptr) { | ||||
4063 | return tensorflow::errors::InvalidArgument( | ||||
4064 | "tf.TensorSpec object doesn't have dtype() attr."); | ||||
4065 | } | ||||
4066 | |||||
4067 | tensorflow::Safe_PyObjectPtr dtype_enum( | ||||
4068 | PyObject_GetAttrString(dtype_object.get(), "_type_enum")); | ||||
4069 | if (dtype_enum == nullptr) { | ||||
4070 | return tensorflow::errors::InvalidArgument( | ||||
4071 | "tf.TensorSpec's dtype object doesn't have _type_enum() " | ||||
4072 | "attr."); | ||||
4073 | } | ||||
4074 | |||||
4075 | tensorflow::DataType dtype = | ||||
4076 | static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get())); | ||||
4077 | absl::StrAppend(&context.result.str, kDType, dtype); | ||||
4078 | |||||
4079 | tensorflow::Safe_PyObjectPtr shape_tuple( | ||||
4080 | PyObject_GetAttrString(arg, "shape")); | ||||
4081 | if (shape_tuple == nullptr) { | ||||
4082 | return tensorflow::errors::InvalidArgument( | ||||
4083 | "tf.TensorSpec object doesn't have shape() attr."); | ||||
4084 | } | ||||
4085 | |||||
4086 | tensorflow::Safe_PyObjectPtr rank( | ||||
4087 | PyObject_GetAttr(shape_tuple.get(), PyUnicode_FromString("rank"))); | ||||
| |||||
4088 | if (rank == nullptr || rank.get() == Py_None(&_Py_NoneStruct)) { | ||||
4089 | // Unknown shape, encode that directly. | ||||
4090 | absl::StrAppend(&context.result.str, kNone); | ||||
4091 | return tensorflow::Status::OK(); | ||||
4092 | } | ||||
4093 | |||||
4094 | absl::StrAppend(&context.result.str, kShape); | ||||
4095 | |||||
4096 | tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast( | ||||
4097 | shape_tuple.get(), "shape_tuple didn't return a sequence")); | ||||
4098 | |||||
4099 | int len = MakeInt(rank.get()); | ||||
4100 | PyObject** shape_seq_array = PySequence_Fast_ITEMS(shape_seq.get())(((((((PyObject*)(shape_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(shape_seq .get()))->ob_item : ((PyTupleObject *)(shape_seq.get()))-> ob_item); | ||||
4101 | |||||
4102 | if (context.include_tensor_ranks_only) { | ||||
4103 | absl::StrAppend(&context.result.str, len); | ||||
4104 | } else { | ||||
4105 | for (int i = 0; i < len; ++i) { | ||||
4106 | // Can be None, int or a Dimension object. | ||||
4107 | PyObject* dimension = shape_seq_array[i]; | ||||
4108 | |||||
4109 | // If it is a Dimension object, then we must extract value from it first. | ||||
4110 | bool is_dimension_class = PyObject_HasAttrString(dimension, "value"); | ||||
4111 | tensorflow::Safe_PyObjectPtr dimension_holder; | ||||
4112 | if (is_dimension_class) { | ||||
4113 | dimension_holder = | ||||
4114 | tensorflow::make_safe(PyObject_GetAttrString(dimension, "value")); | ||||
4115 | dimension = dimension_holder.get(); | ||||
4116 | } | ||||
4117 | |||||
4118 | if (dimension == Py_None(&_Py_NoneStruct)) { | ||||
4119 | absl::StrAppend(&context.result.str, kNone); | ||||
4120 | } else { | ||||
4121 | absl::StrAppend(&context.result.str, MakeInt(dimension), kShapeDelim); | ||||
4122 | } | ||||
4123 | } | ||||
4124 | } | ||||
4125 | |||||
4126 | return tensorflow::Status::OK(); | ||||
4127 | } | ||||
4128 | |||||
4129 | // TODO(b/199534088): Remove this function by using EncodeResource instead. | ||||
4130 | tensorflow::Status EncodeOwnedIterator(PyObject* arg, | ||||
4131 | EncodingContext& context) { | ||||
4132 | PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec")); | ||||
4133 | if (type_spec == nullptr) { | ||||
4134 | return tensorflow::errors::InvalidArgument( | ||||
4135 | "Error while reading OwnedIterator._type_spec."); | ||||
4136 | } | ||||
4137 | context.result.objects.push_back(type_spec); | ||||
4138 | |||||
4139 | // Add resource tracking | ||||
4140 | tensorflow::Safe_PyObjectPtr itr_res( | ||||
4141 | PyObject_GetAttrString(arg, "_iterator_resource")); | ||||
4142 | if (itr_res == nullptr) { | ||||
4143 | return tensorflow::errors::InvalidArgument( | ||||
4144 | "Error while reading Dataset iterator resource."); | ||||
4145 | } | ||||
4146 | // OwnedIterator should ideally always provide a unique resource id. | ||||
4147 | // TODO(b/199534088) Cases where resource_id is not provided need to be fixed. | ||||
4148 | if (tensorflow::swig::IsTensor(itr_res.get())) { | ||||
4149 | absl::StrAppend(&context.result.str, kDIter); | ||||
4150 | tensorflow::Safe_PyObjectPtr py_resource_id( | ||||
4151 | PyObject_GetAttrString(itr_res.get(), "_id")); | ||||
4152 | if (py_resource_id == nullptr) { | ||||
4153 | return tensorflow::errors::InvalidArgument( | ||||
4154 | "Error while reading Dataset iterator resouce id."); | ||||
4155 | } | ||||
4156 | int resource_id = PyLong_AsSize_t(py_resource_id.get()); | ||||
4157 | if (resource_id < 0) { | ||||
4158 | return tensorflow::errors::InvalidArgument("PyLong_AsSize_t failure"); | ||||
4159 | } | ||||
4160 | int local_id = context.resource_id_map.AddResourceId(resource_id); | ||||
4161 | absl::StrAppend(&context.result.str, local_id, kLocalIdDelim); | ||||
4162 | } else { | ||||
4163 | // If '_iterator_resource' is not a Tensor, there is no resource id. | ||||
4164 | // Instead we treat it the same way as a CompositeTensor | ||||
4165 | absl::StrAppend(&context.result.str, kCompositeTensor); | ||||
4166 | } | ||||
4167 | return tensorflow::Status::OK(); | ||||
4168 | } | ||||
4169 | |||||
4170 | tensorflow::Status EncodeResource(PyObject* arg, EncodingContext& context) { | ||||
4171 | absl::StrAppend(&context.result.str, kResourceVariable); | ||||
4172 | tensorflow::Safe_PyObjectPtr py_resource_id( | ||||
4173 | PyObject_CallMethod(arg, "__tf_resource_id__", nullptr)); | ||||
4174 | DCHECK(py_resource_id != nullptr)while (false && (py_resource_id != nullptr)) ::tensorflow ::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 4174); | ||||
4175 | |||||
4176 | int resource_id = PyLong_AsSize_t(py_resource_id.get()); | ||||
4177 | DCHECK_GE(resource_id, 0)while (false && ((void)(resource_id), (void)(0), 0)) :: tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 4177); | ||||
4178 | int local_id = context.resource_id_map.AddResourceId(resource_id); | ||||
4179 | absl::StrAppend(&context.result.str, local_id, kLocalIdDelim); | ||||
4180 | |||||
4181 | tensorflow::Safe_PyObjectPtr type_spec( | ||||
4182 | PyObject_CallMethod(arg, "__tf_function_cache_spec__", nullptr)); | ||||
4183 | absl::StrAppend(&context.result.str, PyUnicode_AsUTF8(type_spec.get())); | ||||
4184 | DCHECK(type_spec != nullptr)while (false && (type_spec != nullptr)) ::tensorflow:: internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 4184); | ||||
4185 | |||||
4186 | return tensorflow::Status::OK(); | ||||
4187 | } | ||||
4188 | |||||
4189 | tensorflow::Status EncodeArgHelperInternal(PyObject* arg, | ||||
4190 | EncodingContext& context); | ||||
4191 | |||||
4192 | // This function doesn't set the type of sequence before | ||||
4193 | tensorflow::Status EncodeSequence(PyObject* arg, const char* type, | ||||
4194 | const char* end_type, | ||||
4195 | EncodingContext& context) { | ||||
4196 | tensorflow::Safe_PyObjectPtr arg_seq( | ||||
4197 | PySequence_Fast(arg, "unable to create seq from list/tuple")); | ||||
4198 | |||||
4199 | absl::StrAppend(&context.result.str, type); | ||||
4200 | int len = PySequence_Fast_GET_SIZE(arg_seq.get())(((((((PyObject*)(arg_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(arg_seq.get()))->ob_size)) : (((PyVarObject *)(((static_cast<void> (0)), (PyTupleObject *)(arg_seq. get()))))->ob_size)); | ||||
4201 | PyObject** arg_seq_array = PySequence_Fast_ITEMS(arg_seq.get())(((((((PyObject*)(arg_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(arg_seq.get() ))->ob_item : ((PyTupleObject *)(arg_seq.get()))->ob_item ); | ||||
4202 | for (int i = 0; i < len; ++i) { | ||||
4203 | PyObject* item = arg_seq_array[i]; | ||||
4204 | if (item == Py_None(&_Py_NoneStruct)) { | ||||
4205 | absl::StrAppend(&context.result.str, kNone); | ||||
4206 | } else { | ||||
4207 | TF_RETURN_IF_ERROR(EncodeArgHelperInternal(item, context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal( item, context)); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | ||||
4208 | } | ||||
4209 | } | ||||
4210 | absl::StrAppend(&context.result.str, end_type); | ||||
4211 | |||||
4212 | return tensorflow::Status::OK(); | ||||
4213 | } | ||||
4214 | |||||
4215 | tensorflow::Status EncodeMapping(PyObject* arg, EncodingContext& context) { | ||||
4216 | tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg)); | ||||
4217 | if (PyList_Sort(keys.get()) == -1) { | ||||
4218 | return tensorflow::errors::Internal("Unable to sort keys"); | ||||
4219 | } | ||||
4220 | |||||
4221 | absl::StrAppend(&context.result.str, kDict); | ||||
4222 | int len = PyList_Size(keys.get()); | ||||
4223 | |||||
4224 | for (int i = 0; i < len; i++) { | ||||
4225 | PyObject* key = PyList_GetItem(keys.get(), i); | ||||
4226 | TF_RETURN_IF_ERROR(EncodeArgHelperInternal(key, context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal( key, context)); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | ||||
4227 | tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key)); | ||||
4228 | TF_RETURN_IF_ERROR(EncodeArgHelperInternal(value.get(), context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal( value.get(), context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0); | ||||
4229 | } | ||||
4230 | |||||
4231 | return tensorflow::Status::OK(); | ||||
4232 | } | ||||
4233 | |||||
4234 | tensorflow::Status EncodeCompositeTensor(PyObject* arg, | ||||
4235 | EncodingContext& context) { | ||||
4236 | absl::StrAppend(&context.result.str, kCompositeTensor); | ||||
4237 | PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec")); | ||||
4238 | if (type_spec == nullptr) { | ||||
4239 | return tensorflow::errors::InvalidArgument( | ||||
4240 | "Error while reading CompositeTensor._type_spec."); | ||||
4241 | } | ||||
4242 | context.result.objects.push_back(type_spec); | ||||
4243 | |||||
4244 | return tensorflow::Status::OK(); | ||||
4245 | } | ||||
4246 | |||||
4247 | tensorflow::Status EncodeTypeSpec(PyObject* arg, EncodingContext& context) { | ||||
4248 | absl::StrAppend(&context.result.str, kRaw); | ||||
4249 | Py_INCREF(arg)_Py_INCREF(((PyObject*)(arg))); | ||||
4250 | context.result.objects.push_back(arg); | ||||
4251 | return tensorflow::Status::OK(); | ||||
4252 | } | ||||
4253 | |||||
4254 | tensorflow::Status EncodeAttrs(PyObject* arg, EncodingContext& context) { | ||||
4255 | absl::StrAppend(&context.result.str, kAttrs); | ||||
4256 | tensorflow::Safe_PyObjectPtr attrs( | ||||
4257 | PyObject_GetAttrString(arg, "__attrs_attrs__")); | ||||
4258 | tensorflow::Safe_PyObjectPtr iter(PyObject_GetIter(attrs.get())); | ||||
4259 | for (tensorflow::Safe_PyObjectPtr item(PyIter_Next(iter.get())); item; | ||||
4260 | item.reset(PyIter_Next(iter.get()))) { | ||||
4261 | tensorflow::Safe_PyObjectPtr name( | ||||
4262 | PyObject_GetAttrString(item.get(), "name")); | ||||
4263 | tensorflow::Safe_PyObjectPtr attr_arg(PyObject_GetAttr(arg, name.get())); | ||||
4264 | TF_RETURN_IF_ERROR(EncodeArgHelperInternal(attr_arg.get(), context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal( attr_arg.get(), context)); if ((__builtin_expect(!_status.ok( ), 0))) return _status; } while (0); | ||||
4265 | } | ||||
4266 | absl::StrAppend(&context.result.str, kAttrsEnd); | ||||
4267 | |||||
4268 | return tensorflow::Status::OK(); | ||||
4269 | } | ||||
4270 | |||||
4271 | tensorflow::Status EncodeUnidentified(PyObject* arg, EncodingContext& context) { | ||||
4272 | // We hold a weak reference because cache keys live practically forever, and | ||||
4273 | // this may leak heavy objects. | ||||
4274 | PyObject* object = PyWeakref_NewRef(arg, nullptr); | ||||
4275 | if (object == nullptr) { | ||||
4276 | PyErr_Clear(); | ||||
4277 | object = arg; | ||||
4278 | Py_INCREF(object)_Py_INCREF(((PyObject*)(object))); | ||||
4279 | } | ||||
4280 | |||||
4281 | absl::StrAppend(&context.result.str, kRaw); | ||||
4282 | context.result.objects.push_back(object); | ||||
4283 | return tensorflow::Status::OK(); | ||||
4284 | } | ||||
4285 | |||||
4286 | tensorflow::Status EncodeArgHelperInternal(PyObject* arg, | ||||
4287 | EncodingContext& context) { | ||||
4288 | if (tensorflow::swig::IsTensorSpec(arg)) { | ||||
4289 | TF_RETURN_IF_ERROR(EncodeTensorOrTensorSpec(arg, true, context))do { ::tensorflow::Status _status = (EncodeTensorOrTensorSpec (arg, true, context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0); | ||||
4290 | } else if (tensorflow::swig::IsTensor(arg)) { | ||||
4291 | TF_RETURN_IF_ERROR(EncodeTensorOrTensorSpec(arg, false, context))do { ::tensorflow::Status _status = (EncodeTensorOrTensorSpec (arg, false, context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0); | ||||
4292 | } else if (tensorflow::swig::IsOwnedIterator(arg)) { | ||||
4293 | TF_RETURN_IF_ERROR(EncodeOwnedIterator(arg, context))do { ::tensorflow::Status _status = (EncodeOwnedIterator(arg, context)); if ((__builtin_expect(!_status.ok(), 0))) return _status ; } while (0); | ||||
4294 | } else if (PyList_Check(arg)((((((PyObject*)(arg))->ob_type))->tp_flags & ((1UL << 25))) != 0)) { | ||||
4295 | TF_RETURN_IF_ERROR(EncodeSequence(arg, kList, kListEnd, context))do { ::tensorflow::Status _status = (EncodeSequence(arg, kList , kListEnd, context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0); | ||||
4296 | } else if (tensorflow::swig::IsTuple(arg)) { | ||||
4297 | TF_RETURN_IF_ERROR(EncodeSequence(arg, kTuple, kTupleEnd, context))do { ::tensorflow::Status _status = (EncodeSequence(arg, kTuple , kTupleEnd, context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0); | ||||
4298 | } else if (tensorflow::swig::IsMapping(arg)) { | ||||
4299 | TF_RETURN_IF_ERROR(EncodeMapping(arg, context))do { ::tensorflow::Status _status = (EncodeMapping(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | ||||
4300 | } else if (tensorflow::swig::IsCompositeTensor(arg)) { | ||||
4301 | TF_RETURN_IF_ERROR(EncodeCompositeTensor(arg, context))do { ::tensorflow::Status _status = (EncodeCompositeTensor(arg , context)); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | ||||
4302 | } else if (tensorflow::swig::IsTypeSpec(arg)) { | ||||
4303 | TF_RETURN_IF_ERROR(EncodeTypeSpec(arg, context))do { ::tensorflow::Status _status = (EncodeTypeSpec(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | ||||
4304 | } else if (tensorflow::swig::IsAttrs(arg)) { | ||||
4305 | TF_RETURN_IF_ERROR(EncodeAttrs(arg, context))do { ::tensorflow::Status _status = (EncodeAttrs(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | ||||
4306 | } else if (tensorflow::swig::IsResourceVariable(arg) && | ||||
4307 | context.encode_variable_by_resource_id) { | ||||
4308 | TF_RETURN_IF_ERROR(EncodeResource(arg, context))do { ::tensorflow::Status _status = (EncodeResource(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | ||||
4309 | } else { | ||||
4310 | TF_RETURN_IF_ERROR(EncodeUnidentified(arg, context))do { ::tensorflow::Status _status = (EncodeUnidentified(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0); | ||||
4311 | } | ||||
4312 | |||||
4313 | return tensorflow::Status::OK(); | ||||
4314 | } | ||||
4315 | |||||
4316 | tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, | ||||
4317 | EncodingContext& context) { | ||||
4318 | auto status = EncodeArgHelperInternal(arg, context); | ||||
4319 | return status; | ||||
4320 | } | ||||
4321 | |||||
4322 | } // namespace | ||||
4323 | |||||
4324 | // `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes | ||||
4325 | // are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes | ||||
4326 | // are used for both performance reasons, as much TensorFlow code specializes | ||||
4327 | // on known shapes to produce slimmer graphs, and correctness, as some | ||||
4328 | // high-level APIs require shapes to be fully-known. | ||||
4329 | // | ||||
4330 | // `include_tensor_ranks_only` allows caching on arguments excluding shape info, | ||||
4331 | // so that a slow path using relaxed shape can rely on a cache key that excludes | ||||
4332 | // shapes. | ||||
4333 | PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only, | ||||
4334 | bool encode_variable_by_resource_id) { | ||||
4335 | EncodingContext context; | ||||
4336 | context.include_tensor_ranks_only = include_tensor_ranks_only; | ||||
4337 | context.encode_variable_by_resource_id = encode_variable_by_resource_id; | ||||
4338 | const auto status = TFE_Py_EncodeArgHelper(arg, context); | ||||
| |||||
4339 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { | ||||
4340 | return nullptr; | ||||
4341 | } | ||||
4342 | |||||
4343 | return context.result.ToPyTuple(); | ||||
4344 | } | ||||
4345 | |||||
4346 | // A method prints incoming messages directly to Python's | ||||
4347 | // stdout using Python's C API. This is necessary in Jupyter notebooks | ||||
4348 | // and colabs where messages to the C stdout don't go to the notebook | ||||
4349 | // cell outputs, but calls to Python's stdout do. | ||||
4350 | void PrintToPythonStdout(const char* msg) { | ||||
4351 | if (Py_IsInitialized()) { | ||||
4352 | PyGILState_STATE py_threadstate; | ||||
4353 | py_threadstate = PyGILState_Ensure(); | ||||
4354 | |||||
4355 | string string_msg = msg; | ||||
4356 | // PySys_WriteStdout truncates strings over 1000 bytes, so | ||||
4357 | // we write the message in chunks small enough to not be truncated. | ||||
4358 | int CHUNK_SIZE = 900; | ||||
4359 | auto len = string_msg.length(); | ||||
4360 | for (int i = 0; i < len; i += CHUNK_SIZE) { | ||||
4361 | PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str()); | ||||
4362 | } | ||||
4363 | |||||
4364 | // Force flushing to make sure print newlines aren't interleaved in | ||||
4365 | // some colab environments | ||||
4366 | PyRun_SimpleString("import sys; sys.stdout.flush()")PyRun_SimpleStringFlags("import sys; sys.stdout.flush()", __null ); | ||||
4367 | |||||
4368 | PyGILState_Release(py_threadstate); | ||||
4369 | } | ||||
4370 | } | ||||
4371 | |||||
4372 | // Register PrintToPythonStdout as a log listener, to allow | ||||
4373 | // printing in colabs and jupyter notebooks to work. | ||||
4374 | void TFE_Py_EnableInteractivePythonLogging() { | ||||
4375 | static bool enabled_interactive_logging = false; | ||||
4376 | if (!enabled_interactive_logging) { | ||||
4377 | enabled_interactive_logging = true; | ||||
4378 | TF_RegisterLogListener(PrintToPythonStdout); | ||||
4379 | } | ||||
4380 | } | ||||
4381 | |||||
4382 | namespace { | ||||
4383 | // weak reference to Python Context object currently active | ||||
4384 | PyObject* weak_eager_context = nullptr; | ||||
4385 | } // namespace | ||||
4386 | |||||
4387 | PyObject* TFE_Py_SetEagerContext(PyObject* py_context) { | ||||
4388 | Py_XDECREF(weak_eager_context)_Py_XDECREF(((PyObject*)(weak_eager_context))); | ||||
4389 | weak_eager_context = PyWeakref_NewRef(py_context, nullptr); | ||||
4390 | if (weak_eager_context == nullptr) { | ||||
4391 | return nullptr; | ||||
4392 | } | ||||
4393 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | ||||
4394 | } | ||||
4395 | |||||
4396 | PyObject* GetPyEagerContext() { | ||||
4397 | if (weak_eager_context == nullptr) { | ||||
4398 | PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set"); | ||||
4399 | return nullptr; | ||||
4400 | } | ||||
4401 | PyObject* py_context = PyWeakref_GET_OBJECT(weak_eager_context)((((PyObject*)(((PyWeakReference *)(weak_eager_context))-> wr_object))->ob_refcnt) > 0 ? ((PyWeakReference *)(weak_eager_context ))->wr_object : (&_Py_NoneStruct)); | ||||
4402 | if (py_context == Py_None(&_Py_NoneStruct)) { | ||||
4403 | PyErr_SetString(PyExc_RuntimeError, "Eager context has been destroyed"); | ||||
4404 | return nullptr; | ||||
4405 | } | ||||
4406 | Py_INCREF(py_context)_Py_INCREF(((PyObject*)(py_context))); | ||||
4407 | return py_context; | ||||
4408 | } | ||||
4409 | |||||
4410 | namespace { | ||||
4411 | |||||
4412 | // Default values for thread_local_data fields. | ||||
4413 | struct EagerContextThreadLocalDataDefaults { | ||||
4414 | tensorflow::Safe_PyObjectPtr is_eager; | ||||
4415 | tensorflow::Safe_PyObjectPtr device_spec; | ||||
4416 | }; | ||||
4417 | |||||
4418 | // Maps each py_eager_context object to its thread_local_data. | ||||
4419 | // | ||||
4420 | // Note: we need to use the python Context object as the key here (and not | ||||
4421 | // its handle object), because the handle object isn't created until the | ||||
4422 | // context is initialized; but thread_local_data is potentially accessed | ||||
4423 | // before then. | ||||
4424 | using EagerContextThreadLocalDataMap = absl::flat_hash_map< | ||||
4425 | PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>; | ||||
4426 | thread_local EagerContextThreadLocalDataMap* | ||||
4427 | eager_context_thread_local_data_map = nullptr; | ||||
4428 | |||||
4429 | // Maps each py_eager_context object to default values. | ||||
4430 | using EagerContextThreadLocalDataDefaultsMap = | ||||
4431 | absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>; | ||||
4432 | EagerContextThreadLocalDataDefaultsMap* | ||||
4433 | eager_context_thread_local_data_defaults = nullptr; | ||||
4434 | |||||
4435 | } // namespace | ||||
4436 | |||||
4437 | namespace tensorflow { | ||||
4438 | |||||
4439 | void MakeEagerContextThreadLocalData(PyObject* py_eager_context, | ||||
4440 | PyObject* is_eager, | ||||
4441 | PyObject* device_spec) { | ||||
4442 | DCheckPyGilState(); | ||||
4443 | if (eager_context_thread_local_data_defaults == nullptr) { | ||||
4444 | absl::LeakCheckDisabler disabler; | ||||
4445 | eager_context_thread_local_data_defaults = | ||||
4446 | new EagerContextThreadLocalDataDefaultsMap(); | ||||
4447 | } | ||||
4448 | if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) { | ||||
4449 | PyErr_SetString(PyExc_AssertionError, | ||||
4450 | "MakeEagerContextThreadLocalData may not be called " | ||||
4451 | "twice on the same eager Context object."); | ||||
4452 | } | ||||
4453 | |||||
4454 | auto& defaults = | ||||
4455 | (*eager_context_thread_local_data_defaults)[py_eager_context]; | ||||
4456 | Py_INCREF(is_eager)_Py_INCREF(((PyObject*)(is_eager))); | ||||
4457 | defaults.is_eager.reset(is_eager); | ||||
4458 | Py_INCREF(device_spec)_Py_INCREF(((PyObject*)(device_spec))); | ||||
4459 | defaults.device_spec.reset(device_spec); | ||||
4460 | } | ||||
4461 | |||||
4462 | EagerContextThreadLocalData* GetEagerContextThreadLocalData( | ||||
4463 | PyObject* py_eager_context) { | ||||
4464 | if (eager_context_thread_local_data_defaults == nullptr) { | ||||
4465 | PyErr_SetString(PyExc_AssertionError, | ||||
4466 | "MakeEagerContextThreadLocalData must be called " | ||||
4467 | "before GetEagerContextThreadLocalData."); | ||||
4468 | return nullptr; | ||||
4469 | } | ||||
4470 | auto defaults = | ||||
4471 | eager_context_thread_local_data_defaults->find(py_eager_context); | ||||
4472 | if (defaults == eager_context_thread_local_data_defaults->end()) { | ||||
4473 | PyErr_SetString(PyExc_AssertionError, | ||||
4474 | "MakeEagerContextThreadLocalData must be called " | ||||
4475 | "before GetEagerContextThreadLocalData."); | ||||
4476 | return nullptr; | ||||
4477 | } | ||||
4478 | |||||
4479 | if (eager_context_thread_local_data_map == nullptr) { | ||||
4480 | absl::LeakCheckDisabler disabler; | ||||
4481 | eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap(); | ||||
4482 | } | ||||
4483 | auto& thread_local_data = | ||||
4484 | (*eager_context_thread_local_data_map)[py_eager_context]; | ||||
4485 | |||||
4486 | if (!thread_local_data) { | ||||
4487 | thread_local_data.reset(new EagerContextThreadLocalData()); | ||||
4488 | |||||
4489 | Safe_PyObjectPtr is_eager( | ||||
4490 | PyObject_CallFunctionObjArgs(defaults->second.is_eager.get(), nullptr)); | ||||
4491 | if (!is_eager) return nullptr; | ||||
4492 | thread_local_data->is_eager = PyObject_IsTrue(is_eager.get()); | ||||
4493 | |||||
4494 | #if PY_MAJOR_VERSION3 >= 3 | ||||
4495 | PyObject* scope_name = PyUnicode_FromString(""); | ||||
4496 | #else | ||||
4497 | PyObject* scope_name = PyString_FromString(""); | ||||
4498 | #endif | ||||
4499 | thread_local_data->scope_name.reset(scope_name); | ||||
4500 | |||||
4501 | #if PY_MAJOR_VERSION3 >= 3 | ||||
4502 | PyObject* device_name = PyUnicode_FromString(""); | ||||
4503 | #else | ||||
4504 | PyObject* device_name = PyString_FromString(""); | ||||
4505 | #endif | ||||
4506 | thread_local_data->device_name.reset(device_name); | ||||
4507 | |||||
4508 | Py_INCREF(defaults->second.device_spec.get())_Py_INCREF(((PyObject*)(defaults->second.device_spec.get() ))); | ||||
4509 | thread_local_data->device_spec.reset(defaults->second.device_spec.get()); | ||||
4510 | |||||
4511 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
4512 | thread_local_data->function_call_options.reset(Py_None(&_Py_NoneStruct)); | ||||
4513 | |||||
4514 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
4515 | thread_local_data->executor.reset(Py_None(&_Py_NoneStruct)); | ||||
4516 | |||||
4517 | thread_local_data->op_callbacks.reset(PyList_New(0)); | ||||
4518 | } | ||||
4519 | return thread_local_data.get(); | ||||
4520 | } | ||||
4521 | |||||
4522 | void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) { | ||||
4523 | DCheckPyGilState(); | ||||
4524 | if (eager_context_thread_local_data_defaults) { | ||||
4525 | eager_context_thread_local_data_defaults->erase(py_eager_context); | ||||
4526 | } | ||||
4527 | if (eager_context_thread_local_data_map) { | ||||
4528 | eager_context_thread_local_data_map->erase(py_eager_context); | ||||
4529 | } | ||||
4530 | } | ||||
4531 | |||||
4532 | } // namespace tensorflow |
1 | #ifndef PyUnicode_FromString |
2 | struct _object; |
3 | typedef struct _object PyObject; |
4 | PyObject* clang_analyzer_PyObject_New_Reference(); |
5 | PyObject *PyUnicode_FromString(const char *u) { |
6 | return clang_analyzer_PyObject_New_Reference(); |
7 | } |
8 | #else |
9 | #warning "API PyUnicode_FromString is defined as a macro." |
10 | #endif |