| File: | .cache/bazel/_bazel_alan/39be661231df2a680c9b74265384c13c/execroot/org_tensorflow/tensorflow/python/eager/pywrap_tfe_src.cc | 
| Warning: | line 4087, column 43 PyObject ownership leak with reference count of 1  | 
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
| 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||||
| 2 | |||||
| 3 | Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| 4 | you may not use this file except in compliance with the License. | ||||
| 5 | You may obtain a copy of the License at | ||||
| 6 | |||||
| 7 | http://www.apache.org/licenses/LICENSE-2.0 | ||||
| 8 | |||||
| 9 | Unless required by applicable law or agreed to in writing, software | ||||
| 10 | distributed under the License is distributed on an "AS IS" BASIS, | ||||
| 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| 12 | See the License for the specific language governing permissions and | ||||
| 13 | limitations under the License. | ||||
| 14 | ==============================================================================*/ | ||||
| 15 | |||||
| 16 | #include <atomic> | ||||
| 17 | #include <cstring> | ||||
| 18 | #include <unordered_map> | ||||
| 19 | |||||
| 20 | #include "absl/debugging/leak_check.h" | ||||
| 21 | #include "absl/strings/str_cat.h" | ||||
| 22 | #include "absl/types/variant.h" | ||||
| 23 | #include "tensorflow/c/c_api.h" | ||||
| 24 | #include "tensorflow/c/c_api_internal.h" | ||||
| 25 | #include "tensorflow/c/eager/c_api.h" | ||||
| 26 | #include "tensorflow/c/eager/c_api_internal.h" | ||||
| 27 | #include "tensorflow/c/eager/tape.h" | ||||
| 28 | #include "tensorflow/c/eager/tfe_context_internal.h" | ||||
| 29 | #include "tensorflow/c/eager/tfe_op_internal.h" | ||||
| 30 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" | ||||
| 31 | #include "tensorflow/c/tf_status.h" | ||||
| 32 | #include "tensorflow/core/framework/types.pb.h" | ||||
| 33 | #include "tensorflow/core/lib/core/errors.h" | ||||
| 34 | #include "tensorflow/core/lib/gtl/cleanup.h" | ||||
| 35 | #include "tensorflow/core/lib/gtl/compactptrset.h" | ||||
| 36 | #include "tensorflow/core/lib/gtl/flatmap.h" | ||||
| 37 | #include "tensorflow/core/lib/gtl/flatset.h" | ||||
| 38 | #include "tensorflow/core/lib/strings/strcat.h" | ||||
| 39 | #include "tensorflow/core/lib/strings/stringprintf.h" | ||||
| 40 | #include "tensorflow/core/platform/casts.h" | ||||
| 41 | #include "tensorflow/core/platform/errors.h" | ||||
| 42 | #include "tensorflow/core/platform/mutex.h" | ||||
| 43 | #include "tensorflow/core/platform/protobuf.h" | ||||
| 44 | #include "tensorflow/core/platform/status.h" | ||||
| 45 | #include "tensorflow/core/platform/types.h" | ||||
| 46 | #include "tensorflow/core/profiler/lib/traceme.h" | ||||
| 47 | #include "tensorflow/core/util/managed_stack_trace.h" | ||||
| 48 | #include "tensorflow/python/eager/pywrap_gradient_exclusions.h" | ||||
| 49 | #include "tensorflow/python/eager/pywrap_tensor.h" | ||||
| 50 | #include "tensorflow/python/eager/pywrap_tfe.h" | ||||
| 51 | #include "tensorflow/python/lib/core/py_util.h" | ||||
| 52 | #include "tensorflow/python/lib/core/safe_ptr.h" | ||||
| 53 | #include "tensorflow/python/util/stack_trace.h" | ||||
| 54 | #include "tensorflow/python/util/util.h" | ||||
| 55 | |||||
| 56 | using tensorflow::Status; | ||||
| 57 | using tensorflow::string; | ||||
| 58 | using tensorflow::strings::Printf; | ||||
| 59 | |||||
| 60 | namespace { | ||||
| 61 | // NOTE: Items are retrieved from and returned to these unique_ptrs, and they | ||||
| 62 | // act as arenas. This is important if the same thread requests 2 items without | ||||
| 63 | // releasing one. | ||||
| 64 | // The following sequence of events on the same thread will still succeed: | ||||
| 65 | // - GetOp <- Returns existing. | ||||
| 66 | // - GetOp <- Allocates and returns a new pointer. | ||||
| 67 | // - ReleaseOp <- Sets the item in the unique_ptr. | ||||
| 68 | // - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one. | ||||
| 69 | // This occurs when a PyFunc kernel is run. This behavior makes it safe in that | ||||
| 70 | // case, as well as the case where python decides to reuse the underlying | ||||
| 71 | // C++ thread in 2 python threads case. | ||||
| 72 | struct OpDeleter { | ||||
| 73 | void operator()(TFE_Op* op) const { TFE_DeleteOp(op); } | ||||
| 74 | }; | ||||
| 75 | thread_local std::unordered_map<TFE_Context*, | ||||
| 76 | std::unique_ptr<TFE_Op, OpDeleter>> | ||||
| 77 | thread_local_eager_operation_map; // NOLINT | ||||
| 78 | thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT | ||||
| 79 | nullptr; | ||||
| 80 | |||||
| 81 | std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) { | ||||
| 82 | auto it = thread_local_eager_operation_map.find(ctx); | ||||
| 83 | if (it == thread_local_eager_operation_map.end()) { | ||||
| 84 | return nullptr; | ||||
| 85 | } | ||||
| 86 | return std::move(it->second); | ||||
| 87 | } | ||||
| 88 | |||||
| 89 | TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, | ||||
| 90 | const char* raw_device_name, TF_Status* status) { | ||||
| 91 | auto op = ReleaseThreadLocalOp(ctx); | ||||
| 92 | if (!op) { | ||||
| 93 | op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation())); | ||||
| 94 | } | ||||
| 95 | status->status = | ||||
| 96 | tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name); | ||||
| 97 | if (!status->status.ok()) { | ||||
| 98 | op.reset(); | ||||
| 99 | } | ||||
| 100 | return op.release(); | ||||
| 101 | } | ||||
| 102 | |||||
| 103 | void ReturnOp(TFE_Context* ctx, TFE_Op* op) { | ||||
| 104 | if (op) { | ||||
| 105 | tensorflow::unwrap(op)->Clear(); | ||||
| 106 | thread_local_eager_operation_map[ctx].reset(op); | ||||
| 107 | } | ||||
| 108 | } | ||||
| 109 | |||||
| 110 | TF_Status* ReleaseThreadLocalStatus() { | ||||
| 111 | if (thread_local_tf_status == nullptr) { | ||||
| 112 | return nullptr; | ||||
| 113 | } | ||||
| 114 | return thread_local_tf_status.release(); | ||||
| 115 | } | ||||
| 116 | |||||
| 117 | struct InputInfo { | ||||
| 118 | InputInfo(int i, bool is_list) : i(i), is_list(is_list) {} | ||||
| 119 | |||||
| 120 | int i; | ||||
| 121 | bool is_list = false; | ||||
| 122 | }; | ||||
| 123 | |||||
| 124 | // Takes in output gradients, returns input gradients. | ||||
| 125 | typedef std::function<PyObject*(PyObject*, const std::vector<int64_t>&)> | ||||
| 126 | PyBackwardFunction; | ||||
| 127 | |||||
| 128 | using AttrToInputsMap = | ||||
| 129 | tensorflow::gtl::FlatMap<string, | ||||
| 130 | tensorflow::gtl::InlinedVector<InputInfo, 4>>; | ||||
| 131 | |||||
| 132 | tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() { | ||||
| 133 | static auto* all_attr_to_input_maps = | ||||
| 134 | new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>; | ||||
| 135 | return all_attr_to_input_maps; | ||||
| 136 | } | ||||
| 137 | |||||
| 138 | // This function doesn't use a lock, since we depend on the GIL directly. | ||||
| 139 | AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) { | ||||
| 140 | #if PY_MAJOR_VERSION3 >= 3 && PY_MINOR_VERSION8 >= 4 | ||||
| 141 |   DCHECK(PyGILState_Check())while (false && (PyGILState_Check())) ::tensorflow::internal ::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 141)  | ||||
| 142 | << "This function needs to hold the GIL when called."; | ||||
| 143 | #endif | ||||
| 144 | auto* all_attr_to_input_maps = GetAllAttrToInputsMaps(); | ||||
| 145 | auto* output = | ||||
| 146 | tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name()); | ||||
| 147 | if (output != nullptr) { | ||||
| 148 | return output; | ||||
| 149 | } | ||||
| 150 | |||||
| 151 | std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap); | ||||
| 152 | |||||
| 153 | // Store a list of InputIndex -> List of corresponding inputs. | ||||
| 154 | for (int i = 0; i < op_def.input_arg_size(); i++) { | ||||
| 155 | if (!op_def.input_arg(i).type_attr().empty()) { | ||||
| 156 | auto it = m->find(op_def.input_arg(i).type_attr()); | ||||
| 157 | if (it == m->end()) { | ||||
| 158 | it = m->insert({op_def.input_arg(i).type_attr(), {}}).first; | ||||
| 159 | } | ||||
| 160 | it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty()); | ||||
| 161 | } | ||||
| 162 | } | ||||
| 163 | |||||
| 164 | auto* retval = m.get(); | ||||
| 165 | (*all_attr_to_input_maps)[op_def.name()] = m.release(); | ||||
| 166 | |||||
| 167 | return retval; | ||||
| 168 | } | ||||
| 169 | |||||
| 170 | // This function doesn't use a lock, since we depend on the GIL directly. | ||||
| 171 | tensorflow::gtl::FlatMap< | ||||
| 172 | string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>* | ||||
| 173 | GetAllAttrToDefaultsMaps() { | ||||
| 174 | static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap< | ||||
| 175 | string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>; | ||||
| 176 | return all_attr_to_defaults_maps; | ||||
| 177 | } | ||||
| 178 | |||||
| 179 | tensorflow::gtl::FlatMap<string, tensorflow::DataType>* | ||||
| 180 | GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) { | ||||
| 181 | #if PY_MAJOR_VERSION3 >= 3 && PY_MINOR_VERSION8 >= 4 | ||||
| 182 |   DCHECK(PyGILState_Check())while (false && (PyGILState_Check())) ::tensorflow::internal ::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 182)  | ||||
| 183 | << "This function needs to hold the GIL when called."; | ||||
| 184 | #endif | ||||
| 185 | auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps(); | ||||
| 186 | auto* output = | ||||
| 187 | tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name()); | ||||
| 188 | if (output != nullptr) { | ||||
| 189 | return output; | ||||
| 190 | } | ||||
| 191 | |||||
| 192 | auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>; | ||||
| 193 | |||||
| 194 | for (const auto& attr : op_def.attr()) { | ||||
| 195 | if (attr.type() == "type" && attr.has_default_value()) { | ||||
| 196 | new_map->insert({attr.name(), attr.default_value().type()}); | ||||
| 197 | } | ||||
| 198 | } | ||||
| 199 | |||||
| 200 | (*all_attr_to_defaults_maps)[op_def.name()] = new_map; | ||||
| 201 | |||||
| 202 | return new_map; | ||||
| 203 | } | ||||
| 204 | |||||
| 205 | struct FastPathOpExecInfo { | ||||
| 206 | TFE_Context* ctx; | ||||
| 207 | const char* device_name; | ||||
| 208 | |||||
| 209 | bool run_callbacks; | ||||
| 210 | bool run_post_exec_callbacks; | ||||
| 211 | bool run_gradient_callback; | ||||
| 212 | |||||
| 213 | // The op name of the main op being executed. | ||||
| 214 | PyObject* name; | ||||
| 215 | // The op type name of the main op being executed. | ||||
| 216 | PyObject* op_name; | ||||
| 217 | PyObject* callbacks; | ||||
| 218 | |||||
| 219 | // All the args passed into the FastPathOpExecInfo. | ||||
| 220 | PyObject* args; | ||||
| 221 | |||||
| 222 | // DTypes can come from another input that has the same attr. So build that | ||||
| 223 | // map. | ||||
| 224 | const AttrToInputsMap* attr_to_inputs_map; | ||||
| 225 | const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes; | ||||
| 226 | tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes; | ||||
| 227 | }; | ||||
| 228 | |||||
| 229 | #define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \ | ||||
| 230 | bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \ | ||||
| 231 | type* value) { \ | ||||
| 232 | if (check_fn(py_value)) { \ | ||||
| 233 | *value = static_cast<type>(parse_fn(py_value)); \ | ||||
| 234 | return true; \ | ||||
| 235 | } else { \ | ||||
| 236 | TF_SetStatus(status, TF_INVALID_ARGUMENT, \ | ||||
| 237 | tensorflow::strings::StrCat( \ | ||||
| 238 | "Expecting " #type " value for attr ", key, ", got ", \ | ||||
| 239 | py_value->ob_type->tp_name) \ | ||||
| 240 | .c_str()); \ | ||||
| 241 | return false; \ | ||||
| 242 | } \ | ||||
| 243 | } | ||||
| 244 | |||||
| 245 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 246 | PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong) | ||||
| 247 | PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong) | ||||
| 248 | #else | ||||
| 249 | PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong) | ||||
| 250 | #endif | ||||
| 251 | PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble) | ||||
| 252 | #undef PARSE_VALUE | ||||
| 253 | |||||
| 254 | #if PY_MAJOR_VERSION3 < 3 | ||||
| 255 | bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status, | ||||
| 256 | int64_t* value) { | ||||
| 257 | if (PyInt_Check(py_value)) { | ||||
| 258 | *value = static_cast<int64_t>(PyInt_AsLong(py_value)); | ||||
| 259 | return true; | ||||
| 260 |   } else if (PyLong_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 24))) != 0)) {  | ||||
| 261 | *value = static_cast<int64_t>(PyLong_AsLong(py_value)); | ||||
| 262 | return true; | ||||
| 263 | } | ||||
| 264 | TF_SetStatus( | ||||
| 265 | status, TF_INVALID_ARGUMENT, | ||||
| 266 | tensorflow::strings::StrCat("Expecting int or long value for attr ", key, | ||||
| 267 | ", got ", py_value->ob_type->tp_name) | ||||
| 268 | .c_str()); | ||||
| 269 | return false; | ||||
| 270 | } | ||||
| 271 | #endif | ||||
| 272 | |||||
| 273 | Py_ssize_t TensorShapeNumDims(PyObject* value) { | ||||
| 274 | const auto size = PySequence_Size(value); | ||||
| 275 | if (size == -1) { | ||||
| 276 | // TensorShape.__len__ raises an error in the scenario where the shape is an | ||||
| 277 | // unknown, which needs to be cleared. | ||||
| 278 | // TODO(nareshmodi): ensure that this is actually a TensorShape. | ||||
| 279 | PyErr_Clear(); | ||||
| 280 | } | ||||
| 281 | return size; | ||||
| 282 | } | ||||
| 283 | |||||
| 284 | bool IsInteger(PyObject* py_value) { | ||||
| 285 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 286 |   return PyLong_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 24))) != 0);  | ||||
| 287 | #else | ||||
| 288 |   return PyInt_Check(py_value) || PyLong_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 24))) != 0);  | ||||
| 289 | #endif | ||||
| 290 | } | ||||
| 291 | |||||
| 292 | // This function considers a Dimension._value of None to be valid, and sets the | ||||
| 293 | // value to be -1 in that case. | ||||
| 294 | bool ParseDimensionValue(const string& key, PyObject* py_value, | ||||
| 295 | TF_Status* status, int64_t* value) { | ||||
| 296 | if (IsInteger(py_value)) { | ||||
| 297 | return ParseInt64Value(key, py_value, status, value); | ||||
| 298 | } | ||||
| 299 | |||||
| 300 | tensorflow::Safe_PyObjectPtr dimension_value( | ||||
| 301 | PyObject_GetAttrString(py_value, "_value")); | ||||
| 302 | if (dimension_value == nullptr) { | ||||
| 303 | PyErr_Clear(); | ||||
| 304 | TF_SetStatus( | ||||
| 305 | status, TF_INVALID_ARGUMENT, | ||||
| 306 | tensorflow::strings::StrCat("Expecting a Dimension for attr ", key, | ||||
| 307 | ", got ", py_value->ob_type->tp_name) | ||||
| 308 | .c_str()); | ||||
| 309 | return false; | ||||
| 310 | } | ||||
| 311 | |||||
| 312 | if (dimension_value.get() == Py_None(&_Py_NoneStruct)) { | ||||
| 313 | *value = -1; | ||||
| 314 | return true; | ||||
| 315 | } | ||||
| 316 | |||||
| 317 | return ParseInt64Value(key, dimension_value.get(), status, value); | ||||
| 318 | } | ||||
| 319 | |||||
| 320 | bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status, | ||||
| 321 | tensorflow::StringPiece* value) { | ||||
| 322 |   if (PyBytes_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 27))) != 0)) {  | ||||
| 323 | Py_ssize_t size = 0; | ||||
| 324 | char* buf = nullptr; | ||||
| 325 | if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false; | ||||
| 326 | *value = tensorflow::StringPiece(buf, size); | ||||
| 327 | return true; | ||||
| 328 | } | ||||
| 329 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 330 |   if (PyUnicode_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & ( (1UL << 28))) != 0)) {  | ||||
| 331 | Py_ssize_t size = 0; | ||||
| 332 | const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size); | ||||
| 333 | if (buf == nullptr) return false; | ||||
| 334 | *value = tensorflow::StringPiece(buf, size); | ||||
| 335 | return true; | ||||
| 336 | } | ||||
| 337 | #endif | ||||
| 338 | TF_SetStatus( | ||||
| 339 | status, TF_INVALID_ARGUMENT, | ||||
| 340 | tensorflow::strings::StrCat("Expecting a string value for attr ", key, | ||||
| 341 | ", got ", py_value->ob_type->tp_name) | ||||
| 342 | .c_str()); | ||||
| 343 | return false; | ||||
| 344 | } | ||||
| 345 | |||||
| 346 | bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status, | ||||
| 347 | unsigned char* value) { | ||||
| 348 | *value = PyObject_IsTrue(py_value); | ||||
| 349 | return true; | ||||
| 350 | } | ||||
| 351 | |||||
| 352 | // The passed in py_value is expected to be an object of the python type | ||||
| 353 | // dtypes.DType or an int. | ||||
| 354 | bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status, | ||||
| 355 | int* value) { | ||||
| 356 | if (IsInteger(py_value)) { | ||||
| 357 | return ParseIntValue(key, py_value, status, value); | ||||
| 358 | } | ||||
| 359 | |||||
| 360 | tensorflow::Safe_PyObjectPtr py_type_enum( | ||||
| 361 | PyObject_GetAttrString(py_value, "_type_enum")); | ||||
| 362 | if (py_type_enum == nullptr) { | ||||
| 363 | PyErr_Clear(); | ||||
| 364 | TF_SetStatus( | ||||
| 365 | status, TF_INVALID_ARGUMENT, | ||||
| 366 | tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key, | ||||
| 367 | ", got ", py_value->ob_type->tp_name) | ||||
| 368 | .c_str()); | ||||
| 369 | return false; | ||||
| 370 | } | ||||
| 371 | |||||
| 372 | return ParseIntValue(key, py_type_enum.get(), status, value); | ||||
| 373 | } | ||||
| 374 | |||||
| 375 | bool SetOpAttrList(TFE_Context* ctx, TFE_Op* op, const char* key, | ||||
| 376 | PyObject* py_list, TF_AttrType type, | ||||
| 377 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | ||||
| 378 | TF_Status* status) { | ||||
| 379 | if (!PySequence_Check(py_list)) { | ||||
| 380 | TF_SetStatus( | ||||
| 381 | status, TF_INVALID_ARGUMENT, | ||||
| 382 | tensorflow::strings::StrCat("Expecting sequence value for attr ", key, | ||||
| 383 | ", got ", py_list->ob_type->tp_name) | ||||
| 384 | .c_str()); | ||||
| 385 | return false; | ||||
| 386 | } | ||||
| 387 | const int num_values = PySequence_Size(py_list); | ||||
| 388 | if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values; | ||||
| 389 | |||||
| 390 | #define PARSE_LIST(c_type, parse_fn) \ | ||||
| 391 | std::unique_ptr<c_type[]> values(new c_type[num_values]); \ | ||||
| 392 | for (int i = 0; i < num_values; ++i) { \ | ||||
| 393 |     tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) )); \  | ||||
| 394 | if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \ | ||||
| 395 | } | ||||
| 396 | |||||
| 397 | if (type == TF_ATTR_STRING) { | ||||
| 398 | std::unique_ptr<const void*[]> values(new const void*[num_values]); | ||||
| 399 | std::unique_ptr<size_t[]> lengths(new size_t[num_values]); | ||||
| 400 | for (int i = 0; i < num_values; ++i) { | ||||
| 401 | tensorflow::StringPiece value; | ||||
| 402 |       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) ));  | ||||
| 403 | if (!ParseStringValue(key, py_value.get(), status, &value)) return false; | ||||
| 404 | values[i] = value.data(); | ||||
| 405 | lengths[i] = value.size(); | ||||
| 406 | } | ||||
| 407 | TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); | ||||
| 408 | } else if (type == TF_ATTR_INT) { | ||||
| 409 | PARSE_LIST(int64_t, ParseInt64Value); | ||||
| 410 | TFE_OpSetAttrIntList(op, key, values.get(), num_values); | ||||
| 411 | } else if (type == TF_ATTR_FLOAT) { | ||||
| 412 | PARSE_LIST(float, ParseFloatValue); | ||||
| 413 | TFE_OpSetAttrFloatList(op, key, values.get(), num_values); | ||||
| 414 | } else if (type == TF_ATTR_BOOL) { | ||||
| 415 | PARSE_LIST(unsigned char, ParseBoolValue); | ||||
| 416 | TFE_OpSetAttrBoolList(op, key, values.get(), num_values); | ||||
| 417 | } else if (type == TF_ATTR_TYPE) { | ||||
| 418 | PARSE_LIST(int, ParseTypeValue); | ||||
| 419 | TFE_OpSetAttrTypeList(op, key, | ||||
| 420 | reinterpret_cast<const TF_DataType*>(values.get()), | ||||
| 421 | num_values); | ||||
| 422 | } else if (type == TF_ATTR_SHAPE) { | ||||
| 423 | // Make one pass through the input counting the total number of | ||||
| 424 | // dims across all the input lists. | ||||
| 425 | int total_dims = 0; | ||||
| 426 | for (int i = 0; i < num_values; ++i) { | ||||
| 427 |       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) ));  | ||||
| 428 | if (py_value.get() != Py_None(&_Py_NoneStruct)) { | ||||
| 429 | if (!PySequence_Check(py_value.get())) { | ||||
| 430 | TF_SetStatus( | ||||
| 431 | status, TF_INVALID_ARGUMENT, | ||||
| 432 | tensorflow::strings::StrCat( | ||||
| 433 | "Expecting None or sequence value for element", i, | ||||
| 434 | " of attr ", key, ", got ", py_value->ob_type->tp_name) | ||||
| 435 | .c_str()); | ||||
| 436 | return false; | ||||
| 437 | } | ||||
| 438 | const auto size = TensorShapeNumDims(py_value.get()); | ||||
| 439 | if (size >= 0) { | ||||
| 440 | total_dims += size; | ||||
| 441 | } | ||||
| 442 | } | ||||
| 443 | } | ||||
| 444 | // Allocate a buffer that can fit all of the dims together. | ||||
| 445 | std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]); | ||||
| 446 | // Copy the input dims into the buffer and set dims to point to | ||||
| 447 | // the start of each list's dims. | ||||
| 448 | std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]); | ||||
| 449 | std::unique_ptr<int[]> num_dims(new int[num_values]); | ||||
| 450 | int64_t* offset = buffer.get(); | ||||
| 451 | for (int i = 0; i < num_values; ++i) { | ||||
| 452 |       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) ));  | ||||
| 453 | if (py_value.get() == Py_None(&_Py_NoneStruct)) { | ||||
| 454 | dims[i] = nullptr; | ||||
| 455 | num_dims[i] = -1; | ||||
| 456 | } else { | ||||
| 457 | const auto size = TensorShapeNumDims(py_value.get()); | ||||
| 458 | if (size == -1) { | ||||
| 459 | dims[i] = nullptr; | ||||
| 460 | num_dims[i] = -1; | ||||
| 461 | continue; | ||||
| 462 | } | ||||
| 463 | dims[i] = offset; | ||||
| 464 | num_dims[i] = size; | ||||
| 465 | for (int j = 0; j < size; ++j) { | ||||
| 466 | tensorflow::Safe_PyObjectPtr inner_py_value( | ||||
| 467 |               PySequence_ITEM(py_value.get(), j)( (((PyObject*)(py_value.get()))->ob_type)->tp_as_sequence ->sq_item(py_value.get(), j) ));  | ||||
| 468 | if (inner_py_value.get() == Py_None(&_Py_NoneStruct)) { | ||||
| 469 | *offset = -1; | ||||
| 470 | } else if (!ParseDimensionValue(key, inner_py_value.get(), status, | ||||
| 471 | offset)) { | ||||
| 472 | return false; | ||||
| 473 | } | ||||
| 474 | ++offset; | ||||
| 475 | } | ||||
| 476 | } | ||||
| 477 | } | ||||
| 478 | TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, | ||||
| 479 | status); | ||||
| 480 | if (!status->status.ok()) return false; | ||||
| 481 | } else if (type == TF_ATTR_FUNC) { | ||||
| 482 | std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]); | ||||
| 483 | for (int i = 0; i < num_values; ++i) { | ||||
| 484 |       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence-> sq_item(py_list, i) ));  | ||||
| 485 | // Allow: | ||||
| 486 | // (1) String function name, OR | ||||
| 487 | // (2) A Python object with a .name attribute | ||||
| 488 | // (A crude test for being a | ||||
| 489 | // tensorflow.python.framework.function._DefinedFunction) | ||||
| 490 | // (which is what the various "defun" or "Defun" decorators do). | ||||
| 491 | // And in the future also allow an object that can encapsulate | ||||
| 492 | // the function name and its attribute values. | ||||
| 493 | tensorflow::StringPiece func_name; | ||||
| 494 | if (!ParseStringValue(key, py_value.get(), status, &func_name)) { | ||||
| 495 | PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name"); | ||||
| 496 | if (name_attr == nullptr || | ||||
| 497 | !ParseStringValue(key, name_attr, status, &func_name)) { | ||||
| 498 | TF_SetStatus( | ||||
| 499 | status, TF_INVALID_ARGUMENT, | ||||
| 500 | tensorflow::strings::StrCat( | ||||
| 501 | "unable to set function value attribute from a ", | ||||
| 502 | py_value.get()->ob_type->tp_name, | ||||
| 503 | " object. If you think this is an error, please file an " | ||||
| 504 | "issue at " | ||||
| 505 | "https://github.com/tensorflow/tensorflow/issues/new") | ||||
| 506 | .c_str()); | ||||
| 507 | return false; | ||||
| 508 | } | ||||
| 509 | } | ||||
| 510 | funcs[i] = TFE_NewOp(ctx, func_name.data(), status); | ||||
| 511 | if (!status->status.ok()) return false; | ||||
| 512 | } | ||||
| 513 | TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values); | ||||
| 514 | if (!status->status.ok()) return false; | ||||
| 515 | } else { | ||||
| 516 | TF_SetStatus(status, TF_UNIMPLEMENTED, | ||||
| 517 | tensorflow::strings::StrCat("Attr ", key, | ||||
| 518 | " has unhandled list type ", type) | ||||
| 519 | .c_str()); | ||||
| 520 | return false; | ||||
| 521 | } | ||||
| 522 | #undef PARSE_LIST | ||||
| 523 | return true; | ||||
| 524 | } | ||||
| 525 | |||||
| 526 | TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, | ||||
| 527 | TF_Status* status) { | ||||
| 528 | TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); | ||||
| 529 | for (const auto& attr : func.attr()) { | ||||
| 530 | if (!status->status.ok()) return nullptr; | ||||
| 531 | SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); | ||||
| 532 | if (!status->status.ok()) return nullptr; | ||||
| 533 | } | ||||
| 534 | return func_op; | ||||
| 535 | } | ||||
| 536 | |||||
| 537 | void SetOpAttrListDefault( | ||||
| 538 | TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, | ||||
| 539 | const char* key, TF_AttrType type, | ||||
| 540 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | ||||
| 541 | TF_Status* status) { | ||||
| 542 | if (type == TF_ATTR_STRING) { | ||||
| 543 | int num_values = attr.default_value().list().s_size(); | ||||
| 544 | std::unique_ptr<const void*[]> values(new const void*[num_values]); | ||||
| 545 | std::unique_ptr<size_t[]> lengths(new size_t[num_values]); | ||||
| 546 | (*attr_list_sizes)[key] = num_values; | ||||
| 547 | for (int i = 0; i < num_values; i++) { | ||||
| 548 | const string& v = attr.default_value().list().s(i); | ||||
| 549 | values[i] = v.data(); | ||||
| 550 | lengths[i] = v.size(); | ||||
| 551 | } | ||||
| 552 | TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); | ||||
| 553 | } else if (type == TF_ATTR_INT) { | ||||
| 554 | int num_values = attr.default_value().list().i_size(); | ||||
| 555 | std::unique_ptr<int64_t[]> values(new int64_t[num_values]); | ||||
| 556 | (*attr_list_sizes)[key] = num_values; | ||||
| 557 | for (int i = 0; i < num_values; i++) { | ||||
| 558 | values[i] = attr.default_value().list().i(i); | ||||
| 559 | } | ||||
| 560 | TFE_OpSetAttrIntList(op, key, values.get(), num_values); | ||||
| 561 | } else if (type == TF_ATTR_FLOAT) { | ||||
| 562 | int num_values = attr.default_value().list().f_size(); | ||||
| 563 | std::unique_ptr<float[]> values(new float[num_values]); | ||||
| 564 | (*attr_list_sizes)[key] = num_values; | ||||
| 565 | for (int i = 0; i < num_values; i++) { | ||||
| 566 | values[i] = attr.default_value().list().f(i); | ||||
| 567 | } | ||||
| 568 | TFE_OpSetAttrFloatList(op, key, values.get(), num_values); | ||||
| 569 | } else if (type == TF_ATTR_BOOL) { | ||||
| 570 | int num_values = attr.default_value().list().b_size(); | ||||
| 571 | std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]); | ||||
| 572 | (*attr_list_sizes)[key] = num_values; | ||||
| 573 | for (int i = 0; i < num_values; i++) { | ||||
| 574 | values[i] = attr.default_value().list().b(i); | ||||
| 575 | } | ||||
| 576 | TFE_OpSetAttrBoolList(op, key, values.get(), num_values); | ||||
| 577 | } else if (type == TF_ATTR_TYPE) { | ||||
| 578 | int num_values = attr.default_value().list().type_size(); | ||||
| 579 | std::unique_ptr<int[]> values(new int[num_values]); | ||||
| 580 | (*attr_list_sizes)[key] = num_values; | ||||
| 581 | for (int i = 0; i < num_values; i++) { | ||||
| 582 | values[i] = attr.default_value().list().type(i); | ||||
| 583 | } | ||||
| 584 | TFE_OpSetAttrTypeList(op, key, | ||||
| 585 | reinterpret_cast<const TF_DataType*>(values.get()), | ||||
| 586 | attr.default_value().list().type_size()); | ||||
| 587 | } else if (type == TF_ATTR_SHAPE) { | ||||
| 588 | int num_values = attr.default_value().list().shape_size(); | ||||
| 589 | (*attr_list_sizes)[key] = num_values; | ||||
| 590 | int total_dims = 0; | ||||
| 591 | for (int i = 0; i < num_values; ++i) { | ||||
| 592 | if (!attr.default_value().list().shape(i).unknown_rank()) { | ||||
| 593 | total_dims += attr.default_value().list().shape(i).dim_size(); | ||||
| 594 | } | ||||
| 595 | } | ||||
| 596 | // Allocate a buffer that can fit all of the dims together. | ||||
| 597 | std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]); | ||||
| 598 | // Copy the input dims into the buffer and set dims to point to | ||||
| 599 | // the start of each list's dims. | ||||
| 600 | std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]); | ||||
| 601 | std::unique_ptr<int[]> num_dims(new int[num_values]); | ||||
| 602 | int64_t* offset = buffer.get(); | ||||
| 603 | for (int i = 0; i < num_values; ++i) { | ||||
| 604 | const auto& shape = attr.default_value().list().shape(i); | ||||
| 605 | if (shape.unknown_rank()) { | ||||
| 606 | dims[i] = nullptr; | ||||
| 607 | num_dims[i] = -1; | ||||
| 608 | } else { | ||||
| 609 | for (int j = 0; j < shape.dim_size(); j++) { | ||||
| 610 | *offset = shape.dim(j).size(); | ||||
| 611 | ++offset; | ||||
| 612 | } | ||||
| 613 | } | ||||
| 614 | } | ||||
| 615 | TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, | ||||
| 616 | status); | ||||
| 617 | } else if (type == TF_ATTR_FUNC) { | ||||
| 618 | int num_values = attr.default_value().list().func_size(); | ||||
| 619 | (*attr_list_sizes)[key] = num_values; | ||||
| 620 | std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]); | ||||
| 621 | for (int i = 0; i < num_values; i++) { | ||||
| 622 | funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status); | ||||
| 623 | } | ||||
| 624 | TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values); | ||||
| 625 | } else { | ||||
| 626 | TF_SetStatus(status, TF_UNIMPLEMENTED, | ||||
| 627 | "Lists of tensors are not yet implemented for default valued " | ||||
| 628 | "attributes for an operation."); | ||||
| 629 | } | ||||
| 630 | } | ||||
| 631 | |||||
| 632 | bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key, | ||||
| 633 | PyObject* py_value, TF_AttrType type, | ||||
| 634 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | ||||
| 635 | TF_Status* status) { | ||||
| 636 | if (type == TF_ATTR_STRING) { | ||||
| 637 | tensorflow::StringPiece value; | ||||
| 638 | if (!ParseStringValue(key, py_value, status, &value)) return false; | ||||
| 639 | TFE_OpSetAttrString(op, key, value.data(), value.size()); | ||||
| 640 | } else if (type == TF_ATTR_INT) { | ||||
| 641 | int64_t value; | ||||
| 642 | if (!ParseInt64Value(key, py_value, status, &value)) return false; | ||||
| 643 | TFE_OpSetAttrInt(op, key, value); | ||||
| 644 | // attr_list_sizes is set for all int attributes (since at this point we are | ||||
| 645 | // not aware if that attribute might be used to calculate the size of an | ||||
| 646 | // output list or not). | ||||
| 647 | if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value; | ||||
| 648 | } else if (type == TF_ATTR_FLOAT) { | ||||
| 649 | float value; | ||||
| 650 | if (!ParseFloatValue(key, py_value, status, &value)) return false; | ||||
| 651 | TFE_OpSetAttrFloat(op, key, value); | ||||
| 652 | } else if (type == TF_ATTR_BOOL) { | ||||
| 653 | unsigned char value; | ||||
| 654 | if (!ParseBoolValue(key, py_value, status, &value)) return false; | ||||
| 655 | TFE_OpSetAttrBool(op, key, value); | ||||
| 656 | } else if (type == TF_ATTR_TYPE) { | ||||
| 657 | int value; | ||||
| 658 | if (!ParseTypeValue(key, py_value, status, &value)) return false; | ||||
| 659 | TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value)); | ||||
| 660 | } else if (type == TF_ATTR_SHAPE) { | ||||
| 661 | if (py_value == Py_None(&_Py_NoneStruct)) { | ||||
| 662 | TFE_OpSetAttrShape(op, key, nullptr, -1, status); | ||||
| 663 | } else { | ||||
| 664 | if (!PySequence_Check(py_value)) { | ||||
| 665 | TF_SetStatus(status, TF_INVALID_ARGUMENT, | ||||
| 666 | tensorflow::strings::StrCat( | ||||
| 667 | "Expecting None or sequence value for attr", key, | ||||
| 668 | ", got ", py_value->ob_type->tp_name) | ||||
| 669 | .c_str()); | ||||
| 670 | return false; | ||||
| 671 | } | ||||
| 672 | const auto num_dims = TensorShapeNumDims(py_value); | ||||
| 673 | if (num_dims == -1) { | ||||
| 674 | TFE_OpSetAttrShape(op, key, nullptr, -1, status); | ||||
| 675 | return true; | ||||
| 676 | } | ||||
| 677 | std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]); | ||||
| 678 | for (int i = 0; i < num_dims; ++i) { | ||||
| 679 | tensorflow::Safe_PyObjectPtr inner_py_value( | ||||
| 680 |             PySequence_ITEM(py_value, i)( (((PyObject*)(py_value))->ob_type)->tp_as_sequence-> sq_item(py_value, i) ));  | ||||
| 681 | if (inner_py_value.get() == Py_None(&_Py_NoneStruct)) { | ||||
| 682 | dims[i] = -1; | ||||
| 683 | } else if (!ParseDimensionValue(key, inner_py_value.get(), status, | ||||
| 684 | &dims[i])) { | ||||
| 685 | return false; | ||||
| 686 | } | ||||
| 687 | } | ||||
| 688 | TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status); | ||||
| 689 | } | ||||
| 690 | if (!status->status.ok()) return false; | ||||
| 691 | } else if (type == TF_ATTR_FUNC) { | ||||
| 692 | // Allow: | ||||
| 693 | // (1) String function name, OR | ||||
| 694 | // (2) A Python object with a .name attribute | ||||
| 695 | // (A crude test for being a | ||||
| 696 | // tensorflow.python.framework.function._DefinedFunction) | ||||
| 697 | // (which is what the various "defun" or "Defun" decorators do). | ||||
| 698 | // And in the future also allow an object that can encapsulate | ||||
| 699 | // the function name and its attribute values. | ||||
| 700 | tensorflow::StringPiece func_name; | ||||
| 701 | if (!ParseStringValue(key, py_value, status, &func_name)) { | ||||
| 702 | PyObject* name_attr = PyObject_GetAttrString(py_value, "name"); | ||||
| 703 | if (name_attr == nullptr || | ||||
| 704 | !ParseStringValue(key, name_attr, status, &func_name)) { | ||||
| 705 | TF_SetStatus( | ||||
| 706 | status, TF_INVALID_ARGUMENT, | ||||
| 707 | tensorflow::strings::StrCat( | ||||
| 708 | "unable to set function value attribute from a ", | ||||
| 709 | py_value->ob_type->tp_name, | ||||
| 710 | " object. If you think this is an error, please file an issue " | ||||
| 711 | "at https://github.com/tensorflow/tensorflow/issues/new") | ||||
| 712 | .c_str()); | ||||
| 713 | return false; | ||||
| 714 | } | ||||
| 715 | } | ||||
| 716 | TF_SetStatus(status, TF_OK, ""); | ||||
| 717 | TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size()); | ||||
| 718 | } else { | ||||
| 719 | TF_SetStatus( | ||||
| 720 | status, TF_UNIMPLEMENTED, | ||||
| 721 | tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type) | ||||
| 722 | .c_str()); | ||||
| 723 | return false; | ||||
| 724 | } | ||||
| 725 | return true; | ||||
| 726 | } | ||||
| 727 | |||||
| 728 | void SetOpAttrScalarDefault( | ||||
| 729 | TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, | ||||
| 730 | const char* attr_name, | ||||
| 731 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | ||||
| 732 | TF_Status* status) { | ||||
| 733 | SetOpAttrValueScalar(ctx, op, default_value, attr_name, status); | ||||
| 734 | if (default_value.value_case() == tensorflow::AttrValue::kI) { | ||||
| 735 | (*attr_list_sizes)[attr_name] = default_value.i(); | ||||
| 736 | } | ||||
| 737 | } | ||||
| 738 | |||||
| 739 | // start_index is the index at which the Tuple/List attrs will start getting | ||||
| 740 | // processed. | ||||
| 741 | void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index, | ||||
| 742 | TF_Status* out_status) { | ||||
| 743 | if (attrs == Py_None(&_Py_NoneStruct)) return; | ||||
| 744 |   Py_ssize_t len = PyTuple_GET_SIZE(attrs)(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(attrs))))->ob_size) - start_index;  | ||||
| 745 | if ((len & 1) != 0) { | ||||
| 746 | TF_SetStatus(out_status, TF_INVALID_ARGUMENT, | ||||
| 747 | "Expecting attrs tuple to have even length."); | ||||
| 748 | return; | ||||
| 749 | } | ||||
| 750 | // Parse attrs | ||||
| 751 | for (Py_ssize_t i = 0; i < len; i += 2) { | ||||
| 752 |     PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i)(((static_cast<void> (0)), (PyTupleObject *)(attrs))-> ob_item[start_index + i]);  | ||||
| 753 |     PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1)(((static_cast<void> (0)), (PyTupleObject *)(attrs))-> ob_item[start_index + i + 1]);  | ||||
| 754 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 755 |     const char* key = PyBytes_Check(py_key)((((((PyObject*)(py_key))->ob_type))->tp_flags & (( 1UL << 27))) != 0) ? PyBytes_AsString(py_key)  | ||||
| 756 | : PyUnicode_AsUTF8(py_key); | ||||
| 757 | #else | ||||
| 758 | const char* key = PyBytes_AsString(py_key); | ||||
| 759 | #endif | ||||
| 760 | unsigned char is_list = 0; | ||||
| 761 | const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status); | ||||
| 762 | if (!out_status->status.ok()) return; | ||||
| 763 | if (is_list != 0) { | ||||
| 764 | if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status)) | ||||
| 765 | return; | ||||
| 766 | } else { | ||||
| 767 | if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status)) | ||||
| 768 | return; | ||||
| 769 | } | ||||
| 770 | } | ||||
| 771 | } | ||||
| 772 | |||||
| 773 | // This function will set the op attrs required. If an attr has the value of | ||||
| 774 | // None, then it will read the AttrDef to get the default value and set that | ||||
| 775 | // instead. Any failure in this function will simply fall back to the slow | ||||
| 776 | // path. | ||||
| 777 | void SetOpAttrWithDefaults( | ||||
| 778 | TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, | ||||
| 779 | const char* attr_name, PyObject* attr_value, | ||||
| 780 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, | ||||
| 781 | TF_Status* status) { | ||||
| 782 | unsigned char is_list = 0; | ||||
| 783 | const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status); | ||||
| 784 | if (!status->status.ok()) return; | ||||
| 785 | if (attr_value == Py_None(&_Py_NoneStruct)) { | ||||
| 786 | if (is_list != 0) { | ||||
| 787 | SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes, | ||||
| 788 | status); | ||||
| 789 | } else { | ||||
| 790 | SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name, | ||||
| 791 | attr_list_sizes, status); | ||||
| 792 | } | ||||
| 793 | } else { | ||||
| 794 | if (is_list != 0) { | ||||
| 795 | SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes, | ||||
| 796 | status); | ||||
| 797 | } else { | ||||
| 798 | SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes, | ||||
| 799 | status); | ||||
| 800 | } | ||||
| 801 | } | ||||
| 802 | } | ||||
| 803 | |||||
| 804 | PyObject* GetPythonObjectFromInt(int num) { | ||||
| 805 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 806 | return PyLong_FromLong(num); | ||||
| 807 | #else | ||||
| 808 | return PyInt_FromLong(num); | ||||
| 809 | #endif | ||||
| 810 | } | ||||
| 811 | |||||
| 812 | // Python subclass of Exception that is created on not ok Status. | ||||
| 813 | tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED); | ||||
| 814 | PyObject* exception_class TF_GUARDED_BY(exception_class_mutex)__attribute__((guarded_by(exception_class_mutex))) = nullptr; | ||||
| 815 | |||||
| 816 | // Python subclass of Exception that is created to signal fallback. | ||||
| 817 | PyObject* fallback_exception_class = nullptr; | ||||
| 818 | |||||
| 819 | // Python function that returns input gradients given output gradients. | ||||
| 820 | PyObject* gradient_function = nullptr; | ||||
| 821 | |||||
| 822 | // Python function that returns output gradients given input gradients. | ||||
| 823 | PyObject* forward_gradient_function = nullptr; | ||||
| 824 | |||||
| 825 | static std::atomic<int64_t> _uid; | ||||
| 826 | |||||
| 827 | } // namespace | ||||
| 828 | |||||
| 829 | TF_Status* GetStatus() { | ||||
| 830 | TF_Status* maybe_status = ReleaseThreadLocalStatus(); | ||||
| 831 | if (maybe_status) { | ||||
| 832 | TF_SetStatus(maybe_status, TF_OK, ""); | ||||
| 833 | return maybe_status; | ||||
| 834 | } else { | ||||
| 835 | return TF_NewStatus(); | ||||
| 836 | } | ||||
| 837 | } | ||||
| 838 | |||||
| 839 | void ReturnStatus(TF_Status* status) { | ||||
| 840 | TF_SetStatus(status, TF_OK, ""); | ||||
| 841 | thread_local_tf_status.reset(status); | ||||
| 842 | } | ||||
| 843 | |||||
| 844 | void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, | ||||
| 845 | const char* op_name, TFE_InputTensorHandles* inputs, | ||||
| 846 | PyObject* attrs, TFE_OutputTensorHandles* outputs, | ||||
| 847 | TF_Status* out_status) { | ||||
| 848 | TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs, | ||||
| 849 | /*cancellation_manager=*/nullptr, outputs, | ||||
| 850 | out_status); | ||||
| 851 | } | ||||
| 852 | |||||
| 853 | void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name, | ||||
| 854 | const char* op_name, | ||||
| 855 | TFE_InputTensorHandles* inputs, PyObject* attrs, | ||||
| 856 | TFE_CancellationManager* cancellation_manager, | ||||
| 857 | TFE_OutputTensorHandles* outputs, | ||||
| 858 | TF_Status* out_status) { | ||||
| 859 | tensorflow::profiler::TraceMe activity( | ||||
| 860 | "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo); | ||||
| 861 | |||||
| 862 | TFE_Op* op = GetOp(ctx, op_name, device_name, out_status); | ||||
| 863 | |||||
| 864 | auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); }); | ||||
| 865 | if (!out_status->status.ok()) return; | ||||
| 866 | |||||
| 867 | tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace( | ||||
| 868 | tensorflow::StackTrace::kStackTraceInitialSize)); | ||||
| 869 | |||||
| 870 | for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) { | ||||
| 871 | TFE_OpAddInput(op, inputs->at(i), out_status); | ||||
| 872 | } | ||||
| 873 | if (cancellation_manager && out_status->status.ok()) { | ||||
| 874 | TFE_OpSetCancellationManager(op, cancellation_manager, out_status); | ||||
| 875 | } | ||||
| 876 | if (out_status->status.ok()) { | ||||
| 877 | SetOpAttrs(ctx, op, attrs, 0, out_status); | ||||
| 878 | } | ||||
| 879 | Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();; | ||||
| 880 | |||||
| 881 | int num_outputs = outputs->size(); | ||||
| 882 | |||||
| 883 | if (out_status->status.ok()) { | ||||
| 884 | TFE_Execute(op, outputs->data(), &num_outputs, out_status); | ||||
| 885 | } | ||||
| 886 | |||||
| 887 | if (out_status->status.ok()) { | ||||
| 888 | outputs->resize(num_outputs); | ||||
| 889 | } else { | ||||
| 890 | TF_SetStatus(out_status, TF_GetCode(out_status), | ||||
| 891 | tensorflow::strings::StrCat(TF_Message(out_status), | ||||
| 892 | " [Op:", op_name, "]") | ||||
| 893 | .c_str()); | ||||
| 894 | } | ||||
| 895 | |||||
| 896 | Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); }; | ||||
| 897 | } | ||||
| 898 | |||||
| 899 | PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) { | ||||
| 900 | tensorflow::mutex_lock l(exception_class_mutex); | ||||
| 901 | if (exception_class != nullptr) { | ||||
| 902 | Py_DECREF(exception_class)_Py_DECREF(((PyObject*)(exception_class))); | ||||
| 903 | } | ||||
| 904 | if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { | ||||
| 905 | exception_class = nullptr; | ||||
| 906 | PyErr_SetString(PyExc_TypeError, | ||||
| 907 | "TFE_Py_RegisterExceptionClass: " | ||||
| 908 | "Registered class should be subclass of Exception."); | ||||
| 909 | return nullptr; | ||||
| 910 | } | ||||
| 911 | |||||
| 912 | Py_INCREF(e)_Py_INCREF(((PyObject*)(e))); | ||||
| 913 | exception_class = e; | ||||
| 914 |   Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 915 | } | ||||
| 916 | |||||
| 917 | PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) { | ||||
| 918 | if (fallback_exception_class != nullptr) { | ||||
| 919 | Py_DECREF(fallback_exception_class)_Py_DECREF(((PyObject*)(fallback_exception_class))); | ||||
| 920 | } | ||||
| 921 | if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { | ||||
| 922 | fallback_exception_class = nullptr; | ||||
| 923 | PyErr_SetString(PyExc_TypeError, | ||||
| 924 | "TFE_Py_RegisterFallbackExceptionClass: " | ||||
| 925 | "Registered class should be subclass of Exception."); | ||||
| 926 | return nullptr; | ||||
| 927 | } else { | ||||
| 928 | Py_INCREF(e)_Py_INCREF(((PyObject*)(e))); | ||||
| 929 | fallback_exception_class = e; | ||||
| 930 |     Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 931 | } | ||||
| 932 | } | ||||
| 933 | |||||
| 934 | PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) { | ||||
| 935 | if (gradient_function != nullptr) { | ||||
| 936 | Py_DECREF(gradient_function)_Py_DECREF(((PyObject*)(gradient_function))); | ||||
| 937 | } | ||||
| 938 | if (!PyCallable_Check(e)) { | ||||
| 939 | gradient_function = nullptr; | ||||
| 940 | PyErr_SetString(PyExc_TypeError, | ||||
| 941 | "TFE_Py_RegisterGradientFunction: " | ||||
| 942 | "Registered object should be function."); | ||||
| 943 | return nullptr; | ||||
| 944 | } else { | ||||
| 945 | Py_INCREF(e)_Py_INCREF(((PyObject*)(e))); | ||||
| 946 | gradient_function = e; | ||||
| 947 |     Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 948 | } | ||||
| 949 | } | ||||
| 950 | |||||
| 951 | PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) { | ||||
| 952 | if (forward_gradient_function != nullptr) { | ||||
| 953 | Py_DECREF(forward_gradient_function)_Py_DECREF(((PyObject*)(forward_gradient_function))); | ||||
| 954 | } | ||||
| 955 | if (!PyCallable_Check(e)) { | ||||
| 956 | forward_gradient_function = nullptr; | ||||
| 957 | PyErr_SetString(PyExc_TypeError, | ||||
| 958 | "TFE_Py_RegisterJVPFunction: " | ||||
| 959 | "Registered object should be function."); | ||||
| 960 | return nullptr; | ||||
| 961 | } else { | ||||
| 962 | Py_INCREF(e)_Py_INCREF(((PyObject*)(e))); | ||||
| 963 | forward_gradient_function = e; | ||||
| 964 |     Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 965 | } | ||||
| 966 | } | ||||
| 967 | |||||
| 968 | void RaiseFallbackException(const char* message) { | ||||
| 969 | if (fallback_exception_class != nullptr) { | ||||
| 970 | PyErr_SetString(fallback_exception_class, message); | ||||
| 971 | return; | ||||
| 972 | } | ||||
| 973 | |||||
| 974 | PyErr_SetString( | ||||
| 975 | PyExc_RuntimeError, | ||||
| 976 | tensorflow::strings::StrCat( | ||||
| 977 | "Fallback exception type not set, attempting to fallback due to ", | ||||
| 978 | message) | ||||
| 979 | .data()); | ||||
| 980 | } | ||||
| 981 | |||||
| 982 | // Format and return `status`' error message with the attached stack trace if | ||||
| 983 | // available. `status` must have an error. | ||||
| 984 | std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) { | ||||
| 985 | tensorflow::DCheckPyGilState(); | ||||
| 986 |   DCHECK(!status.ok())while (false && (!status.ok())) ::tensorflow::internal ::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 986);  | ||||
| 987 | |||||
| 988 | if (status.stack_trace().empty()) return status.error_message(); | ||||
| 989 | |||||
| 990 | const std::vector<tensorflow::StackFrame>& stack_trace = status.stack_trace(); | ||||
| 991 | |||||
| 992 | PyObject* linecache = PyImport_ImportModule("linecache"); | ||||
| 993 | PyObject* getline = | ||||
| 994 | PyObject_GetAttr(linecache, PyUnicode_FromString("getline")); | ||||
| 995 |   DCHECK(getline)while (false && (getline)) ::tensorflow::internal::LogMessageFatal ("tensorflow/python/eager/pywrap_tfe_src.cc", 995);  | ||||
| 996 | |||||
| 997 | std::ostringstream result; | ||||
| 998 | result << "Exception originated from\n\n"; | ||||
| 999 | |||||
| 1000 | for (const tensorflow::StackFrame& stack_frame : stack_trace) { | ||||
| 1001 | PyObject* line_str_obj = PyObject_CallFunction( | ||||
| 1002 | getline, const_cast<char*>("si"), stack_frame.file_name.c_str(), | ||||
| 1003 | stack_frame.line_number); | ||||
| 1004 | tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj); | ||||
| 1005 | tensorflow::str_util::RemoveWhitespaceContext(&line_str); | ||||
| 1006 | result << " File \"" << stack_frame.file_name << "\", line " | ||||
| 1007 | << stack_frame.line_number << ", in " << stack_frame.function_name | ||||
| 1008 | << '\n'; | ||||
| 1009 | |||||
| 1010 | if (!line_str.empty()) result << " " << line_str << '\n'; | ||||
| 1011 | Py_XDECREF(line_str_obj)_Py_XDECREF(((PyObject*)(line_str_obj))); | ||||
| 1012 | } | ||||
| 1013 | |||||
| 1014 | Py_DecRef(getline); | ||||
| 1015 | Py_DecRef(linecache); | ||||
| 1016 | |||||
| 1017 | result << '\n' << status.error_message(); | ||||
| 1018 | return result.str(); | ||||
| 1019 | } | ||||
| 1020 | |||||
| 1021 | int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) { | ||||
| 1022 | if (status->status.ok()) return 0; | ||||
| 1023 | const char* msg = TF_Message(status); | ||||
| 1024 | if (exception == nullptr) { | ||||
| 1025 | tensorflow::mutex_lock l(exception_class_mutex); | ||||
| 1026 | if (exception_class != nullptr) { | ||||
| 1027 | tensorflow::Safe_PyObjectPtr payloads(PyDict_New()); | ||||
| 1028 | for (const auto& payload : | ||||
| 1029 | tensorflow::errors::GetPayloads(status->status)) { | ||||
| 1030 | PyDict_SetItem(payloads.get(), | ||||
| 1031 | PyBytes_FromString(payload.first.c_str()), | ||||
| 1032 | PyBytes_FromString(payload.second.c_str())); | ||||
| 1033 | } | ||||
| 1034 | tensorflow::Safe_PyObjectPtr val(Py_BuildValue( | ||||
| 1035 | "siO", FormatErrorStatusStackTrace(status->status).c_str(), | ||||
| 1036 | TF_GetCode(status), payloads.get())); | ||||
| 1037 | if (PyErr_Occurred()) { | ||||
| 1038 | // NOTE: This hides the actual error (i.e. the reason `status` was not | ||||
| 1039 | // TF_OK), but there is nothing we can do at this point since we can't | ||||
| 1040 | // generate a reasonable error from the status. | ||||
| 1041 | // Consider adding a message explaining this. | ||||
| 1042 | return -1; | ||||
| 1043 | } | ||||
| 1044 | PyErr_SetObject(exception_class, val.get()); | ||||
| 1045 | return -1; | ||||
| 1046 | } else { | ||||
| 1047 | exception = PyExc_RuntimeError; | ||||
| 1048 | } | ||||
| 1049 | } | ||||
| 1050 | // May be update already set exception. | ||||
| 1051 | PyErr_SetString(exception, msg); | ||||
| 1052 | return -1; | ||||
| 1053 | } | ||||
| 1054 | |||||
| 1055 | int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, | ||||
| 1056 | PyObject* exception) { | ||||
| 1057 | if (status.ok()) return 0; | ||||
| 1058 | const char* msg = status.error_message().c_str(); | ||||
| 1059 | if (exception == nullptr) { | ||||
| 1060 | tensorflow::mutex_lock l(exception_class_mutex); | ||||
| 1061 | if (exception_class != nullptr) { | ||||
| 1062 | tensorflow::Safe_PyObjectPtr payloads(PyDict_New()); | ||||
| 1063 | for (const auto& element : tensorflow::errors::GetPayloads(status)) { | ||||
| 1064 | PyDict_SetItem(payloads.get(), | ||||
| 1065 | PyBytes_FromString(element.first.c_str()), | ||||
| 1066 | PyBytes_FromString(element.second.c_str())); | ||||
| 1067 | } | ||||
| 1068 | tensorflow::Safe_PyObjectPtr val( | ||||
| 1069 | Py_BuildValue("siO", FormatErrorStatusStackTrace(status).c_str(), | ||||
| 1070 | status.code(), payloads.get())); | ||||
| 1071 | PyErr_SetObject(exception_class, val.get()); | ||||
| 1072 | return -1; | ||||
| 1073 | } else { | ||||
| 1074 | exception = PyExc_RuntimeError; | ||||
| 1075 | } | ||||
| 1076 | } | ||||
| 1077 | // May be update already set exception. | ||||
| 1078 | PyErr_SetString(exception, msg); | ||||
| 1079 | return -1; | ||||
| 1080 | } | ||||
| 1081 | |||||
| 1082 | const char* TFE_GetPythonString(PyObject* o) { | ||||
| 1083 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 1084 |   if (PyBytes_Check(o)((((((PyObject*)(o))->ob_type))->tp_flags & ((1UL << 27))) != 0)) {  | ||||
| 1085 | return PyBytes_AsString(o); | ||||
| 1086 | } else { | ||||
| 1087 | return PyUnicode_AsUTF8(o); | ||||
| 1088 | } | ||||
| 1089 | #else | ||||
| 1090 | return PyBytes_AsString(o); | ||||
| 1091 | #endif | ||||
| 1092 | } | ||||
| 1093 | |||||
| 1094 | int64_t get_uid() { return _uid++; } | ||||
| 1095 | |||||
| 1096 | PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); } | ||||
| 1097 | |||||
| 1098 | void TFE_DeleteContextCapsule(PyObject* context) { | ||||
| 1099 | TFE_Context* ctx = | ||||
| 1100 | reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr)); | ||||
| 1101 | auto op = ReleaseThreadLocalOp(ctx); | ||||
| 1102 | op.reset(); | ||||
| 1103 | TFE_DeleteContext(ctx); | ||||
| 1104 | } | ||||
| 1105 | |||||
| 1106 | static int64_t MakeInt(PyObject* integer) { | ||||
| 1107 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 1108 | return PyLong_AsLong(integer); | ||||
| 1109 | #else | ||||
| 1110 | return PyInt_AsLong(integer); | ||||
| 1111 | #endif | ||||
| 1112 | } | ||||
| 1113 | |||||
| 1114 | static int64_t FastTensorId(PyObject* tensor) { | ||||
| 1115 | if (EagerTensor_CheckExact(tensor)) { | ||||
| 1116 | return PyEagerTensor_ID(tensor); | ||||
| 1117 | } | ||||
| 1118 | PyObject* id_field = PyObject_GetAttrString(tensor, "_id"); | ||||
| 1119 | if (id_field == nullptr) { | ||||
| 1120 | return -1; | ||||
| 1121 | } | ||||
| 1122 | int64_t id = MakeInt(id_field); | ||||
| 1123 | Py_DECREF(id_field)_Py_DECREF(((PyObject*)(id_field))); | ||||
| 1124 | return id; | ||||
| 1125 | } | ||||
| 1126 | |||||
| 1127 | namespace tensorflow { | ||||
| 1128 | DataType PyTensor_DataType(PyObject* tensor) { | ||||
| 1129 | if (EagerTensor_CheckExact(tensor)) { | ||||
| 1130 | return PyEagerTensor_Dtype(tensor); | ||||
| 1131 | } else { | ||||
| 1132 | #if PY_MAJOR_VERSION3 < 3 | ||||
| 1133 | // Python 2.x: | ||||
| 1134 | static PyObject* dtype_attr = PyString_InternFromString("dtype"); | ||||
| 1135 | static PyObject* type_enum_attr = PyString_InternFromString("_type_enum"); | ||||
| 1136 | #else | ||||
| 1137 | // Python 3.x: | ||||
| 1138 | static PyObject* dtype_attr = PyUnicode_InternFromString("dtype"); | ||||
| 1139 | static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum"); | ||||
| 1140 | #endif | ||||
| 1141 | Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr)); | ||||
| 1142 | if (!dtype_field) { | ||||
| 1143 | return DT_INVALID; | ||||
| 1144 | } | ||||
| 1145 | |||||
| 1146 | Safe_PyObjectPtr enum_field( | ||||
| 1147 | PyObject_GetAttr(dtype_field.get(), type_enum_attr)); | ||||
| 1148 | if (!enum_field) { | ||||
| 1149 | return DT_INVALID; | ||||
| 1150 | } | ||||
| 1151 | |||||
| 1152 | return static_cast<DataType>(MakeInt(enum_field.get())); | ||||
| 1153 | } | ||||
| 1154 | } | ||||
| 1155 | } // namespace tensorflow | ||||
| 1156 | |||||
| 1157 | class PyTapeTensor { | ||||
| 1158 | public: | ||||
| 1159 | PyTapeTensor(int64_t id, tensorflow::DataType dtype, | ||||
| 1160 | const tensorflow::TensorShape& shape) | ||||
| 1161 | : id_(id), dtype_(dtype), shape_(shape) {} | ||||
| 1162 | PyTapeTensor(int64_t id, tensorflow::DataType dtype, PyObject* shape) | ||||
| 1163 | : id_(id), dtype_(dtype), shape_(shape) { | ||||
| 1164 | Py_INCREF(absl::get<1>(shape_))_Py_INCREF(((PyObject*)(absl::get<1>(shape_)))); | ||||
| 1165 | } | ||||
| 1166 | PyTapeTensor(const PyTapeTensor& other) { | ||||
| 1167 | id_ = other.id_; | ||||
| 1168 | dtype_ = other.dtype_; | ||||
| 1169 | shape_ = other.shape_; | ||||
| 1170 | if (shape_.index() == 1) { | ||||
| 1171 | Py_INCREF(absl::get<1>(shape_))_Py_INCREF(((PyObject*)(absl::get<1>(shape_)))); | ||||
| 1172 | } | ||||
| 1173 | } | ||||
| 1174 | |||||
| 1175 | ~PyTapeTensor() { | ||||
| 1176 | if (shape_.index() == 1) { | ||||
| 1177 | Py_DECREF(absl::get<1>(shape_))_Py_DECREF(((PyObject*)(absl::get<1>(shape_)))); | ||||
| 1178 | } | ||||
| 1179 | } | ||||
| 1180 | PyObject* GetShape() const; | ||||
| 1181 | PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); } | ||||
| 1182 | int64_t GetID() const { return id_; } | ||||
| 1183 | tensorflow::DataType GetDType() const { return dtype_; } | ||||
| 1184 | |||||
| 1185 | PyObject* OnesLike() const; | ||||
| 1186 | PyObject* ZerosLike() const; | ||||
| 1187 | |||||
| 1188 | private: | ||||
| 1189 | int64_t id_; | ||||
| 1190 | tensorflow::DataType dtype_; | ||||
| 1191 | |||||
| 1192 | // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that | ||||
| 1193 | // PyObject is the tensor itself. This is used to support tf.shape(tensor) for | ||||
| 1194 | // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype | ||||
| 1195 | // tensors. | ||||
| 1196 | absl::variant<tensorflow::TensorShape, PyObject*> shape_; | ||||
| 1197 | }; | ||||
| 1198 | |||||
| 1199 | static PyTapeTensor TapeTensorFromTensor(PyObject* tensor); | ||||
| 1200 | |||||
| 1201 | class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction, | ||||
| 1202 | PyTapeTensor> { | ||||
| 1203 | public: | ||||
| 1204 | explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { | ||||
| 1205 | Py_INCREF(py_vspace_)_Py_INCREF(((PyObject*)(py_vspace_))); | ||||
| 1206 | } | ||||
| 1207 | |||||
| 1208 | tensorflow::Status Initialize() { | ||||
| 1209 | num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); | ||||
| 1210 | if (num_elements_ == nullptr) { | ||||
| 1211 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
| 1212 | } | ||||
| 1213 | aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn"); | ||||
| 1214 | if (aggregate_fn_ == nullptr) { | ||||
| 1215 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
| 1216 | } | ||||
| 1217 | zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn"); | ||||
| 1218 | if (zeros_fn_ == nullptr) { | ||||
| 1219 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
| 1220 | } | ||||
| 1221 | zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn"); | ||||
| 1222 | if (zeros_like_fn_ == nullptr) { | ||||
| 1223 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
| 1224 | } | ||||
| 1225 | ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn"); | ||||
| 1226 | if (ones_fn_ == nullptr) { | ||||
| 1227 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
| 1228 | } | ||||
| 1229 | ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn"); | ||||
| 1230 | if (ones_like_fn_ == nullptr) { | ||||
| 1231 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
| 1232 | } | ||||
| 1233 | graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn"); | ||||
| 1234 | if (graph_shape_fn_ == nullptr) { | ||||
| 1235 | return tensorflow::errors::InvalidArgument("invalid vspace"); | ||||
| 1236 | } | ||||
| 1237 | return tensorflow::Status::OK(); | ||||
| 1238 | } | ||||
| 1239 | |||||
| 1240 | ~PyVSpace() override { | ||||
| 1241 | Py_XDECREF(num_elements_)_Py_XDECREF(((PyObject*)(num_elements_))); | ||||
| 1242 | Py_XDECREF(aggregate_fn_)_Py_XDECREF(((PyObject*)(aggregate_fn_))); | ||||
| 1243 | Py_XDECREF(zeros_fn_)_Py_XDECREF(((PyObject*)(zeros_fn_))); | ||||
| 1244 | Py_XDECREF(zeros_like_fn_)_Py_XDECREF(((PyObject*)(zeros_like_fn_))); | ||||
| 1245 | Py_XDECREF(ones_fn_)_Py_XDECREF(((PyObject*)(ones_fn_))); | ||||
| 1246 | Py_XDECREF(ones_like_fn_)_Py_XDECREF(((PyObject*)(ones_like_fn_))); | ||||
| 1247 | Py_XDECREF(graph_shape_fn_)_Py_XDECREF(((PyObject*)(graph_shape_fn_))); | ||||
| 1248 | |||||
| 1249 | Py_DECREF(py_vspace_)_Py_DECREF(((PyObject*)(py_vspace_))); | ||||
| 1250 | } | ||||
| 1251 | |||||
| 1252 | int64_t NumElements(PyObject* tensor) const final { | ||||
| 1253 | if (EagerTensor_CheckExact(tensor)) { | ||||
| 1254 | return PyEagerTensor_NumElements(tensor); | ||||
| 1255 | } | ||||
| 1256 | PyObject* arglist = | ||||
| 1257 | Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor)); | ||||
| 1258 |     PyObject* result = PyEval_CallObject(num_elements_, arglist)PyEval_CallObjectWithKeywords(num_elements_, arglist, (PyObject *)__null);  | ||||
| 1259 | Py_DECREF(arglist)_Py_DECREF(((PyObject*)(arglist))); | ||||
| 1260 | if (result == nullptr) { | ||||
| 1261 | // The caller detects whether a python exception has been raised. | ||||
| 1262 | return -1; | ||||
| 1263 | } | ||||
| 1264 | int64_t r = MakeInt(result); | ||||
| 1265 | Py_DECREF(result)_Py_DECREF(((PyObject*)(result))); | ||||
| 1266 | return r; | ||||
| 1267 | } | ||||
| 1268 | |||||
| 1269 | PyObject* AggregateGradients( | ||||
| 1270 | tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final { | ||||
| 1271 | PyObject* list = PyList_New(gradient_tensors.size()); | ||||
| 1272 | for (int i = 0; i < gradient_tensors.size(); ++i) { | ||||
| 1273 | // Note: stealing a reference to the gradient tensors. | ||||
| 1274 |       CHECK(gradient_tensors[i] != nullptr)if ((__builtin_expect(!(gradient_tensors[i] != nullptr), 0))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 1274) << "Check failed: " "gradient_tensors[i] != nullptr" " ";  | ||||
| 1275 |       CHECK(gradient_tensors[i] != Py_None)if ((__builtin_expect(!(gradient_tensors[i] != (&_Py_NoneStruct )), 0))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 1275) << "Check failed: " "gradient_tensors[i] != Py_None" " ";  | ||||
| 1276 |       PyList_SET_ITEM(list, i,PyList_SetItem(list, i, reinterpret_cast<PyObject*>(gradient_tensors [i]))  | ||||
| 1277 |                       reinterpret_cast<PyObject*>(gradient_tensors[i]))PyList_SetItem(list, i, reinterpret_cast<PyObject*>(gradient_tensors [i]));  | ||||
| 1278 | } | ||||
| 1279 | PyObject* arglist = Py_BuildValue("(O)", list); | ||||
| 1280 |     CHECK(arglist != nullptr)if ((__builtin_expect(!(arglist != nullptr), 0))) ::tensorflow ::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 1280) << "Check failed: " "arglist != nullptr" " ";  | ||||
| 1281 |     PyObject* result = PyEval_CallObject(aggregate_fn_, arglist)PyEval_CallObjectWithKeywords(aggregate_fn_, arglist, (PyObject *)__null);  | ||||
| 1282 | Py_DECREF(arglist)_Py_DECREF(((PyObject*)(arglist))); | ||||
| 1283 | Py_DECREF(list)_Py_DECREF(((PyObject*)(list))); | ||||
| 1284 | return result; | ||||
| 1285 | } | ||||
| 1286 | |||||
| 1287 | int64_t TensorId(PyObject* tensor) const final { | ||||
| 1288 | return FastTensorId(tensor); | ||||
| 1289 | } | ||||
| 1290 | |||||
| 1291 | void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient)_Py_INCREF(((PyObject*)(gradient))); } | ||||
| 1292 | |||||
| 1293 | PyObject* Ones(PyObject* shape, PyObject* dtype) const { | ||||
| 1294 | if (PyErr_Occurred()) { | ||||
| 1295 | return nullptr; | ||||
| 1296 | } | ||||
| 1297 | PyObject* arg_list = Py_BuildValue("OO", shape, dtype); | ||||
| 1298 |     PyObject* result = PyEval_CallObject(ones_fn_, arg_list)PyEval_CallObjectWithKeywords(ones_fn_, arg_list, (PyObject * )__null);  | ||||
| 1299 | Py_DECREF(arg_list)_Py_DECREF(((PyObject*)(arg_list))); | ||||
| 1300 | return result; | ||||
| 1301 | } | ||||
| 1302 | |||||
| 1303 | PyObject* OnesLike(PyObject* tensor) const { | ||||
| 1304 | if (PyErr_Occurred()) { | ||||
| 1305 | return nullptr; | ||||
| 1306 | } | ||||
| 1307 | return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL__null); | ||||
| 1308 | } | ||||
| 1309 | |||||
| 1310 | // Builds a tensor filled with ones with the same shape and dtype as `t`. | ||||
| 1311 | Status BuildOnesLike(const PyTapeTensor& t, | ||||
| 1312 | PyObject** result) const override { | ||||
| 1313 | *result = t.OnesLike(); | ||||
| 1314 | return Status::OK(); | ||||
| 1315 | } | ||||
| 1316 | |||||
| 1317 | PyObject* Zeros(PyObject* shape, PyObject* dtype) const { | ||||
| 1318 | if (PyErr_Occurred()) { | ||||
| 1319 | return nullptr; | ||||
| 1320 | } | ||||
| 1321 | PyObject* arg_list = Py_BuildValue("OO", shape, dtype); | ||||
| 1322 |     PyObject* result = PyEval_CallObject(zeros_fn_, arg_list)PyEval_CallObjectWithKeywords(zeros_fn_, arg_list, (PyObject * )__null);  | ||||
| 1323 | Py_DECREF(arg_list)_Py_DECREF(((PyObject*)(arg_list))); | ||||
| 1324 | return result; | ||||
| 1325 | } | ||||
| 1326 | |||||
| 1327 | PyObject* ZerosLike(PyObject* tensor) const { | ||||
| 1328 | if (PyErr_Occurred()) { | ||||
| 1329 | return nullptr; | ||||
| 1330 | } | ||||
| 1331 | return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL__null); | ||||
| 1332 | } | ||||
| 1333 | |||||
| 1334 | PyObject* GraphShape(PyObject* tensor) const { | ||||
| 1335 | PyObject* arg_list = Py_BuildValue("(O)", tensor); | ||||
| 1336 |     PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list)PyEval_CallObjectWithKeywords(graph_shape_fn_, arg_list, (PyObject *)__null);  | ||||
| 1337 | Py_DECREF(arg_list)_Py_DECREF(((PyObject*)(arg_list))); | ||||
| 1338 | return result; | ||||
| 1339 | } | ||||
| 1340 | |||||
| 1341 | tensorflow::Status CallBackwardFunction( | ||||
| 1342 | const string& op_type, PyBackwardFunction* backward_function, | ||||
| 1343 | const std::vector<int64_t>& unneeded_gradients, | ||||
| 1344 | tensorflow::gtl::ArraySlice<PyObject*> output_gradients, | ||||
| 1345 | absl::Span<PyObject*> result) const final { | ||||
| 1346 | PyObject* grads = PyTuple_New(output_gradients.size()); | ||||
| 1347 | for (int i = 0; i < output_gradients.size(); ++i) { | ||||
| 1348 | if (output_gradients[i] == nullptr) { | ||||
| 1349 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
| 1350 | PyTuple_SET_ITEM(grads, i, Py_None)PyTuple_SetItem(grads, i, (&_Py_NoneStruct)); | ||||
| 1351 | } else { | ||||
| 1352 |         PyTuple_SET_ITEM(grads, i,PyTuple_SetItem(grads, i, reinterpret_cast<PyObject*>(output_gradients [i]))  | ||||
| 1353 |                          reinterpret_cast<PyObject*>(output_gradients[i]))PyTuple_SetItem(grads, i, reinterpret_cast<PyObject*>(output_gradients [i]));  | ||||
| 1354 | } | ||||
| 1355 | } | ||||
| 1356 | PyObject* py_result = (*backward_function)(grads, unneeded_gradients); | ||||
| 1357 | Py_DECREF(grads)_Py_DECREF(((PyObject*)(grads))); | ||||
| 1358 | if (py_result == nullptr) { | ||||
| 1359 | return tensorflow::errors::Internal("gradient function threw exceptions"); | ||||
| 1360 | } | ||||
| 1361 | PyObject* seq = | ||||
| 1362 | PySequence_Fast(py_result, "expected a sequence of gradients"); | ||||
| 1363 | if (seq == nullptr) { | ||||
| 1364 | return tensorflow::errors::InvalidArgument( | ||||
| 1365 | "gradient function did not return a list"); | ||||
| 1366 | } | ||||
| 1367 |     int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size));  | ||||
| 1368 | if (len != result.size()) { | ||||
| 1369 | return tensorflow::errors::Internal( | ||||
| 1370 | "Recorded operation '", op_type, | ||||
| 1371 | "' returned too few gradients. Expected ", result.size(), | ||||
| 1372 | " but received ", len); | ||||
| 1373 | } | ||||
| 1374 |     PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item);  | ||||
| 1375 |     VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 1375, tensorflow::INFO) << "Gradient length is " << len;  | ||||
| 1376 | for (int i = 0; i < len; ++i) { | ||||
| 1377 | PyObject* item = seq_array[i]; | ||||
| 1378 | if (item == Py_None(&_Py_NoneStruct)) { | ||||
| 1379 | result[i] = nullptr; | ||||
| 1380 | } else { | ||||
| 1381 | Py_INCREF(item)_Py_INCREF(((PyObject*)(item))); | ||||
| 1382 | result[i] = item; | ||||
| 1383 | } | ||||
| 1384 | } | ||||
| 1385 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | ||||
| 1386 | Py_DECREF(py_result)_Py_DECREF(((PyObject*)(py_result))); | ||||
| 1387 | return tensorflow::Status::OK(); | ||||
| 1388 | } | ||||
| 1389 | |||||
| 1390 | void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor)_Py_XDECREF(((PyObject*)(tensor))); } | ||||
| 1391 | |||||
| 1392 | PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final { | ||||
| 1393 | return TapeTensorFromTensor(tensor); | ||||
| 1394 | } | ||||
| 1395 | |||||
| 1396 | private: | ||||
| 1397 | PyObject* py_vspace_; | ||||
| 1398 | |||||
| 1399 | PyObject* num_elements_; | ||||
| 1400 | PyObject* aggregate_fn_; | ||||
| 1401 | PyObject* zeros_fn_; | ||||
| 1402 | PyObject* zeros_like_fn_; | ||||
| 1403 | PyObject* ones_fn_; | ||||
| 1404 | PyObject* ones_like_fn_; | ||||
| 1405 | PyObject* graph_shape_fn_; | ||||
| 1406 | }; | ||||
| 1407 | PyVSpace* py_vspace = nullptr; | ||||
| 1408 | |||||
| 1409 | bool HasAccumulator(); | ||||
| 1410 | |||||
| 1411 | PyObject* TFE_Py_RegisterVSpace(PyObject* e) { | ||||
| 1412 | if (py_vspace != nullptr) { | ||||
| 1413 | if (HasAccumulator()) { | ||||
| 1414 | // Accumulators reference py_vspace, so we can't swap it out while one is | ||||
| 1415 | // active. This is unlikely to ever happen. | ||||
| 1416 | MaybeRaiseExceptionFromStatus( | ||||
| 1417 | tensorflow::errors::Internal( | ||||
| 1418 | "Can't change the vspace implementation while a " | ||||
| 1419 | "forward accumulator is active."), | ||||
| 1420 | nullptr); | ||||
| 1421 | } | ||||
| 1422 | delete py_vspace; | ||||
| 1423 | } | ||||
| 1424 | |||||
| 1425 | py_vspace = new PyVSpace(e); | ||||
| 1426 | auto status = py_vspace->Initialize(); | ||||
| 1427 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { | ||||
| 1428 | delete py_vspace; | ||||
| 1429 | return nullptr; | ||||
| 1430 | } | ||||
| 1431 | |||||
| 1432 |   Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 1433 | } | ||||
| 1434 | |||||
| 1435 | PyObject* PyTapeTensor::GetShape() const { | ||||
| 1436 | if (shape_.index() == 0) { | ||||
| 1437 | auto& shape = absl::get<0>(shape_); | ||||
| 1438 | PyObject* py_shape = PyTuple_New(shape.dims()); | ||||
| 1439 | for (int i = 0; i < shape.dims(); ++i) { | ||||
| 1440 |       PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)))PyTuple_SetItem(py_shape, i, PyLong_FromLong(shape.dim_size(i )));  | ||||
| 1441 | } | ||||
| 1442 | |||||
| 1443 | return py_shape; | ||||
| 1444 | } | ||||
| 1445 | |||||
| 1446 | return py_vspace->GraphShape(absl::get<1>(shape_)); | ||||
| 1447 | } | ||||
| 1448 | |||||
| 1449 | PyObject* PyTapeTensor::OnesLike() const { | ||||
| 1450 | if (shape_.index() == 1) { | ||||
| 1451 | PyObject* tensor = absl::get<1>(shape_); | ||||
| 1452 | return py_vspace->OnesLike(tensor); | ||||
| 1453 | } | ||||
| 1454 | PyObject* py_shape = GetShape(); | ||||
| 1455 | PyObject* dtype_field = GetPyDType(); | ||||
| 1456 | PyObject* result = py_vspace->Ones(py_shape, dtype_field); | ||||
| 1457 | Py_DECREF(dtype_field)_Py_DECREF(((PyObject*)(dtype_field))); | ||||
| 1458 | Py_DECREF(py_shape)_Py_DECREF(((PyObject*)(py_shape))); | ||||
| 1459 | return result; | ||||
| 1460 | } | ||||
| 1461 | |||||
| 1462 | PyObject* PyTapeTensor::ZerosLike() const { | ||||
| 1463 | if (GetDType() == tensorflow::DT_RESOURCE) { | ||||
| 1464 | // Gradient functions for ops which return resource tensors accept | ||||
| 1465 | // None. This is the behavior of py_vspace->Zeros, but checking here avoids | ||||
| 1466 | // issues with ZerosLike. | ||||
| 1467 |     Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 1468 | } | ||||
| 1469 | if (shape_.index() == 1) { | ||||
| 1470 | PyObject* tensor = absl::get<1>(shape_); | ||||
| 1471 | return py_vspace->ZerosLike(tensor); | ||||
| 1472 | } | ||||
| 1473 | PyObject* py_shape = GetShape(); | ||||
| 1474 | PyObject* dtype_field = GetPyDType(); | ||||
| 1475 | PyObject* result = py_vspace->Zeros(py_shape, dtype_field); | ||||
| 1476 | Py_DECREF(dtype_field)_Py_DECREF(((PyObject*)(dtype_field))); | ||||
| 1477 | Py_DECREF(py_shape)_Py_DECREF(((PyObject*)(py_shape))); | ||||
| 1478 | return result; | ||||
| 1479 | } | ||||
| 1480 | |||||
| 1481 | // Keeps track of all variables that have been accessed during execution. | ||||
| 1482 | class VariableWatcher { | ||||
| 1483 | public: | ||||
| 1484 | VariableWatcher() {} | ||||
| 1485 | |||||
| 1486 | ~VariableWatcher() { | ||||
| 1487 | for (const IdAndVariable& v : watched_variables_) { | ||||
| 1488 | Py_DECREF(v.variable)_Py_DECREF(((PyObject*)(v.variable))); | ||||
| 1489 | } | ||||
| 1490 | } | ||||
| 1491 | |||||
| 1492 | int64_t WatchVariable(PyObject* v) { | ||||
| 1493 | tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle")); | ||||
| 1494 | if (handle == nullptr) { | ||||
| 1495 | return -1; | ||||
| 1496 | } | ||||
| 1497 | int64_t id = FastTensorId(handle.get()); | ||||
| 1498 | |||||
| 1499 | tensorflow::mutex_lock l(watched_variables_mu_); | ||||
| 1500 | auto insert_result = watched_variables_.emplace(id, v); | ||||
| 1501 | |||||
| 1502 | if (insert_result.second) { | ||||
| 1503 | // Only increment the reference count if we aren't already watching this | ||||
| 1504 | // variable. | ||||
| 1505 | Py_INCREF(v)_Py_INCREF(((PyObject*)(v))); | ||||
| 1506 | } | ||||
| 1507 | |||||
| 1508 | return id; | ||||
| 1509 | } | ||||
| 1510 | |||||
| 1511 | PyObject* GetVariablesAsPyTuple() { | ||||
| 1512 | tensorflow::mutex_lock l(watched_variables_mu_); | ||||
| 1513 | PyObject* result = PyTuple_New(watched_variables_.size()); | ||||
| 1514 | Py_ssize_t pos = 0; | ||||
| 1515 | for (const IdAndVariable& id_and_variable : watched_variables_) { | ||||
| 1516 | PyTuple_SET_ITEM(result, pos++, id_and_variable.variable)PyTuple_SetItem(result, pos++, id_and_variable.variable); | ||||
| 1517 | Py_INCREF(id_and_variable.variable)_Py_INCREF(((PyObject*)(id_and_variable.variable))); | ||||
| 1518 | } | ||||
| 1519 | return result; | ||||
| 1520 | } | ||||
| 1521 | |||||
| 1522 | private: | ||||
| 1523 | // We store an IdAndVariable in the map since the map needs to be locked | ||||
| 1524 | // during insert, but should not call back into python during insert to avoid | ||||
| 1525 | // deadlocking with the GIL. | ||||
| 1526 | struct IdAndVariable { | ||||
| 1527 | int64_t id; | ||||
| 1528 | PyObject* variable; | ||||
| 1529 | |||||
| 1530 | IdAndVariable(int64_t id, PyObject* variable) | ||||
| 1531 | : id(id), variable(variable) {} | ||||
| 1532 | }; | ||||
| 1533 | struct CompareById { | ||||
| 1534 | bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const { | ||||
| 1535 | return lhs.id < rhs.id; | ||||
| 1536 | } | ||||
| 1537 | }; | ||||
| 1538 | |||||
| 1539 | tensorflow::mutex watched_variables_mu_; | ||||
| 1540 | std::set<IdAndVariable, CompareById> watched_variables_ | ||||
| 1541 | TF_GUARDED_BY(watched_variables_mu_)__attribute__((guarded_by(watched_variables_mu_))); | ||||
| 1542 | }; | ||||
| 1543 | |||||
| 1544 | class GradientTape | ||||
| 1545 | : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, | ||||
| 1546 | PyTapeTensor> { | ||||
| 1547 | public: | ||||
| 1548 | explicit GradientTape(bool persistent, bool watch_accessed_variables) | ||||
| 1549 | : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, | ||||
| 1550 | PyTapeTensor>(persistent), | ||||
| 1551 | watch_accessed_variables_(watch_accessed_variables) {} | ||||
| 1552 | |||||
| 1553 | virtual ~GradientTape() {} | ||||
| 1554 | |||||
| 1555 | void VariableAccessed(PyObject* v) { | ||||
| 1556 | if (watch_accessed_variables_) { | ||||
| 1557 | WatchVariable(v); | ||||
| 1558 | } | ||||
| 1559 | } | ||||
| 1560 | |||||
| 1561 | void WatchVariable(PyObject* v) { | ||||
| 1562 | int64_t id = variable_watcher_.WatchVariable(v); | ||||
| 1563 | |||||
| 1564 | if (!PyErr_Occurred()) { | ||||
| 1565 | this->Watch(id); | ||||
| 1566 | } | ||||
| 1567 | } | ||||
| 1568 | |||||
| 1569 | PyObject* GetVariablesAsPyTuple() { | ||||
| 1570 | return variable_watcher_.GetVariablesAsPyTuple(); | ||||
| 1571 | } | ||||
| 1572 | |||||
| 1573 | private: | ||||
| 1574 | bool watch_accessed_variables_; | ||||
| 1575 | VariableWatcher variable_watcher_; | ||||
| 1576 | }; | ||||
| 1577 | |||||
| 1578 | typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction, | ||||
| 1579 | PyTapeTensor> | ||||
| 1580 | ForwardAccumulator; | ||||
| 1581 | |||||
| 1582 | // Incremented when a GradientTape or accumulator is newly added to a set, and | ||||
| 1583 | // used to enforce an ordering between them. | ||||
| 1584 | std::atomic_uint_fast64_t tape_nesting_id_counter(0); | ||||
| 1585 | |||||
| 1586 | typedef struct { | ||||
| 1587 | PyObject_HEADPyObject ob_base; | ||||
| 1588 | /* Type-specific fields go here. */ | ||||
| 1589 | GradientTape* tape; | ||||
| 1590 | // A nesting order between GradientTapes and ForwardAccumulators, used to | ||||
| 1591 | // ensure that GradientTapes do not watch the products of outer | ||||
| 1592 | // ForwardAccumulators. | ||||
| 1593 | int64_t nesting_id; | ||||
| 1594 | } TFE_Py_Tape; | ||||
| 1595 | |||||
| 1596 | static void TFE_Py_Tape_Delete(PyObject* tape) { | ||||
| 1597 | delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape; | ||||
| 1598 | Py_TYPE(tape)(((PyObject*)(tape))->ob_type)->tp_free(tape); | ||||
| 1599 | } | ||||
| 1600 | |||||
| 1601 | static PyTypeObject TFE_Py_Tape_Type = { | ||||
| 1602 | PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, "tfe.Tape", /* tp_name */ | ||||
| 1603 | sizeof(TFE_Py_Tape), /* tp_basicsize */ | ||||
| 1604 | 0, /* tp_itemsize */ | ||||
| 1605 | &TFE_Py_Tape_Delete, /* tp_dealloc */ | ||||
| 1606 | #if PY_VERSION_HEX((3 << 24) | (8 << 16) | (5 << 8) | (0xF << 4) | (0 << 0)) < 0x03080000  | ||||
| 1607 | nullptr, /* tp_print */ | ||||
| 1608 | #else | ||||
| 1609 | 0, /* tp_vectorcall_offset */ | ||||
| 1610 | #endif | ||||
| 1611 | nullptr, /* tp_getattr */ | ||||
| 1612 | nullptr, /* tp_setattr */ | ||||
| 1613 | nullptr, /* tp_reserved */ | ||||
| 1614 | nullptr, /* tp_repr */ | ||||
| 1615 | nullptr, /* tp_as_number */ | ||||
| 1616 | nullptr, /* tp_as_sequence */ | ||||
| 1617 | nullptr, /* tp_as_mapping */ | ||||
| 1618 | nullptr, /* tp_hash */ | ||||
| 1619 | nullptr, /* tp_call */ | ||||
| 1620 | nullptr, /* tp_str */ | ||||
| 1621 | nullptr, /* tp_getattro */ | ||||
| 1622 | nullptr, /* tp_setattro */ | ||||
| 1623 | nullptr, /* tp_as_buffer */ | ||||
| 1624 | Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0), /* tp_flags */ | ||||
| 1625 | "TFE_Py_Tape objects", /* tp_doc */ | ||||
| 1626 | }; | ||||
| 1627 | |||||
| 1628 | typedef struct { | ||||
| 1629 | PyObject_HEADPyObject ob_base; | ||||
| 1630 | /* Type-specific fields go here. */ | ||||
| 1631 | ForwardAccumulator* accumulator; | ||||
| 1632 | // A nesting order between GradientTapes and ForwardAccumulators, used to | ||||
| 1633 | // ensure that GradientTapes do not watch the products of outer | ||||
| 1634 | // ForwardAccumulators. | ||||
| 1635 | int64_t nesting_id; | ||||
| 1636 | } TFE_Py_ForwardAccumulator; | ||||
| 1637 | |||||
| 1638 | static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) { | ||||
| 1639 | delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator; | ||||
| 1640 | Py_TYPE(accumulator)(((PyObject*)(accumulator))->ob_type)->tp_free(accumulator); | ||||
| 1641 | } | ||||
| 1642 | |||||
| 1643 | static PyTypeObject TFE_Py_ForwardAccumulator_Type = { | ||||
| 1644 | PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, "ForwardAccumulator", /* tp_name */ | ||||
| 1645 | sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */ | ||||
| 1646 | 0, /* tp_itemsize */ | ||||
| 1647 | &TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */ | ||||
| 1648 | #if PY_VERSION_HEX((3 << 24) | (8 << 16) | (5 << 8) | (0xF << 4) | (0 << 0)) < 0x03080000  | ||||
| 1649 | nullptr, /* tp_print */ | ||||
| 1650 | #else | ||||
| 1651 | 0, /* tp_vectorcall_offset */ | ||||
| 1652 | #endif | ||||
| 1653 | nullptr, /* tp_getattr */ | ||||
| 1654 | nullptr, /* tp_setattr */ | ||||
| 1655 | nullptr, /* tp_reserved */ | ||||
| 1656 | nullptr, /* tp_repr */ | ||||
| 1657 | nullptr, /* tp_as_number */ | ||||
| 1658 | nullptr, /* tp_as_sequence */ | ||||
| 1659 | nullptr, /* tp_as_mapping */ | ||||
| 1660 | nullptr, /* tp_hash */ | ||||
| 1661 | nullptr, /* tp_call */ | ||||
| 1662 | nullptr, /* tp_str */ | ||||
| 1663 | nullptr, /* tp_getattro */ | ||||
| 1664 | nullptr, /* tp_setattro */ | ||||
| 1665 | nullptr, /* tp_as_buffer */ | ||||
| 1666 | Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0), /* tp_flags */ | ||||
| 1667 | "TFE_Py_ForwardAccumulator objects", /* tp_doc */ | ||||
| 1668 | }; | ||||
| 1669 | |||||
| 1670 | typedef struct { | ||||
| 1671 | PyObject_HEADPyObject ob_base; | ||||
| 1672 | /* Type-specific fields go here. */ | ||||
| 1673 | VariableWatcher* variable_watcher; | ||||
| 1674 | } TFE_Py_VariableWatcher; | ||||
| 1675 | |||||
| 1676 | static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) { | ||||
| 1677 | delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher) | ||||
| 1678 | ->variable_watcher; | ||||
| 1679 | Py_TYPE(variable_watcher)(((PyObject*)(variable_watcher))->ob_type)->tp_free(variable_watcher); | ||||
| 1680 | } | ||||
| 1681 | |||||
| 1682 | static PyTypeObject TFE_Py_VariableWatcher_Type = { | ||||
| 1683 | PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, "tfe.VariableWatcher", /* tp_name */ | ||||
| 1684 | sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */ | ||||
| 1685 | 0, /* tp_itemsize */ | ||||
| 1686 | &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */ | ||||
| 1687 | #if PY_VERSION_HEX((3 << 24) | (8 << 16) | (5 << 8) | (0xF << 4) | (0 << 0)) < 0x03080000  | ||||
| 1688 | nullptr, /* tp_print */ | ||||
| 1689 | #else | ||||
| 1690 | 0, /* tp_vectorcall_offset */ | ||||
| 1691 | #endif | ||||
| 1692 | nullptr, /* tp_getattr */ | ||||
| 1693 | nullptr, /* tp_setattr */ | ||||
| 1694 | nullptr, /* tp_reserved */ | ||||
| 1695 | nullptr, /* tp_repr */ | ||||
| 1696 | nullptr, /* tp_as_number */ | ||||
| 1697 | nullptr, /* tp_as_sequence */ | ||||
| 1698 | nullptr, /* tp_as_mapping */ | ||||
| 1699 | nullptr, /* tp_hash */ | ||||
| 1700 | nullptr, /* tp_call */ | ||||
| 1701 | nullptr, /* tp_str */ | ||||
| 1702 | nullptr, /* tp_getattro */ | ||||
| 1703 | nullptr, /* tp_setattro */ | ||||
| 1704 | nullptr, /* tp_as_buffer */ | ||||
| 1705 | Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0), /* tp_flags */ | ||||
| 1706 | "TFE_Py_VariableWatcher objects", /* tp_doc */ | ||||
| 1707 | }; | ||||
| 1708 | |||||
| 1709 | // Note: in the current design no mutex is needed here because of the python | ||||
| 1710 | // GIL, which is always held when any TFE_Py_* methods are called. We should | ||||
| 1711 | // revisit this if/when decide to not hold the GIL while manipulating the tape | ||||
| 1712 | // stack. | ||||
| 1713 | tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() { | ||||
| 1714 | thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> | ||||
| 1715 | tape_set = nullptr; | ||||
| 1716 | if (tape_set == nullptr) { | ||||
| 1717 | tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>); | ||||
| 1718 | } | ||||
| 1719 | return tape_set.get(); | ||||
| 1720 | } | ||||
| 1721 | |||||
| 1722 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>* | ||||
| 1723 | GetVariableWatcherSet() { | ||||
| 1724 | thread_local std::unique_ptr< | ||||
| 1725 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> | ||||
| 1726 | variable_watcher_set = nullptr; | ||||
| 1727 | if (variable_watcher_set == nullptr) { | ||||
| 1728 | variable_watcher_set.reset( | ||||
| 1729 | new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>); | ||||
| 1730 | } | ||||
| 1731 | return variable_watcher_set.get(); | ||||
| 1732 | } | ||||
| 1733 | |||||
| 1734 | // A linked hash set, where iteration is in insertion order. | ||||
| 1735 | // | ||||
| 1736 | // Nested accumulators rely on op recording happening in insertion order, so an | ||||
| 1737 | // unordered data structure like CompactPointerSet is not suitable. Outer | ||||
| 1738 | // accumulators need to observe operations first so they know to watch the inner | ||||
| 1739 | // accumulator's jvp computation. | ||||
| 1740 | // | ||||
| 1741 | // Not thread safe. | ||||
| 1742 | class AccumulatorSet { | ||||
| 1743 | public: | ||||
| 1744 | // Returns true if `element` was newly inserted, false if it already exists. | ||||
| 1745 | bool insert(TFE_Py_ForwardAccumulator* element) { | ||||
| 1746 | if (map_.find(element) != map_.end()) { | ||||
| 1747 | return false; | ||||
| 1748 | } | ||||
| 1749 | ListType::iterator it = ordered_.insert(ordered_.end(), element); | ||||
| 1750 | map_.insert(std::make_pair(element, it)); | ||||
| 1751 | return true; | ||||
| 1752 | } | ||||
| 1753 | |||||
| 1754 | void erase(TFE_Py_ForwardAccumulator* element) { | ||||
| 1755 | MapType::iterator existing = map_.find(element); | ||||
| 1756 | if (existing == map_.end()) { | ||||
| 1757 | return; | ||||
| 1758 | } | ||||
| 1759 | ListType::iterator list_position = existing->second; | ||||
| 1760 | map_.erase(existing); | ||||
| 1761 | ordered_.erase(list_position); | ||||
| 1762 | } | ||||
| 1763 | |||||
| 1764 | bool empty() const { return ordered_.empty(); } | ||||
| 1765 | |||||
| 1766 | size_t size() const { return ordered_.size(); } | ||||
| 1767 | |||||
| 1768 | private: | ||||
| 1769 | typedef std::list<TFE_Py_ForwardAccumulator*> ListType; | ||||
| 1770 | typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*, | ||||
| 1771 | ListType::iterator> | ||||
| 1772 | MapType; | ||||
| 1773 | |||||
| 1774 | public: | ||||
| 1775 | typedef ListType::const_iterator const_iterator; | ||||
| 1776 | typedef ListType::const_reverse_iterator const_reverse_iterator; | ||||
| 1777 | |||||
| 1778 | const_iterator begin() const { return ordered_.begin(); } | ||||
| 1779 | const_iterator end() const { return ordered_.end(); } | ||||
| 1780 | |||||
| 1781 | const_reverse_iterator rbegin() const { return ordered_.rbegin(); } | ||||
| 1782 | const_reverse_iterator rend() const { return ordered_.rend(); } | ||||
| 1783 | |||||
| 1784 | private: | ||||
| 1785 | MapType map_; | ||||
| 1786 | ListType ordered_; | ||||
| 1787 | }; | ||||
| 1788 | |||||
| 1789 | AccumulatorSet* GetAccumulatorSet() { | ||||
| 1790 | thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr}; | ||||
| 1791 | if (accumulator_set == nullptr) { | ||||
| 1792 | accumulator_set.reset(new AccumulatorSet); | ||||
| 1793 | } | ||||
| 1794 | return accumulator_set.get(); | ||||
| 1795 | } | ||||
| 1796 | |||||
| 1797 | inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); } | ||||
| 1798 | |||||
| 1799 | inline bool HasGradientTape() { return !GetTapeSet()->empty(); } | ||||
| 1800 | |||||
| 1801 | inline bool HasAccumulatorOrTape() { | ||||
| 1802 | return HasGradientTape() || HasAccumulator(); | ||||
| 1803 | } | ||||
| 1804 | |||||
| 1805 | // A safe copy of a set, used for tapes and accumulators. The copy is not | ||||
| 1806 | // affected by other python threads changing the set of active tapes. | ||||
| 1807 | template <typename ContainerType> | ||||
| 1808 | class SafeSetCopy { | ||||
| 1809 | public: | ||||
| 1810 | explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) { | ||||
| 1811 | for (auto* member : set_copy_) { | ||||
| 1812 | Py_INCREF(member)_Py_INCREF(((PyObject*)(member))); | ||||
| 1813 | } | ||||
| 1814 | } | ||||
| 1815 | |||||
| 1816 | ~SafeSetCopy() { | ||||
| 1817 | for (auto* member : set_copy_) { | ||||
| 1818 | Py_DECREF(member)_Py_DECREF(((PyObject*)(member))); | ||||
| 1819 | } | ||||
| 1820 | } | ||||
| 1821 | |||||
| 1822 | typename ContainerType::const_iterator begin() const { | ||||
| 1823 | return set_copy_.begin(); | ||||
| 1824 | } | ||||
| 1825 | |||||
| 1826 | typename ContainerType::const_iterator end() const { return set_copy_.end(); } | ||||
| 1827 | |||||
| 1828 | bool empty() const { return set_copy_.empty(); } | ||||
| 1829 | size_t size() const { return set_copy_.size(); } | ||||
| 1830 | |||||
| 1831 | protected: | ||||
| 1832 | ContainerType set_copy_; | ||||
| 1833 | }; | ||||
| 1834 | |||||
| 1835 | class SafeTapeSet | ||||
| 1836 | : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> { | ||||
| 1837 | public: | ||||
| 1838 | SafeTapeSet() | ||||
| 1839 | : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>( | ||||
| 1840 | *GetTapeSet()) {} | ||||
| 1841 | }; | ||||
| 1842 | |||||
| 1843 | class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> { | ||||
| 1844 | public: | ||||
| 1845 | SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {} | ||||
| 1846 | |||||
| 1847 | typename AccumulatorSet::const_reverse_iterator rbegin() const { | ||||
| 1848 | return set_copy_.rbegin(); | ||||
| 1849 | } | ||||
| 1850 | |||||
| 1851 | typename AccumulatorSet::const_reverse_iterator rend() const { | ||||
| 1852 | return set_copy_.rend(); | ||||
| 1853 | } | ||||
| 1854 | }; | ||||
| 1855 | |||||
| 1856 | class SafeVariableWatcherSet | ||||
| 1857 | : public SafeSetCopy< | ||||
| 1858 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> { | ||||
| 1859 | public: | ||||
| 1860 | SafeVariableWatcherSet() | ||||
| 1861 | : SafeSetCopy< | ||||
| 1862 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>( | ||||
| 1863 | *GetVariableWatcherSet()) {} | ||||
| 1864 | }; | ||||
| 1865 | |||||
| 1866 | bool* ThreadTapeIsStopped() { | ||||
| 1867 | thread_local bool thread_tape_is_stopped{false}; | ||||
| 1868 | return &thread_tape_is_stopped; | ||||
| 1869 | } | ||||
| 1870 | |||||
| 1871 | void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; } | ||||
| 1872 | |||||
| 1873 | void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; } | ||||
| 1874 | |||||
| 1875 | PyObject* TFE_Py_TapeSetIsStopped() { | ||||
| 1876 | if (*ThreadTapeIsStopped()) { | ||||
| 1877 |     Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct);  | ||||
| 1878 | } | ||||
| 1879 |   Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct);  | ||||
| 1880 | } | ||||
| 1881 | |||||
| 1882 | PyObject* TFE_Py_TapeSetNew(PyObject* persistent, | ||||
| 1883 | PyObject* watch_accessed_variables) { | ||||
| 1884 | TFE_Py_Tape_Type.tp_new = PyType_GenericNew; | ||||
| 1885 | if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; | ||||
| 1886 |   TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type)( (TFE_Py_Tape *) PyObject_Init( (PyObject *) PyObject_Malloc ( ( (&TFE_Py_Tape_Type)->tp_basicsize ) ), (&TFE_Py_Tape_Type )) );  | ||||
| 1887 | tape->tape = new GradientTape(persistent == Py_True((PyObject *) &_Py_TrueStruct), | ||||
| 1888 | watch_accessed_variables == Py_True((PyObject *) &_Py_TrueStruct)); | ||||
| 1889 | Py_INCREF(tape)_Py_INCREF(((PyObject*)(tape))); | ||||
| 1890 | tape->nesting_id = tape_nesting_id_counter.fetch_add(1); | ||||
| 1891 | GetTapeSet()->insert(tape); | ||||
| 1892 | return reinterpret_cast<PyObject*>(tape); | ||||
| 1893 | } | ||||
| 1894 | |||||
| 1895 | void TFE_Py_TapeSetAdd(PyObject* tape) { | ||||
| 1896 | Py_INCREF(tape)_Py_INCREF(((PyObject*)(tape))); | ||||
| 1897 | TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape); | ||||
| 1898 | if (!GetTapeSet()->insert(tfe_tape).second) { | ||||
| 1899 | // Already exists in the tape set. | ||||
| 1900 | Py_DECREF(tape)_Py_DECREF(((PyObject*)(tape))); | ||||
| 1901 | } else { | ||||
| 1902 | tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1); | ||||
| 1903 | } | ||||
| 1904 | } | ||||
| 1905 | |||||
| 1906 | PyObject* TFE_Py_TapeSetIsEmpty() { | ||||
| 1907 | if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) { | ||||
| 1908 |     Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct);  | ||||
| 1909 | } | ||||
| 1910 |   Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct);  | ||||
| 1911 | } | ||||
| 1912 | |||||
| 1913 | void TFE_Py_TapeSetRemove(PyObject* tape) { | ||||
| 1914 | auto* stack = GetTapeSet(); | ||||
| 1915 | stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape)); | ||||
| 1916 | // We kept a reference to the tape in the set to ensure it wouldn't get | ||||
| 1917 | // deleted under us; cleaning it up here. | ||||
| 1918 | Py_DECREF(tape)_Py_DECREF(((PyObject*)(tape))); | ||||
| 1919 | } | ||||
| 1920 | |||||
| 1921 | static std::vector<int64_t> MakeIntList(PyObject* list) { | ||||
| 1922 | if (list == Py_None(&_Py_NoneStruct)) { | ||||
| 1923 | return {}; | ||||
| 1924 | } | ||||
| 1925 | PyObject* seq = PySequence_Fast(list, "expected a sequence"); | ||||
| 1926 | if (seq == nullptr) { | ||||
| 1927 | return {}; | ||||
| 1928 | } | ||||
| 1929 | int len = PySequence_Size(list); | ||||
| 1930 |   PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item);  | ||||
| 1931 | std::vector<int64_t> tensor_ids; | ||||
| 1932 | tensor_ids.reserve(len); | ||||
| 1933 | for (int i = 0; i < len; ++i) { | ||||
| 1934 | PyObject* item = seq_array[i]; | ||||
| 1935 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 1936 |     if (PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL << 24))) != 0)) {  | ||||
| 1937 | #else | ||||
| 1938 |     if (PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL << 24))) != 0) || PyInt_Check(item)) {  | ||||
| 1939 | #endif | ||||
| 1940 | int64_t id = MakeInt(item); | ||||
| 1941 | tensor_ids.push_back(id); | ||||
| 1942 | } else { | ||||
| 1943 | tensor_ids.push_back(-1); | ||||
| 1944 | } | ||||
| 1945 | } | ||||
| 1946 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | ||||
| 1947 | return tensor_ids; | ||||
| 1948 | } | ||||
| 1949 | |||||
| 1950 | // Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be | ||||
| 1951 | // null. Returns true on success and false on a Python exception. | ||||
| 1952 | bool TensorShapesAndDtypes(PyObject* tensors, std::vector<int64_t>* tensor_ids, | ||||
| 1953 | std::vector<tensorflow::DataType>* dtypes) { | ||||
| 1954 | tensorflow::Safe_PyObjectPtr seq( | ||||
| 1955 | PySequence_Fast(tensors, "expected a sequence")); | ||||
| 1956 | if (seq == nullptr) { | ||||
| 1957 | return false; | ||||
| 1958 | } | ||||
| 1959 |   int len = PySequence_Fast_GET_SIZE(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(seq.get()))->ob_size)) : (((PyVarObject* )(((static_cast<void> (0)), (PyTupleObject *)(seq.get() ))))->ob_size));  | ||||
| 1960 |   PyObject** seq_array = PySequence_Fast_ITEMS(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq.get()))-> ob_item : ((PyTupleObject *)(seq.get()))->ob_item);  | ||||
| 1961 | tensor_ids->reserve(len); | ||||
| 1962 | dtypes->reserve(len); | ||||
| 1963 | for (int i = 0; i < len; ++i) { | ||||
| 1964 | PyObject* item = seq_array[i]; | ||||
| 1965 | tensor_ids->push_back(FastTensorId(item)); | ||||
| 1966 | dtypes->push_back(tensorflow::PyTensor_DataType(item)); | ||||
| 1967 | } | ||||
| 1968 | return true; | ||||
| 1969 | } | ||||
| 1970 | |||||
| 1971 | bool TapeCouldPossiblyRecord(PyObject* tensors) { | ||||
| 1972 | if (tensors == Py_None(&_Py_NoneStruct)) { | ||||
| 1973 | return false; | ||||
| 1974 | } | ||||
| 1975 | if (*ThreadTapeIsStopped()) { | ||||
| 1976 | return false; | ||||
| 1977 | } | ||||
| 1978 | if (!HasAccumulatorOrTape()) { | ||||
| 1979 | return false; | ||||
| 1980 | } | ||||
| 1981 | return true; | ||||
| 1982 | } | ||||
| 1983 | |||||
| 1984 | bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); } | ||||
| 1985 | |||||
| 1986 | bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); } | ||||
| 1987 | |||||
| 1988 | PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) { | ||||
| 1989 | if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) { | ||||
| 1990 |     Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct);  | ||||
| 1991 | } | ||||
| 1992 | // TODO(apassos) consider not building a list and changing the API to check | ||||
| 1993 | // each tensor individually. | ||||
| 1994 | std::vector<int64_t> tensor_ids; | ||||
| 1995 | std::vector<tensorflow::DataType> dtypes; | ||||
| 1996 | if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) { | ||||
| 1997 | return nullptr; | ||||
| 1998 | } | ||||
| 1999 | auto tape_set = *GetTapeSet(); | ||||
| 2000 | for (TFE_Py_Tape* tape : tape_set) { | ||||
| 2001 | if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { | ||||
| 2002 |       Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct);  | ||||
| 2003 | } | ||||
| 2004 | } | ||||
| 2005 | |||||
| 2006 |   Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct);  | ||||
| 2007 | } | ||||
| 2008 | |||||
| 2009 | PyObject* TFE_Py_ForwardAccumulatorPushState() { | ||||
| 2010 | auto forward_accumulators = *GetAccumulatorSet(); | ||||
| 2011 | for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { | ||||
| 2012 | accumulator->accumulator->PushState(); | ||||
| 2013 | } | ||||
| 2014 |   Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 2015 | } | ||||
| 2016 | |||||
| 2017 | PyObject* TFE_Py_ForwardAccumulatorPopState() { | ||||
| 2018 | auto forward_accumulators = *GetAccumulatorSet(); | ||||
| 2019 | for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { | ||||
| 2020 | accumulator->accumulator->PopState(); | ||||
| 2021 | } | ||||
| 2022 |   Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 2023 | } | ||||
| 2024 | |||||
| 2025 | PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) { | ||||
| 2026 | if (!TapeCouldPossiblyRecord(tensors)) { | ||||
| 2027 | return GetPythonObjectFromInt(0); | ||||
| 2028 | } | ||||
| 2029 | std::vector<int64_t> tensor_ids; | ||||
| 2030 | std::vector<tensorflow::DataType> dtypes; | ||||
| 2031 | if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) { | ||||
| 2032 | return nullptr; | ||||
| 2033 | } | ||||
| 2034 | |||||
| 2035 | // If there is a persistent tape watching, or if there are multiple tapes | ||||
| 2036 | // watching, we'll return immediately indicating that higher-order tape | ||||
| 2037 | // gradients are possible. | ||||
| 2038 | bool some_tape_watching = false; | ||||
| 2039 | if (CouldBackprop()) { | ||||
| 2040 | auto tape_set = *GetTapeSet(); | ||||
| 2041 | for (TFE_Py_Tape* tape : tape_set) { | ||||
| 2042 | if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { | ||||
| 2043 | if (tape->tape->IsPersistent() || some_tape_watching) { | ||||
| 2044 | // Either this is the second tape watching, or this tape is | ||||
| 2045 | // persistent: higher-order gradients are possible. | ||||
| 2046 | return GetPythonObjectFromInt(2); | ||||
| 2047 | } | ||||
| 2048 | some_tape_watching = true; | ||||
| 2049 | } | ||||
| 2050 | } | ||||
| 2051 | } | ||||
| 2052 | if (CouldForwardprop()) { | ||||
| 2053 | auto forward_accumulators = *GetAccumulatorSet(); | ||||
| 2054 | for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { | ||||
| 2055 | if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) { | ||||
| 2056 | if (some_tape_watching) { | ||||
| 2057 | // This is the second tape watching: higher-order gradients are | ||||
| 2058 | // possible. Note that there's no equivalent of persistence for | ||||
| 2059 | // forward-mode. | ||||
| 2060 | return GetPythonObjectFromInt(2); | ||||
| 2061 | } | ||||
| 2062 | some_tape_watching = true; | ||||
| 2063 | } | ||||
| 2064 | } | ||||
| 2065 | } | ||||
| 2066 | if (some_tape_watching) { | ||||
| 2067 | // There's exactly one non-persistent tape. The user can request first-order | ||||
| 2068 | // gradients but won't be able to get higher-order tape gradients. | ||||
| 2069 | return GetPythonObjectFromInt(1); | ||||
| 2070 | } else { | ||||
| 2071 | // There are no tapes. The user can't request tape gradients. | ||||
| 2072 | return GetPythonObjectFromInt(0); | ||||
| 2073 | } | ||||
| 2074 | } | ||||
| 2075 | |||||
| 2076 | void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { | ||||
| 2077 | if (!CouldBackprop()) { | ||||
| 2078 | return; | ||||
| 2079 | } | ||||
| 2080 | int64_t tensor_id = FastTensorId(tensor); | ||||
| 2081 | if (PyErr_Occurred()) { | ||||
| 2082 | return; | ||||
| 2083 | } | ||||
| 2084 | reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id); | ||||
| 2085 | } | ||||
| 2086 | |||||
| 2087 | bool ListContainsNone(PyObject* list) { | ||||
| 2088 | if (list == Py_None(&_Py_NoneStruct)) return true; | ||||
| 2089 | tensorflow::Safe_PyObjectPtr seq( | ||||
| 2090 | PySequence_Fast(list, "expected a sequence")); | ||||
| 2091 | if (seq == nullptr) { | ||||
| 2092 | return false; | ||||
| 2093 | } | ||||
| 2094 | |||||
| 2095 | int len = PySequence_Size(list); | ||||
| 2096 |   PyObject** seq_array = PySequence_Fast_ITEMS(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq.get()))-> ob_item : ((PyTupleObject *)(seq.get()))->ob_item);  | ||||
| 2097 | for (int i = 0; i < len; ++i) { | ||||
| 2098 | PyObject* item = seq_array[i]; | ||||
| 2099 | if (item == Py_None(&_Py_NoneStruct)) return true; | ||||
| 2100 | } | ||||
| 2101 | |||||
| 2102 | return false; | ||||
| 2103 | } | ||||
| 2104 | |||||
| 2105 | // As an optimization, the tape generally keeps only the shape and dtype of | ||||
| 2106 | // tensors, and uses this information to generate ones/zeros tensors. However, | ||||
| 2107 | // some tensors require OnesLike/ZerosLike because their gradients do not match | ||||
| 2108 | // their inference shape/dtype. | ||||
| 2109 | bool DTypeNeedsHandleData(tensorflow::DataType dtype) { | ||||
| 2110 | return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE; | ||||
| 2111 | } | ||||
| 2112 | |||||
| 2113 | static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { | ||||
| 2114 | if (EagerTensor_CheckExact(tensor)) { | ||||
| 2115 | tensorflow::ImmediateExecutionTensorHandle* handle = | ||||
| 2116 | tensorflow::unwrap(EagerTensor_Handle(tensor)); | ||||
| 2117 | int64_t id = PyEagerTensor_ID(tensor); | ||||
| 2118 | tensorflow::DataType dtype = | ||||
| 2119 | static_cast<tensorflow::DataType>(handle->DataType()); | ||||
| 2120 | if (DTypeNeedsHandleData(dtype)) { | ||||
| 2121 | return PyTapeTensor(id, dtype, tensor); | ||||
| 2122 | } | ||||
| 2123 | |||||
| 2124 | tensorflow::TensorShape tensor_shape; | ||||
| 2125 | int num_dims; | ||||
| 2126 | tensorflow::Status status = handle->NumDims(&num_dims); | ||||
| 2127 | if (status.ok()) { | ||||
| 2128 | for (int i = 0; i < num_dims; ++i) { | ||||
| 2129 | int64_t dim_size; | ||||
| 2130 | status = handle->Dim(i, &dim_size); | ||||
| 2131 | if (!status.ok()) break; | ||||
| 2132 | tensor_shape.AddDim(dim_size); | ||||
| 2133 | } | ||||
| 2134 | } | ||||
| 2135 | |||||
| 2136 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { | ||||
| 2137 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), | ||||
| 2138 | tensorflow::TensorShape({})); | ||||
| 2139 | } else { | ||||
| 2140 | return PyTapeTensor(id, dtype, tensor_shape); | ||||
| 2141 | } | ||||
| 2142 | } | ||||
| 2143 | int64_t id = FastTensorId(tensor); | ||||
| 2144 | if (PyErr_Occurred()) { | ||||
| 2145 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), | ||||
| 2146 | tensorflow::TensorShape({})); | ||||
| 2147 | } | ||||
| 2148 | PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype"); | ||||
| 2149 | PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum"); | ||||
| 2150 | Py_DECREF(dtype_object)_Py_DECREF(((PyObject*)(dtype_object))); | ||||
| 2151 | tensorflow::DataType dtype = | ||||
| 2152 | static_cast<tensorflow::DataType>(MakeInt(dtype_enum)); | ||||
| 2153 | Py_DECREF(dtype_enum)_Py_DECREF(((PyObject*)(dtype_enum))); | ||||
| 2154 | if (PyErr_Occurred()) { | ||||
| 2155 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), | ||||
| 2156 | tensorflow::TensorShape({})); | ||||
| 2157 | } | ||||
| 2158 | static char _shape_tuple[] = "_shape_tuple"; | ||||
| 2159 | tensorflow::Safe_PyObjectPtr shape_tuple( | ||||
| 2160 | PyObject_CallMethod(tensor, _shape_tuple, nullptr)); | ||||
| 2161 | if (PyErr_Occurred()) { | ||||
| 2162 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), | ||||
| 2163 | tensorflow::TensorShape({})); | ||||
| 2164 | } | ||||
| 2165 | |||||
| 2166 | if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) { | ||||
| 2167 | return PyTapeTensor(id, dtype, tensor); | ||||
| 2168 | } | ||||
| 2169 | |||||
| 2170 | auto l = MakeIntList(shape_tuple.get()); | ||||
| 2171 | // Replace -1, which represents accidental Nones which can occur in graph mode | ||||
| 2172 | // and can cause errors in shape construction with 0s. | ||||
| 2173 | for (auto& c : l) { | ||||
| 2174 | if (c < 0) { | ||||
| 2175 | c = 0; | ||||
| 2176 | } | ||||
| 2177 | } | ||||
| 2178 | tensorflow::TensorShape shape(l); | ||||
| 2179 | return PyTapeTensor(id, dtype, shape); | ||||
| 2180 | } | ||||
| 2181 | |||||
| 2182 | // Populates output_info from output_seq, which must come from PySequence_Fast. | ||||
| 2183 | // | ||||
| 2184 | // Does not take ownership of output_seq. Returns true on success and false if a | ||||
| 2185 | // Python exception has been set. | ||||
| 2186 | bool TapeTensorsFromTensorSequence(PyObject* output_seq, | ||||
| 2187 | std::vector<PyTapeTensor>* output_info) { | ||||
| 2188 |   Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq)(((((((PyObject*)(output_seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(output_seq))->ob_size)) : (((PyVarObject *)(((static_cast<void> (0)), (PyTupleObject *)(output_seq ))))->ob_size));  | ||||
| 2189 |   PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq)(((((((PyObject*)(output_seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(output_seq))-> ob_item : ((PyTupleObject *)(output_seq))->ob_item);  | ||||
| 2190 | output_info->reserve(output_len); | ||||
| 2191 | for (Py_ssize_t i = 0; i < output_len; ++i) { | ||||
| 2192 | output_info->push_back(TapeTensorFromTensor(output_seq_array[i])); | ||||
| 2193 | if (PyErr_Occurred() != nullptr) { | ||||
| 2194 | return false; | ||||
| 2195 | } | ||||
| 2196 | } | ||||
| 2197 | return true; | ||||
| 2198 | } | ||||
| 2199 | |||||
| 2200 | std::vector<int64_t> MakeTensorIDList(PyObject* tensors) { | ||||
| 2201 | PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); | ||||
| 2202 | if (seq == nullptr) { | ||||
| 2203 | return {}; | ||||
| 2204 | } | ||||
| 2205 |   int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size));  | ||||
| 2206 |   PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item);  | ||||
| 2207 | std::vector<int64_t> list; | ||||
| 2208 | list.reserve(len); | ||||
| 2209 | for (int i = 0; i < len; ++i) { | ||||
| 2210 | PyObject* tensor = seq_array[i]; | ||||
| 2211 | list.push_back(FastTensorId(tensor)); | ||||
| 2212 | if (PyErr_Occurred()) { | ||||
| 2213 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | ||||
| 2214 | return list; | ||||
| 2215 | } | ||||
| 2216 | } | ||||
| 2217 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | ||||
| 2218 | return list; | ||||
| 2219 | } | ||||
| 2220 | |||||
| 2221 | void TFE_Py_TapeVariableAccessed(PyObject* variable) { | ||||
| 2222 | if (!CouldBackprop()) { | ||||
| 2223 | return; | ||||
| 2224 | } | ||||
| 2225 | for (TFE_Py_Tape* tape : SafeTapeSet()) { | ||||
| 2226 | tape->tape->VariableAccessed(variable); | ||||
| 2227 | } | ||||
| 2228 | } | ||||
| 2229 | |||||
| 2230 | void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) { | ||||
| 2231 | if (!CouldBackprop()) { | ||||
| 2232 | return; | ||||
| 2233 | } | ||||
| 2234 | reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable); | ||||
| 2235 | } | ||||
| 2236 | |||||
| 2237 | PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { | ||||
| 2238 | return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple(); | ||||
| 2239 | } | ||||
| 2240 | |||||
| 2241 | PyObject* TFE_Py_VariableWatcherNew() { | ||||
| 2242 | TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew; | ||||
| 2243 | if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr; | ||||
| 2244 | TFE_Py_VariableWatcher* variable_watcher = | ||||
| 2245 |       PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type)( (TFE_Py_VariableWatcher *) PyObject_Init( (PyObject *) PyObject_Malloc ( ( (&TFE_Py_VariableWatcher_Type)->tp_basicsize ) ), ( &TFE_Py_VariableWatcher_Type)) );  | ||||
| 2246 | variable_watcher->variable_watcher = new VariableWatcher(); | ||||
| 2247 | Py_INCREF(variable_watcher)_Py_INCREF(((PyObject*)(variable_watcher))); | ||||
| 2248 | GetVariableWatcherSet()->insert(variable_watcher); | ||||
| 2249 | return reinterpret_cast<PyObject*>(variable_watcher); | ||||
| 2250 | } | ||||
| 2251 | |||||
| 2252 | void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) { | ||||
| 2253 | auto* stack = GetVariableWatcherSet(); | ||||
| 2254 | stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)); | ||||
| 2255 | // We kept a reference to the variable watcher in the set to ensure it | ||||
| 2256 | // wouldn't get deleted under us; cleaning it up here. | ||||
| 2257 | Py_DECREF(variable_watcher)_Py_DECREF(((PyObject*)(variable_watcher))); | ||||
| 2258 | } | ||||
| 2259 | |||||
| 2260 | void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) { | ||||
| 2261 | for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) { | ||||
| 2262 | variable_watcher->variable_watcher->WatchVariable(variable); | ||||
| 2263 | } | ||||
| 2264 | } | ||||
| 2265 | |||||
| 2266 | PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) { | ||||
| 2267 | return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher) | ||||
| 2268 | ->variable_watcher->GetVariablesAsPyTuple(); | ||||
| 2269 | } | ||||
| 2270 | |||||
| 2271 | namespace { | ||||
| 2272 | std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) { | ||||
| 2273 | PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); | ||||
| 2274 | if (seq == nullptr) { | ||||
| 2275 | return {}; | ||||
| 2276 | } | ||||
| 2277 |   int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size));  | ||||
| 2278 |   PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item);  | ||||
| 2279 | std::vector<tensorflow::DataType> list; | ||||
| 2280 | list.reserve(len); | ||||
| 2281 | for (int i = 0; i < len; ++i) { | ||||
| 2282 | PyObject* tensor = seq_array[i]; | ||||
| 2283 | list.push_back(tensorflow::PyTensor_DataType(tensor)); | ||||
| 2284 | } | ||||
| 2285 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | ||||
| 2286 | return list; | ||||
| 2287 | } | ||||
| 2288 | |||||
| 2289 | PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id, | ||||
| 2290 | PyObject* weak_tensor_ref) { | ||||
| 2291 | int64_t parsed_tensor_id = MakeInt(tensor_id); | ||||
| 2292 | for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) { | ||||
| 2293 | accumulator->accumulator->DeleteGradient(parsed_tensor_id); | ||||
| 2294 | } | ||||
| 2295 | Py_DECREF(weak_tensor_ref)_Py_DECREF(((PyObject*)(weak_tensor_ref))); | ||||
| 2296 | Py_DECREF(tensor_id)_Py_DECREF(((PyObject*)(tensor_id))); | ||||
| 2297 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
| 2298 | return Py_None(&_Py_NoneStruct); | ||||
| 2299 | } | ||||
| 2300 | |||||
| 2301 | static PyMethodDef forward_accumulator_delete_gradient_method_def = { | ||||
| 2302 | "ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient, | ||||
| 2303 | METH_O0x0008, "ForwardAccumulatorDeleteGradient"}; | ||||
| 2304 | |||||
| 2305 | void RegisterForwardAccumulatorCleanup(PyObject* tensor, int64_t tensor_id) { | ||||
| 2306 | tensorflow::Safe_PyObjectPtr callback( | ||||
| 2307 |       PyCFunction_New(&forward_accumulator_delete_gradient_method_def,PyCFunction_NewEx((&forward_accumulator_delete_gradient_method_def ), (PyLong_FromLong(tensor_id)), __null)  | ||||
| 2308 |                       PyLong_FromLong(tensor_id))PyCFunction_NewEx((&forward_accumulator_delete_gradient_method_def ), (PyLong_FromLong(tensor_id)), __null));  | ||||
| 2309 | // We need to keep a reference to the weakref active if we want our callback | ||||
| 2310 | // called. The callback itself now owns the weakref object and the tensor ID | ||||
| 2311 | // object. | ||||
| 2312 | PyWeakref_NewRef(tensor, callback.get()); | ||||
| 2313 | } | ||||
| 2314 | |||||
| 2315 | void TapeSetRecordBackprop( | ||||
| 2316 | const string& op_type, const std::vector<PyTapeTensor>& output_info, | ||||
| 2317 | const std::vector<int64_t>& input_ids, | ||||
| 2318 | const std::vector<tensorflow::DataType>& input_dtypes, | ||||
| 2319 | const std::function<PyBackwardFunction*()>& backward_function_getter, | ||||
| 2320 | const std::function<void(PyBackwardFunction*)>& backward_function_killer, | ||||
| 2321 | tensorflow::uint64 max_gradient_tape_id) { | ||||
| 2322 | if (!CouldBackprop()) { | ||||
| 2323 | return; | ||||
| 2324 | } | ||||
| 2325 | for (TFE_Py_Tape* tape : SafeTapeSet()) { | ||||
| 2326 | if (tape->nesting_id < max_gradient_tape_id) { | ||||
| 2327 | tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes, | ||||
| 2328 | backward_function_getter, | ||||
| 2329 | backward_function_killer); | ||||
| 2330 | } | ||||
| 2331 | } | ||||
| 2332 | } | ||||
| 2333 | |||||
| 2334 | bool TapeSetRecordForwardprop( | ||||
| 2335 | const string& op_type, PyObject* output_seq, | ||||
| 2336 | const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors, | ||||
| 2337 | const std::vector<int64_t>& input_ids, | ||||
| 2338 | const std::vector<tensorflow::DataType>& input_dtypes, | ||||
| 2339 | const std::function<PyBackwardFunction*()>& backward_function_getter, | ||||
| 2340 | const std::function<void(PyBackwardFunction*)>& backward_function_killer, | ||||
| 2341 | const tensorflow::eager::ForwardFunction<PyObject>* forward_function, | ||||
| 2342 | PyObject* forwardprop_output_indices, | ||||
| 2343 | tensorflow::uint64* max_gradient_tape_id) { | ||||
| 2344 | *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max(); | ||||
| 2345 | if (!CouldForwardprop()) { | ||||
| 2346 | return true; | ||||
| 2347 | } | ||||
| 2348 | auto accumulator_set = SafeAccumulatorSet(); | ||||
| 2349 | tensorflow::Safe_PyObjectPtr input_seq( | ||||
| 2350 | PySequence_Fast(input_tensors, "expected a sequence of tensors")); | ||||
| 2351 | if (input_seq == nullptr || PyErr_Occurred()) return false; | ||||
| 2352 |   Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get())(((((((PyObject*)(input_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(input_seq.get()))->ob_size)) : ((( PyVarObject*)(((static_cast<void> (0)), (PyTupleObject * )(input_seq.get()))))->ob_size));  | ||||
| 2353 |   PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq)(((((((PyObject*)(output_seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(output_seq))-> ob_item : ((PyTupleObject *)(output_seq))->ob_item);  | ||||
| 2354 | for (int i = 0; i < output_info.size(); ++i) { | ||||
| 2355 | RegisterForwardAccumulatorCleanup(output_seq_array[i], | ||||
| 2356 | output_info[i].GetID()); | ||||
| 2357 | } | ||||
| 2358 | if (forwardprop_output_indices != nullptr && | ||||
| 2359 | forwardprop_output_indices != Py_None(&_Py_NoneStruct)) { | ||||
| 2360 | tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast( | ||||
| 2361 | forwardprop_output_indices, "Expected a sequence of indices")); | ||||
| 2362 | if (indices_fast == nullptr || PyErr_Occurred()) { | ||||
| 2363 | return false; | ||||
| 2364 | } | ||||
| 2365 |     if (PySequence_Fast_GET_SIZE(indices_fast.get())(((((((PyObject*)(indices_fast.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(indices_fast.get()))->ob_size)) : ( ((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(indices_fast.get()))))->ob_size)) !=  | ||||
| 2366 | accumulator_set.size()) { | ||||
| 2367 | MaybeRaiseExceptionFromStatus( | ||||
| 2368 | tensorflow::errors::Internal( | ||||
| 2369 | "Accumulators were added or removed from the active set " | ||||
| 2370 | "between packing and unpacking."), | ||||
| 2371 | nullptr); | ||||
| 2372 | } | ||||
| 2373 |     PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get())(((((((PyObject*)(indices_fast.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(indices_fast .get()))->ob_item : ((PyTupleObject *)(indices_fast.get()) )->ob_item);  | ||||
| 2374 | Py_ssize_t accumulator_index = 0; | ||||
| 2375 | for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin(); | ||||
| 2376 | it != accumulator_set.rend(); ++it, ++accumulator_index) { | ||||
| 2377 | tensorflow::Safe_PyObjectPtr jvp_index_seq( | ||||
| 2378 | PySequence_Fast(indices_fast_array[accumulator_index], | ||||
| 2379 | "Expected a sequence of jvp indices.")); | ||||
| 2380 | if (jvp_index_seq == nullptr || PyErr_Occurred()) { | ||||
| 2381 | return false; | ||||
| 2382 | } | ||||
| 2383 |       Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get())(((((((PyObject*)(jvp_index_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(jvp_index_seq.get()))->ob_size)) : (((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(jvp_index_seq.get()))))->ob_size));  | ||||
| 2384 | PyObject** jvp_index_seq_array = | ||||
| 2385 |           PySequence_Fast_ITEMS(jvp_index_seq.get())(((((((PyObject*)(jvp_index_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(jvp_index_seq .get()))->ob_item : ((PyTupleObject *)(jvp_index_seq.get() ))->ob_item);  | ||||
| 2386 | for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) { | ||||
| 2387 | PyObject* tuple = jvp_index_seq_array[jvp_index]; | ||||
| 2388 | int64_t primal_tensor_id = | ||||
| 2389 | output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID(); | ||||
| 2390 | (*it)->accumulator->Watch( | ||||
| 2391 | primal_tensor_id, | ||||
| 2392 | output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]); | ||||
| 2393 | } | ||||
| 2394 | } | ||||
| 2395 | } else { | ||||
| 2396 | std::vector<PyTapeTensor> input_info; | ||||
| 2397 | input_info.reserve(input_len); | ||||
| 2398 |     PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get())(((((((PyObject*)(input_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(input_seq .get()))->ob_item : ((PyTupleObject *)(input_seq.get()))-> ob_item);  | ||||
| 2399 | for (Py_ssize_t i = 0; i < input_len; ++i) { | ||||
| 2400 | input_info.push_back(TapeTensorFromTensor(input_seq_array[i])); | ||||
| 2401 | } | ||||
| 2402 | for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) { | ||||
| 2403 | tensorflow::Status status = accumulator->accumulator->Accumulate( | ||||
| 2404 | op_type, input_info, output_info, input_ids, input_dtypes, | ||||
| 2405 | forward_function, backward_function_getter, backward_function_killer); | ||||
| 2406 | if (PyErr_Occurred()) return false; // Don't swallow Python exceptions. | ||||
| 2407 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { | ||||
| 2408 | return false; | ||||
| 2409 | } | ||||
| 2410 | if (accumulator->accumulator->BusyAccumulating()) { | ||||
| 2411 | // Ensure inner accumulators don't see outer accumulators' jvps. This | ||||
| 2412 | // mostly happens on its own, with some potentially surprising | ||||
| 2413 | // exceptions, so the blanket policy is for consistency. | ||||
| 2414 | *max_gradient_tape_id = accumulator->nesting_id; | ||||
| 2415 | break; | ||||
| 2416 | } | ||||
| 2417 | } | ||||
| 2418 | } | ||||
| 2419 | return true; | ||||
| 2420 | } | ||||
| 2421 | |||||
| 2422 | PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) { | ||||
| 2423 | PyObject* py_input_tangents = PyTuple_New(input_tangents.size()); | ||||
| 2424 | for (int i = 0; i < input_tangents.size(); ++i) { | ||||
| 2425 | PyObject* element; | ||||
| 2426 | if (input_tangents[i] == nullptr) { | ||||
| 2427 | element = Py_None(&_Py_NoneStruct); | ||||
| 2428 | } else { | ||||
| 2429 | element = input_tangents[i]; | ||||
| 2430 | } | ||||
| 2431 | Py_INCREF(element)_Py_INCREF(((PyObject*)(element))); | ||||
| 2432 | PyTuple_SET_ITEM(py_input_tangents, i, element)PyTuple_SetItem(py_input_tangents, i, element); | ||||
| 2433 | } | ||||
| 2434 | return py_input_tangents; | ||||
| 2435 | } | ||||
| 2436 | |||||
| 2437 | tensorflow::Status ParseTangentOutputs( | ||||
| 2438 | PyObject* user_output, std::vector<PyObject*>* output_tangents) { | ||||
| 2439 | if (user_output == Py_None(&_Py_NoneStruct)) { | ||||
| 2440 | // No connected gradients. | ||||
| 2441 | return tensorflow::Status::OK(); | ||||
| 2442 | } | ||||
| 2443 | tensorflow::Safe_PyObjectPtr fast_result( | ||||
| 2444 | PySequence_Fast(user_output, "expected a sequence of forward gradients")); | ||||
| 2445 | if (fast_result == nullptr) { | ||||
| 2446 | return tensorflow::errors::InvalidArgument( | ||||
| 2447 | "forward gradient function did not return a sequence."); | ||||
| 2448 | } | ||||
| 2449 |   int len = PySequence_Fast_GET_SIZE(fast_result.get())(((((((PyObject*)(fast_result.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_result.get()))->ob_size)) : ( ((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(fast_result.get()))))->ob_size));  | ||||
| 2450 |   PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get())(((((((PyObject*)(fast_result.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_result .get()))->ob_item : ((PyTupleObject *)(fast_result.get())) ->ob_item);  | ||||
| 2451 | output_tangents->reserve(len); | ||||
| 2452 | for (int i = 0; i < len; ++i) { | ||||
| 2453 | PyObject* item = fast_result_array[i]; | ||||
| 2454 | if (item == Py_None(&_Py_NoneStruct)) { | ||||
| 2455 | output_tangents->push_back(nullptr); | ||||
| 2456 | } else { | ||||
| 2457 | Py_INCREF(item)_Py_INCREF(((PyObject*)(item))); | ||||
| 2458 | output_tangents->push_back(item); | ||||
| 2459 | } | ||||
| 2460 | } | ||||
| 2461 | return tensorflow::Status::OK(); | ||||
| 2462 | } | ||||
| 2463 | |||||
| 2464 | // Calls the registered forward_gradient_function, computing `output_tangents` | ||||
| 2465 | // from `input_tangents`. `output_tangents` must not be null. | ||||
| 2466 | // | ||||
| 2467 | // `op_name`, `attrs`, `inputs`, and `results` describe the operation for which | ||||
| 2468 | // the forward function is being called. | ||||
| 2469 | tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs, | ||||
| 2470 | PyObject* inputs, PyObject* results, | ||||
| 2471 | const std::vector<PyObject*>& input_tangents, | ||||
| 2472 | std::vector<PyObject*>* output_tangents, | ||||
| 2473 | bool use_batch) { | ||||
| 2474 | if (forward_gradient_function == nullptr) { | ||||
| 2475 | return tensorflow::errors::Internal( | ||||
| 2476 | "No forward gradient function registered."); | ||||
| 2477 | } | ||||
| 2478 | tensorflow::Safe_PyObjectPtr py_input_tangents( | ||||
| 2479 | TangentsAsPyTuple(input_tangents)); | ||||
| 2480 | |||||
| 2481 | // Normalize the input sequence to a tuple so it works with function | ||||
| 2482 | // caching; otherwise it may be an opaque _InputList object. | ||||
| 2483 | tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs)); | ||||
| 2484 | PyObject* to_batch = (use_batch) ? Py_True((PyObject *) &_Py_TrueStruct) : Py_False((PyObject *) &_Py_FalseStruct); | ||||
| 2485 | tensorflow::Safe_PyObjectPtr callback_args( | ||||
| 2486 | Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results, | ||||
| 2487 | py_input_tangents.get(), to_batch)); | ||||
| 2488 | tensorflow::Safe_PyObjectPtr py_result( | ||||
| 2489 | PyObject_CallObject(forward_gradient_function, callback_args.get())); | ||||
| 2490 | if (py_result == nullptr || PyErr_Occurred()) { | ||||
| 2491 | return tensorflow::errors::Internal( | ||||
| 2492 | "forward gradient function threw exceptions"); | ||||
| 2493 | } | ||||
| 2494 | return ParseTangentOutputs(py_result.get(), output_tangents); | ||||
| 2495 | } | ||||
| 2496 | |||||
| 2497 | // Like CallJVPFunction, but calls a pre-bound forward function. | ||||
| 2498 | // These are passed in from a record_gradient argument. | ||||
| 2499 | tensorflow::Status CallOpSpecificJVPFunction( | ||||
| 2500 | PyObject* op_specific_forward_function, | ||||
| 2501 | const std::vector<PyObject*>& input_tangents, | ||||
| 2502 | std::vector<PyObject*>* output_tangents) { | ||||
| 2503 | tensorflow::Safe_PyObjectPtr py_input_tangents( | ||||
| 2504 | TangentsAsPyTuple(input_tangents)); | ||||
| 2505 | |||||
| 2506 | tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject( | ||||
| 2507 | op_specific_forward_function, py_input_tangents.get())); | ||||
| 2508 | if (py_result == nullptr || PyErr_Occurred()) { | ||||
| 2509 | return tensorflow::errors::Internal( | ||||
| 2510 | "forward gradient function threw exceptions"); | ||||
| 2511 | } | ||||
| 2512 | return ParseTangentOutputs(py_result.get(), output_tangents); | ||||
| 2513 | } | ||||
| 2514 | |||||
| 2515 | bool ParseOpTypeString(PyObject* op_type, string* op_type_string) { | ||||
| 2516 |   if (PyBytes_Check(op_type)((((((PyObject*)(op_type))->ob_type))->tp_flags & ( (1UL << 27))) != 0)) {  | ||||
| 2517 | *op_type_string = PyBytes_AsString(op_type); | ||||
| 2518 |   } else if (PyUnicode_Check(op_type)((((((PyObject*)(op_type))->ob_type))->tp_flags & ( (1UL << 28))) != 0)) {  | ||||
| 2519 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 2520 | *op_type_string = PyUnicode_AsUTF8(op_type); | ||||
| 2521 | #else | ||||
| 2522 | PyObject* py_str = PyUnicode_AsUTF8String(op_type); | ||||
| 2523 | if (py_str == nullptr) { | ||||
| 2524 | return false; | ||||
| 2525 | } | ||||
| 2526 |     *op_type_string = PyBytes_AS_STRING(py_str)((static_cast<void> (0)), (((PyBytesObject *)(py_str))-> ob_sval));  | ||||
| 2527 | Py_DECREF(py_str)_Py_DECREF(((PyObject*)(py_str))); | ||||
| 2528 | #endif | ||||
| 2529 | } else { | ||||
| 2530 | PyErr_SetString(PyExc_RuntimeError, "op_type should be a string."); | ||||
| 2531 | return false; | ||||
| 2532 | } | ||||
| 2533 | return true; | ||||
| 2534 | } | ||||
| 2535 | |||||
| 2536 | bool TapeSetRecordOperation( | ||||
| 2537 | PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors, | ||||
| 2538 | const std::vector<int64_t>& input_ids, | ||||
| 2539 | const std::vector<tensorflow::DataType>& input_dtypes, | ||||
| 2540 | const std::function<PyBackwardFunction*()>& backward_function_getter, | ||||
| 2541 | const std::function<void(PyBackwardFunction*)>& backward_function_killer, | ||||
| 2542 | const tensorflow::eager::ForwardFunction<PyObject>* forward_function) { | ||||
| 2543 | std::vector<PyTapeTensor> output_info; | ||||
| 2544 | tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( | ||||
| 2545 | output_tensors, "expected a sequence of integer tensor ids")); | ||||
| 2546 | if (PyErr_Occurred() || | ||||
| 2547 | !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { | ||||
| 2548 | return false; | ||||
| 2549 | } | ||||
| 2550 | string op_type_str; | ||||
| 2551 | if (!ParseOpTypeString(op_type, &op_type_str)) { | ||||
| 2552 | return false; | ||||
| 2553 | } | ||||
| 2554 | tensorflow::uint64 max_gradient_tape_id; | ||||
| 2555 | if (!TapeSetRecordForwardprop( | ||||
| 2556 | op_type_str, output_seq.get(), output_info, input_tensors, input_ids, | ||||
| 2557 | input_dtypes, backward_function_getter, backward_function_killer, | ||||
| 2558 | forward_function, nullptr /* No special-cased jvps. */, | ||||
| 2559 | &max_gradient_tape_id)) { | ||||
| 2560 | return false; | ||||
| 2561 | } | ||||
| 2562 | TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes, | ||||
| 2563 | backward_function_getter, backward_function_killer, | ||||
| 2564 | max_gradient_tape_id); | ||||
| 2565 | return true; | ||||
| 2566 | } | ||||
| 2567 | } // namespace | ||||
| 2568 | |||||
| 2569 | PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type, | ||||
| 2570 | PyObject* output_tensors, | ||||
| 2571 | PyObject* input_tensors, | ||||
| 2572 | PyObject* backward_function, | ||||
| 2573 | PyObject* forward_function) { | ||||
| 2574 | if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) { | ||||
| 2575 |     Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 2576 | } | ||||
| 2577 | std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors); | ||||
| 2578 | if (PyErr_Occurred()) return nullptr; | ||||
| 2579 | |||||
| 2580 | std::vector<tensorflow::DataType> input_dtypes = | ||||
| 2581 | MakeTensorDtypeList(input_tensors); | ||||
| 2582 | if (PyErr_Occurred()) return nullptr; | ||||
| 2583 | |||||
| 2584 | std::function<PyBackwardFunction*()> backward_function_getter( | ||||
| 2585 | [backward_function]() { | ||||
| 2586 | Py_INCREF(backward_function)_Py_INCREF(((PyObject*)(backward_function))); | ||||
| 2587 | PyBackwardFunction* function = new PyBackwardFunction( | ||||
| 2588 | [backward_function](PyObject* out_grads, | ||||
| 2589 | const std::vector<int64_t>& unused) { | ||||
| 2590 | return PyObject_CallObject(backward_function, out_grads); | ||||
| 2591 | }); | ||||
| 2592 | return function; | ||||
| 2593 | }); | ||||
| 2594 | std::function<void(PyBackwardFunction*)> backward_function_killer( | ||||
| 2595 | [backward_function](PyBackwardFunction* py_backward_function) { | ||||
| 2596 | Py_DECREF(backward_function)_Py_DECREF(((PyObject*)(backward_function))); | ||||
| 2597 | delete py_backward_function; | ||||
| 2598 | }); | ||||
| 2599 | |||||
| 2600 | if (forward_function == Py_None(&_Py_NoneStruct)) { | ||||
| 2601 | if (!TapeSetRecordOperation( | ||||
| 2602 | op_type, input_tensors, output_tensors, input_ids, input_dtypes, | ||||
| 2603 | backward_function_getter, backward_function_killer, | ||||
| 2604 | nullptr /* No special-cased forward function */)) { | ||||
| 2605 | return nullptr; | ||||
| 2606 | } | ||||
| 2607 | } else { | ||||
| 2608 | tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function( | ||||
| 2609 | [forward_function](const std::vector<PyObject*>& input_tangents, | ||||
| 2610 | std::vector<PyObject*>* output_tangents, | ||||
| 2611 | bool use_batch = false) { | ||||
| 2612 | return CallOpSpecificJVPFunction(forward_function, input_tangents, | ||||
| 2613 | output_tangents); | ||||
| 2614 | }); | ||||
| 2615 | if (!TapeSetRecordOperation( | ||||
| 2616 | op_type, input_tensors, output_tensors, input_ids, input_dtypes, | ||||
| 2617 | backward_function_getter, backward_function_killer, | ||||
| 2618 | &wrapped_forward_function)) { | ||||
| 2619 | return nullptr; | ||||
| 2620 | } | ||||
| 2621 | } | ||||
| 2622 |   Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 2623 | } | ||||
| 2624 | |||||
| 2625 | PyObject* TFE_Py_TapeSetRecordOperationForwardprop( | ||||
| 2626 | PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors, | ||||
| 2627 | PyObject* backward_function, PyObject* forwardprop_output_indices) { | ||||
| 2628 | if (!HasAccumulator() || *ThreadTapeIsStopped()) { | ||||
| 2629 |     Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 2630 | } | ||||
| 2631 | std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors); | ||||
| 2632 | if (PyErr_Occurred()) return nullptr; | ||||
| 2633 | |||||
| 2634 | std::vector<tensorflow::DataType> input_dtypes = | ||||
| 2635 | MakeTensorDtypeList(input_tensors); | ||||
| 2636 | if (PyErr_Occurred()) return nullptr; | ||||
| 2637 | |||||
| 2638 | std::function<PyBackwardFunction*()> backward_function_getter( | ||||
| 2639 | [backward_function]() { | ||||
| 2640 | Py_INCREF(backward_function)_Py_INCREF(((PyObject*)(backward_function))); | ||||
| 2641 | PyBackwardFunction* function = new PyBackwardFunction( | ||||
| 2642 | [backward_function](PyObject* out_grads, | ||||
| 2643 | const std::vector<int64_t>& unused) { | ||||
| 2644 | return PyObject_CallObject(backward_function, out_grads); | ||||
| 2645 | }); | ||||
| 2646 | return function; | ||||
| 2647 | }); | ||||
| 2648 | std::function<void(PyBackwardFunction*)> backward_function_killer( | ||||
| 2649 | [backward_function](PyBackwardFunction* py_backward_function) { | ||||
| 2650 | Py_DECREF(backward_function)_Py_DECREF(((PyObject*)(backward_function))); | ||||
| 2651 | delete py_backward_function; | ||||
| 2652 | }); | ||||
| 2653 | std::vector<PyTapeTensor> output_info; | ||||
| 2654 | tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( | ||||
| 2655 | output_tensors, "expected a sequence of integer tensor ids")); | ||||
| 2656 | if (PyErr_Occurred() || | ||||
| 2657 | !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { | ||||
| 2658 | return nullptr; | ||||
| 2659 | } | ||||
| 2660 | string op_type_str; | ||||
| 2661 | if (!ParseOpTypeString(op_type, &op_type_str)) { | ||||
| 2662 | return nullptr; | ||||
| 2663 | } | ||||
| 2664 | tensorflow::uint64 max_gradient_tape_id; | ||||
| 2665 | if (!TapeSetRecordForwardprop( | ||||
| 2666 | op_type_str, output_seq.get(), output_info, input_tensors, input_ids, | ||||
| 2667 | input_dtypes, backward_function_getter, backward_function_killer, | ||||
| 2668 | nullptr /* no special-cased forward function */, | ||||
| 2669 | forwardprop_output_indices, &max_gradient_tape_id)) { | ||||
| 2670 | return nullptr; | ||||
| 2671 | } | ||||
| 2672 |   Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 2673 | } | ||||
| 2674 | |||||
| 2675 | PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type, | ||||
| 2676 | PyObject* output_tensors, | ||||
| 2677 | PyObject* input_tensors, | ||||
| 2678 | PyObject* backward_function) { | ||||
| 2679 | if (!CouldBackprop()) { | ||||
| 2680 |     Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 2681 | } | ||||
| 2682 | std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors); | ||||
| 2683 | if (PyErr_Occurred()) return nullptr; | ||||
| 2684 | |||||
| 2685 | std::vector<tensorflow::DataType> input_dtypes = | ||||
| 2686 | MakeTensorDtypeList(input_tensors); | ||||
| 2687 | if (PyErr_Occurred()) return nullptr; | ||||
| 2688 | |||||
| 2689 | std::function<PyBackwardFunction*()> backward_function_getter( | ||||
| 2690 | [backward_function]() { | ||||
| 2691 | Py_INCREF(backward_function)_Py_INCREF(((PyObject*)(backward_function))); | ||||
| 2692 | PyBackwardFunction* function = new PyBackwardFunction( | ||||
| 2693 | [backward_function](PyObject* out_grads, | ||||
| 2694 | const std::vector<int64_t>& unused) { | ||||
| 2695 | return PyObject_CallObject(backward_function, out_grads); | ||||
| 2696 | }); | ||||
| 2697 | return function; | ||||
| 2698 | }); | ||||
| 2699 | std::function<void(PyBackwardFunction*)> backward_function_killer( | ||||
| 2700 | [backward_function](PyBackwardFunction* py_backward_function) { | ||||
| 2701 | Py_DECREF(backward_function)_Py_DECREF(((PyObject*)(backward_function))); | ||||
| 2702 | delete py_backward_function; | ||||
| 2703 | }); | ||||
| 2704 | std::vector<PyTapeTensor> output_info; | ||||
| 2705 | tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( | ||||
| 2706 | output_tensors, "expected a sequence of integer tensor ids")); | ||||
| 2707 | if (PyErr_Occurred() || | ||||
| 2708 | !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { | ||||
| 2709 | return nullptr; | ||||
| 2710 | } | ||||
| 2711 | string op_type_str; | ||||
| 2712 | if (!ParseOpTypeString(op_type, &op_type_str)) { | ||||
| 2713 | return nullptr; | ||||
| 2714 | } | ||||
| 2715 | TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes, | ||||
| 2716 | backward_function_getter, backward_function_killer, | ||||
| 2717 | // No filtering based on relative ordering with forward | ||||
| 2718 | // accumulators. | ||||
| 2719 | std::numeric_limits<tensorflow::uint64>::max()); | ||||
| 2720 |   Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 2721 | } | ||||
| 2722 | |||||
| 2723 | void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id) { | ||||
| 2724 | for (TFE_Py_Tape* tape : *GetTapeSet()) { | ||||
| 2725 | tape->tape->DeleteTrace(tensor_id); | ||||
| 2726 | } | ||||
| 2727 | } | ||||
| 2728 | |||||
| 2729 | std::vector<PyObject*> MakeTensorList(PyObject* tensors) { | ||||
| 2730 | PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); | ||||
| 2731 | if (seq == nullptr) { | ||||
| 2732 | return {}; | ||||
| 2733 | } | ||||
| 2734 |   int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size));  | ||||
| 2735 |   PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq))->ob_item : ((PyTupleObject *)(seq))->ob_item);  | ||||
| 2736 | std::vector<PyObject*> list(seq_array, seq_array + len); | ||||
| 2737 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | ||||
| 2738 | return list; | ||||
| 2739 | } | ||||
| 2740 | |||||
| 2741 | PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target, | ||||
| 2742 | PyObject* sources, PyObject* output_gradients, | ||||
| 2743 | PyObject* sources_raw, | ||||
| 2744 | PyObject* unconnected_gradients, | ||||
| 2745 | TF_Status* status) { | ||||
| 2746 | TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape); | ||||
| 2747 | if (!tape_obj->tape->IsPersistent()) { | ||||
| 2748 | auto* tape_set = GetTapeSet(); | ||||
| 2749 | if (tape_set->find(tape_obj) != tape_set->end()) { | ||||
| 2750 | PyErr_SetString(PyExc_RuntimeError, | ||||
| 2751 | "gradient() cannot be invoked within the " | ||||
| 2752 | "GradientTape context (i.e., while operations are being " | ||||
| 2753 | "recorded). Either move the call to gradient() to be " | ||||
| 2754 | "outside the 'with tf.GradientTape' block, or " | ||||
| 2755 | "use a persistent tape: " | ||||
| 2756 | "'with tf.GradientTape(persistent=true)'"); | ||||
| 2757 | return nullptr; | ||||
| 2758 | } | ||||
| 2759 | } | ||||
| 2760 | |||||
| 2761 | std::vector<int64_t> target_vec = MakeTensorIDList(target); | ||||
| 2762 | if (PyErr_Occurred()) { | ||||
| 2763 | return nullptr; | ||||
| 2764 | } | ||||
| 2765 | std::vector<int64_t> sources_vec = MakeTensorIDList(sources); | ||||
| 2766 | if (PyErr_Occurred()) { | ||||
| 2767 | return nullptr; | ||||
| 2768 | } | ||||
| 2769 | tensorflow::gtl::FlatSet<int64_t> sources_set(sources_vec.begin(), | ||||
| 2770 | sources_vec.end()); | ||||
| 2771 | |||||
| 2772 | tensorflow::Safe_PyObjectPtr seq = | ||||
| 2773 | tensorflow::make_safe(PySequence_Fast(target, "expected a sequence")); | ||||
| 2774 |   int len = PySequence_Fast_GET_SIZE(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(seq.get()))->ob_size)) : (((PyVarObject* )(((static_cast<void> (0)), (PyTupleObject *)(seq.get() ))))->ob_size));  | ||||
| 2775 |   PyObject** seq_array = PySequence_Fast_ITEMS(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(seq.get()))-> ob_item : ((PyTupleObject *)(seq.get()))->ob_item);  | ||||
| 2776 | std::unordered_map<int64_t, PyTapeTensor> source_tensors_that_are_targets; | ||||
| 2777 | for (int i = 0; i < len; ++i) { | ||||
| 2778 | int64_t target_id = target_vec[i]; | ||||
| 2779 | if (sources_set.find(target_id) != sources_set.end()) { | ||||
| 2780 | auto tensor = seq_array[i]; | ||||
| 2781 | source_tensors_that_are_targets.insert( | ||||
| 2782 | std::make_pair(target_id, TapeTensorFromTensor(tensor))); | ||||
| 2783 | } | ||||
| 2784 | if (PyErr_Occurred()) { | ||||
| 2785 | return nullptr; | ||||
| 2786 | } | ||||
| 2787 | } | ||||
| 2788 | if (PyErr_Occurred()) { | ||||
| 2789 | return nullptr; | ||||
| 2790 | } | ||||
| 2791 | |||||
| 2792 | std::vector<PyObject*> outgrad_vec; | ||||
| 2793 | if (output_gradients != Py_None(&_Py_NoneStruct)) { | ||||
| 2794 | outgrad_vec = MakeTensorList(output_gradients); | ||||
| 2795 | if (PyErr_Occurred()) { | ||||
| 2796 | return nullptr; | ||||
| 2797 | } | ||||
| 2798 | for (PyObject* tensor : outgrad_vec) { | ||||
| 2799 | // Calling the backward function will eat a reference to the tensors in | ||||
| 2800 | // outgrad_vec, so we need to increase their reference count. | ||||
| 2801 | Py_INCREF(tensor)_Py_INCREF(((PyObject*)(tensor))); | ||||
| 2802 | } | ||||
| 2803 | } | ||||
| 2804 | std::vector<PyObject*> result(sources_vec.size()); | ||||
| 2805 | status->status = tape_obj->tape->ComputeGradient( | ||||
| 2806 | *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets, | ||||
| 2807 | outgrad_vec, absl::MakeSpan(result)); | ||||
| 2808 | if (!status->status.ok()) { | ||||
| 2809 | if (PyErr_Occurred()) { | ||||
| 2810 | // Do not propagate the erroneous status as that would swallow the | ||||
| 2811 | // exception which caused the problem. | ||||
| 2812 | status->status = tensorflow::Status::OK(); | ||||
| 2813 | } | ||||
| 2814 | return nullptr; | ||||
| 2815 | } | ||||
| 2816 | |||||
| 2817 | bool unconnected_gradients_zero = | ||||
| 2818 | strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0; | ||||
| 2819 | std::vector<PyObject*> sources_obj; | ||||
| 2820 | if (unconnected_gradients_zero) { | ||||
| 2821 | // Uses the "raw" sources here so it can properly make a zeros tensor even | ||||
| 2822 | // if there are resource variables as sources. | ||||
| 2823 | sources_obj = MakeTensorList(sources_raw); | ||||
| 2824 | } | ||||
| 2825 | |||||
| 2826 | if (!result.empty()) { | ||||
| 2827 | PyObject* py_result = PyList_New(result.size()); | ||||
| 2828 | tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size()); | ||||
| 2829 | for (int i = 0; i < result.size(); ++i) { | ||||
| 2830 | if (result[i] == nullptr) { | ||||
| 2831 | if (unconnected_gradients_zero) { | ||||
| 2832 | // generate a zeros tensor in the shape of sources[i] | ||||
| 2833 | tensorflow::DataType dtype = | ||||
| 2834 | tensorflow::PyTensor_DataType(sources_obj[i]); | ||||
| 2835 | PyTapeTensor tensor = | ||||
| 2836 | PyTapeTensor(sources_vec[i], dtype, sources_obj[i]); | ||||
| 2837 | result[i] = tensor.ZerosLike(); | ||||
| 2838 | } else { | ||||
| 2839 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
| 2840 | result[i] = Py_None(&_Py_NoneStruct); | ||||
| 2841 | } | ||||
| 2842 | } else if (seen_results.find(result[i]) != seen_results.end()) { | ||||
| 2843 | Py_INCREF(result[i])_Py_INCREF(((PyObject*)(result[i]))); | ||||
| 2844 | } | ||||
| 2845 | seen_results.insert(result[i]); | ||||
| 2846 |       PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]))PyList_SetItem(py_result, i, reinterpret_cast<PyObject*> (result[i]));  | ||||
| 2847 | } | ||||
| 2848 | return py_result; | ||||
| 2849 | } | ||||
| 2850 | return PyList_New(0); | ||||
| 2851 | } | ||||
| 2852 | |||||
| 2853 | PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) { | ||||
| 2854 | TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew; | ||||
| 2855 | if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr; | ||||
| 2856 | TFE_Py_ForwardAccumulator* accumulator = | ||||
| 2857 |       PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type)( (TFE_Py_ForwardAccumulator *) PyObject_Init( (PyObject *) PyObject_Malloc ( ( (&TFE_Py_ForwardAccumulator_Type)->tp_basicsize ) ) , (&TFE_Py_ForwardAccumulator_Type)) );  | ||||
| 2858 | if (py_vspace == nullptr) { | ||||
| 2859 | MaybeRaiseExceptionFromStatus( | ||||
| 2860 | tensorflow::errors::Internal( | ||||
| 2861 | "ForwardAccumulator requires a PyVSpace to be registered."), | ||||
| 2862 | nullptr); | ||||
| 2863 | } | ||||
| 2864 | accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch); | ||||
| 2865 | return reinterpret_cast<PyObject*>(accumulator); | ||||
| 2866 | } | ||||
| 2867 | |||||
| 2868 | PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) { | ||||
| 2869 | TFE_Py_ForwardAccumulator* c_accumulator( | ||||
| 2870 | reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)); | ||||
| 2871 | c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1); | ||||
| 2872 | if (GetAccumulatorSet()->insert(c_accumulator)) { | ||||
| 2873 | Py_INCREF(accumulator)_Py_INCREF(((PyObject*)(accumulator))); | ||||
| 2874 |     Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 2875 | } else { | ||||
| 2876 | MaybeRaiseExceptionFromStatus( | ||||
| 2877 | tensorflow::errors::Internal( | ||||
| 2878 | "A ForwardAccumulator was added to the active set twice."), | ||||
| 2879 | nullptr); | ||||
| 2880 | return nullptr; | ||||
| 2881 | } | ||||
| 2882 | } | ||||
| 2883 | |||||
| 2884 | void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) { | ||||
| 2885 | GetAccumulatorSet()->erase( | ||||
| 2886 | reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)); | ||||
| 2887 | Py_DECREF(accumulator)_Py_DECREF(((PyObject*)(accumulator))); | ||||
| 2888 | } | ||||
| 2889 | |||||
| 2890 | void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor, | ||||
| 2891 | PyObject* tangent) { | ||||
| 2892 | int64_t tensor_id = FastTensorId(tensor); | ||||
| 2893 | reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator) | ||||
| 2894 | ->accumulator->Watch(tensor_id, tangent); | ||||
| 2895 | RegisterForwardAccumulatorCleanup(tensor, tensor_id); | ||||
| 2896 | } | ||||
| 2897 | |||||
| 2898 | // Returns a new reference to the JVP Tensor. | ||||
| 2899 | PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, | ||||
| 2900 | PyObject* tensor) { | ||||
| 2901 | PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator) | ||||
| 2902 | ->accumulator->FetchJVP(FastTensorId(tensor)); | ||||
| 2903 | if (jvp == nullptr) { | ||||
| 2904 | jvp = Py_None(&_Py_NoneStruct); | ||||
| 2905 | } | ||||
| 2906 | Py_INCREF(jvp)_Py_INCREF(((PyObject*)(jvp))); | ||||
| 2907 | return jvp; | ||||
| 2908 | } | ||||
| 2909 | |||||
| 2910 | PyObject* TFE_Py_PackJVPs(PyObject* tensors) { | ||||
| 2911 | if (!TapeCouldPossiblyRecord(tensors)) { | ||||
| 2912 | tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0)); | ||||
| 2913 | tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0)); | ||||
| 2914 | return PyTuple_Pack(2, empty_tuple.get(), empty_list.get()); | ||||
| 2915 | } | ||||
| 2916 | auto accumulators = *GetAccumulatorSet(); | ||||
| 2917 | tensorflow::Safe_PyObjectPtr tensors_fast( | ||||
| 2918 | PySequence_Fast(tensors, "Expected a sequence of input Tensors.")); | ||||
| 2919 | if (tensors_fast == nullptr || PyErr_Occurred()) { | ||||
| 2920 | return nullptr; | ||||
| 2921 | } | ||||
| 2922 | std::vector<int64_t> augmented_input_ids; | ||||
| 2923 |   int len = PySequence_Fast_GET_SIZE(tensors_fast.get())(((((((PyObject*)(tensors_fast.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(tensors_fast.get()))->ob_size)) : ( ((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(tensors_fast.get()))))->ob_size));  | ||||
| 2924 |   PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get())(((((((PyObject*)(tensors_fast.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(tensors_fast .get()))->ob_item : ((PyTupleObject *)(tensors_fast.get()) )->ob_item);  | ||||
| 2925 | for (Py_ssize_t position = 0; position < len; ++position) { | ||||
| 2926 | PyObject* input = tensors_fast_array[position]; | ||||
| 2927 | if (input == Py_None(&_Py_NoneStruct)) { | ||||
| 2928 | continue; | ||||
| 2929 | } | ||||
| 2930 | tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input)); | ||||
| 2931 | if (input_dtype == tensorflow::DT_INVALID) { | ||||
| 2932 | return nullptr; | ||||
| 2933 | } | ||||
| 2934 | augmented_input_ids.push_back(FastTensorId(input)); | ||||
| 2935 | } | ||||
| 2936 | if (PyErr_Occurred()) { | ||||
| 2937 | return nullptr; | ||||
| 2938 | } | ||||
| 2939 | // Find the innermost accumulator such that all outer accumulators are | ||||
| 2940 | // recording. Any more deeply nested accumulators will not have their JVPs | ||||
| 2941 | // saved. | ||||
| 2942 | AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin(); | ||||
| 2943 | for (; innermost_all_recording != accumulators.end(); | ||||
| 2944 | ++innermost_all_recording) { | ||||
| 2945 | if ((*innermost_all_recording)->accumulator->BusyAccumulating()) { | ||||
| 2946 | break; | ||||
| 2947 | } | ||||
| 2948 | } | ||||
| 2949 | AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording( | ||||
| 2950 | innermost_all_recording); | ||||
| 2951 | |||||
| 2952 | bool saving_jvps = false; | ||||
| 2953 | tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size())); | ||||
| 2954 | std::vector<PyObject*> new_tensors; | ||||
| 2955 | Py_ssize_t accumulator_index = 0; | ||||
| 2956 | // Start with the innermost accumulators to give outer accumulators a chance | ||||
| 2957 | // to find their higher-order JVPs. | ||||
| 2958 | for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin(); | ||||
| 2959 | it != accumulators.rend(); ++it, ++accumulator_index) { | ||||
| 2960 | std::vector<int64_t> new_input_ids; | ||||
| 2961 | std::vector<std::pair<int64_t, int64_t>> accumulator_indices; | ||||
| 2962 | if (it == reverse_innermost_all_recording) { | ||||
| 2963 | saving_jvps = true; | ||||
| 2964 | } | ||||
| 2965 | if (saving_jvps) { | ||||
| 2966 | for (int input_index = 0; input_index < augmented_input_ids.size(); | ||||
| 2967 | ++input_index) { | ||||
| 2968 | int64_t existing_input = augmented_input_ids[input_index]; | ||||
| 2969 | PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input); | ||||
| 2970 | if (jvp != nullptr) { | ||||
| 2971 | new_tensors.push_back(jvp); | ||||
| 2972 | new_input_ids.push_back(FastTensorId(jvp)); | ||||
| 2973 | accumulator_indices.emplace_back( | ||||
| 2974 | input_index, | ||||
| 2975 | augmented_input_ids.size() + new_input_ids.size() - 1); | ||||
| 2976 | } | ||||
| 2977 | } | ||||
| 2978 | } | ||||
| 2979 | tensorflow::Safe_PyObjectPtr accumulator_indices_py( | ||||
| 2980 | PyTuple_New(accumulator_indices.size())); | ||||
| 2981 | for (int i = 0; i < accumulator_indices.size(); ++i) { | ||||
| 2982 | tensorflow::Safe_PyObjectPtr from_index( | ||||
| 2983 | GetPythonObjectFromInt(accumulator_indices[i].first)); | ||||
| 2984 | tensorflow::Safe_PyObjectPtr to_index( | ||||
| 2985 | GetPythonObjectFromInt(accumulator_indices[i].second)); | ||||
| 2986 | PyTuple_SetItem(accumulator_indices_py.get(), i, | ||||
| 2987 | PyTuple_Pack(2, from_index.get(), to_index.get())); | ||||
| 2988 | } | ||||
| 2989 | PyTuple_SetItem(all_indices.get(), accumulator_index, | ||||
| 2990 | accumulator_indices_py.release()); | ||||
| 2991 | augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(), | ||||
| 2992 | new_input_ids.end()); | ||||
| 2993 | } | ||||
| 2994 | |||||
| 2995 | tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size())); | ||||
| 2996 | for (int i = 0; i < new_tensors.size(); ++i) { | ||||
| 2997 | PyObject* jvp = new_tensors[i]; | ||||
| 2998 | Py_INCREF(jvp)_Py_INCREF(((PyObject*)(jvp))); | ||||
| 2999 | PyList_SET_ITEM(new_tensors_py.get(), i, jvp)PyList_SetItem(new_tensors_py.get(), i, jvp); | ||||
| 3000 | } | ||||
| 3001 | return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get()); | ||||
| 3002 | } | ||||
| 3003 | |||||
| 3004 | namespace { | ||||
| 3005 | |||||
| 3006 | // Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C. | ||||
| 3007 | enum FastPathExecuteArgIndex { | ||||
| 3008 | FAST_PATH_EXECUTE_ARG_CONTEXT = 0, | ||||
| 3009 | FAST_PATH_EXECUTE_ARG_OP_NAME = 1, | ||||
| 3010 | FAST_PATH_EXECUTE_ARG_NAME = 2, | ||||
| 3011 | FAST_PATH_EXECUTE_ARG_INPUT_START = 3 | ||||
| 3012 | }; | ||||
| 3013 | |||||
| 3014 | PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) { | ||||
| 3015 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 3016 | return PyUnicode_FromStringAndSize(s.data(), s.size()); | ||||
| 3017 | #else | ||||
| 3018 | return PyBytes_FromStringAndSize(s.data(), s.size()); | ||||
| 3019 | #endif | ||||
| 3020 | } | ||||
| 3021 | |||||
| 3022 | bool CheckResourceVariable(PyObject* item) { | ||||
| 3023 | if (tensorflow::swig::IsResourceVariable(item)) { | ||||
| 3024 | tensorflow::Safe_PyObjectPtr handle( | ||||
| 3025 | PyObject_GetAttrString(item, "_handle")); | ||||
| 3026 | return EagerTensor_CheckExact(handle.get()); | ||||
| 3027 | } | ||||
| 3028 | |||||
| 3029 | return false; | ||||
| 3030 | } | ||||
| 3031 | |||||
| 3032 | bool IsNumberType(PyObject* item) { | ||||
| 3033 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 3034 |   return PyFloat_Check(item)((((PyObject*)(item))->ob_type) == (&PyFloat_Type) || PyType_IsSubtype ((((PyObject*)(item))->ob_type), (&PyFloat_Type))) || PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL << 24))) != 0);  | ||||
| 3035 | #else | ||||
| 3036 |   return PyFloat_Check(item)((((PyObject*)(item))->ob_type) == (&PyFloat_Type) || PyType_IsSubtype ((((PyObject*)(item))->ob_type), (&PyFloat_Type))) || PyInt_Check(item) || PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL << 24))) != 0);  | ||||
| 3037 | #endif | ||||
| 3038 | } | ||||
| 3039 | |||||
| 3040 | bool CheckOneInput(PyObject* item) { | ||||
| 3041 | if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) || | ||||
| 3042 |       PyArray_Check(item)((((PyObject*)(item))->ob_type) == (&(*(PyTypeObject * )_tensorflow_numpy_api[2])) || PyType_IsSubtype((((PyObject*) (item))->ob_type), (&(*(PyTypeObject *)_tensorflow_numpy_api [2])))) || IsNumberType(item)) {  | ||||
| 3043 | return true; | ||||
| 3044 | } | ||||
| 3045 | |||||
| 3046 | // Sequences are not properly handled. Sequences with purely python numeric | ||||
| 3047 | // types work, but sequences with mixes of EagerTensors and python numeric | ||||
| 3048 | // types don't work. | ||||
| 3049 | // TODO(nareshmodi): fix | ||||
| 3050 | return false; | ||||
| 3051 | } | ||||
| 3052 | |||||
| 3053 | bool CheckInputsOk(PyObject* seq, int start_index, | ||||
| 3054 | const tensorflow::OpDef& op_def) { | ||||
| 3055 | for (int i = 0; i < op_def.input_arg_size(); i++) { | ||||
| 3056 |     PyObject* item = PyTuple_GET_ITEM(seq, i + start_index)(((static_cast<void> (0)), (PyTupleObject *)(seq))-> ob_item[i + start_index]);  | ||||
| 3057 | if (!op_def.input_arg(i).number_attr().empty() || | ||||
| 3058 | !op_def.input_arg(i).type_list_attr().empty()) { | ||||
| 3059 | // This item should be a seq input. | ||||
| 3060 | if (!PySequence_Check(item)) { | ||||
| 3061 |         VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 3061, tensorflow::INFO) << "Falling back to slow path for Op \"" << op_def.name()  | ||||
| 3062 | << "\", Input \"" << op_def.input_arg(i).name() | ||||
| 3063 | << "\" since we expected a sequence, but got " | ||||
| 3064 | << item->ob_type->tp_name; | ||||
| 3065 | return false; | ||||
| 3066 | } | ||||
| 3067 | tensorflow::Safe_PyObjectPtr fast_item( | ||||
| 3068 | PySequence_Fast(item, "Could not parse sequence.")); | ||||
| 3069 | if (fast_item.get() == nullptr) { | ||||
| 3070 | return false; | ||||
| 3071 | } | ||||
| 3072 |       int len = PySequence_Fast_GET_SIZE(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_item.get()))->ob_size)) : ((( PyVarObject*)(((static_cast<void> (0)), (PyTupleObject * )(fast_item.get()))))->ob_size));  | ||||
| 3073 |       PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_item .get()))->ob_item : ((PyTupleObject *)(fast_item.get()))-> ob_item);  | ||||
| 3074 | for (Py_ssize_t j = 0; j < len; j++) { | ||||
| 3075 | PyObject* inner_item = fast_item_array[j]; | ||||
| 3076 | if (!CheckOneInput(inner_item)) { | ||||
| 3077 |           VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 3077, tensorflow::INFO) << "Falling back to slow path for Op \"" << op_def.name()  | ||||
| 3078 | << "\", Input \"" << op_def.input_arg(i).name() | ||||
| 3079 | << "\", Index " << j | ||||
| 3080 | << " since we expected an EagerTensor/ResourceVariable, " | ||||
| 3081 | "but got " | ||||
| 3082 | << inner_item->ob_type->tp_name; | ||||
| 3083 | return false; | ||||
| 3084 | } | ||||
| 3085 | } | ||||
| 3086 | } else if (!CheckOneInput(item)) { | ||||
| 3087 |       VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 3087, tensorflow::INFO)  | ||||
| 3088 | << "Falling back to slow path for Op \"" << op_def.name() | ||||
| 3089 | << "\", Input \"" << op_def.input_arg(i).name() | ||||
| 3090 | << "\" since we expected an EagerTensor/ResourceVariable, but got " | ||||
| 3091 | << item->ob_type->tp_name; | ||||
| 3092 | return false; | ||||
| 3093 | } | ||||
| 3094 | } | ||||
| 3095 | |||||
| 3096 | return true; | ||||
| 3097 | } | ||||
| 3098 | |||||
| 3099 | tensorflow::DataType MaybeGetDType(PyObject* item) { | ||||
| 3100 | if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) { | ||||
| 3101 | return tensorflow::PyTensor_DataType(item); | ||||
| 3102 | } | ||||
| 3103 | |||||
| 3104 | return tensorflow::DT_INVALID; | ||||
| 3105 | } | ||||
| 3106 | |||||
| 3107 | tensorflow::DataType MaybeGetDTypeForAttr(const string& attr, | ||||
| 3108 | FastPathOpExecInfo* op_exec_info) { | ||||
| 3109 | auto cached_it = op_exec_info->cached_dtypes.find(attr); | ||||
| 3110 | if (cached_it != op_exec_info->cached_dtypes.end()) { | ||||
| 3111 | return cached_it->second; | ||||
| 3112 | } | ||||
| 3113 | |||||
| 3114 | auto it = op_exec_info->attr_to_inputs_map->find(attr); | ||||
| 3115 | if (it == op_exec_info->attr_to_inputs_map->end()) { | ||||
| 3116 | // No other inputs - this should never happen. | ||||
| 3117 | return tensorflow::DT_INVALID; | ||||
| 3118 | } | ||||
| 3119 | |||||
| 3120 | for (const auto& input_info : it->second) { | ||||
| 3121 |     PyObject* item = PyTuple_GET_ITEM((((static_cast<void> (0)), (PyTupleObject *)(op_exec_info ->args))->ob_item[FAST_PATH_EXECUTE_ARG_INPUT_START + input_info .i])  | ||||
| 3122 |         op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i)(((static_cast<void> (0)), (PyTupleObject *)(op_exec_info ->args))->ob_item[FAST_PATH_EXECUTE_ARG_INPUT_START + input_info .i]);  | ||||
| 3123 | if (input_info.is_list) { | ||||
| 3124 | tensorflow::Safe_PyObjectPtr fast_item( | ||||
| 3125 | PySequence_Fast(item, "Unable to allocate")); | ||||
| 3126 |       int len = PySequence_Fast_GET_SIZE(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_item.get()))->ob_size)) : ((( PyVarObject*)(((static_cast<void> (0)), (PyTupleObject * )(fast_item.get()))))->ob_size));  | ||||
| 3127 |       PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_item .get()))->ob_item : ((PyTupleObject *)(fast_item.get()))-> ob_item);  | ||||
| 3128 | for (int i = 0; i < len; i++) { | ||||
| 3129 | auto dtype = MaybeGetDType(fast_item_array[i]); | ||||
| 3130 | if (dtype != tensorflow::DT_INVALID) return dtype; | ||||
| 3131 | } | ||||
| 3132 | } else { | ||||
| 3133 | auto dtype = MaybeGetDType(item); | ||||
| 3134 | if (dtype != tensorflow::DT_INVALID) return dtype; | ||||
| 3135 | } | ||||
| 3136 | } | ||||
| 3137 | |||||
| 3138 | auto default_it = op_exec_info->default_dtypes->find(attr); | ||||
| 3139 | if (default_it != op_exec_info->default_dtypes->end()) { | ||||
| 3140 | return default_it->second; | ||||
| 3141 | } | ||||
| 3142 | |||||
| 3143 | return tensorflow::DT_INVALID; | ||||
| 3144 | } | ||||
| 3145 | |||||
| 3146 | PyObject* CopySequenceSettingIndicesToNull( | ||||
| 3147 | PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) { | ||||
| 3148 | tensorflow::Safe_PyObjectPtr fast_seq( | ||||
| 3149 | PySequence_Fast(seq, "unable to allocate")); | ||||
| 3150 |   int len = PySequence_Fast_GET_SIZE(fast_seq.get())(((((((PyObject*)(fast_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_seq.get()))->ob_size)) : (((PyVarObject *)(((static_cast<void> (0)), (PyTupleObject *)(fast_seq .get()))))->ob_size));  | ||||
| 3151 |   PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get())(((((((PyObject*)(fast_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_seq .get()))->ob_item : ((PyTupleObject *)(fast_seq.get()))-> ob_item);  | ||||
| 3152 | PyObject* result = PyTuple_New(len); | ||||
| 3153 | for (int i = 0; i < len; i++) { | ||||
| 3154 | PyObject* item; | ||||
| 3155 | if (indices.find(i) != indices.end()) { | ||||
| 3156 | item = Py_None(&_Py_NoneStruct); | ||||
| 3157 | } else { | ||||
| 3158 | item = fast_seq_array[i]; | ||||
| 3159 | } | ||||
| 3160 | Py_INCREF(item)_Py_INCREF(((PyObject*)(item))); | ||||
| 3161 | PyTuple_SET_ITEM(result, i, item)PyTuple_SetItem(result, i, item); | ||||
| 3162 | } | ||||
| 3163 | return result; | ||||
| 3164 | } | ||||
| 3165 | |||||
| 3166 | PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, | ||||
| 3167 | PyObject* results, | ||||
| 3168 | PyObject* forward_pass_name_scope = nullptr) { | ||||
| 3169 | std::vector<int64_t> input_ids = MakeTensorIDList(inputs); | ||||
| 3170 | if (PyErr_Occurred()) return nullptr; | ||||
| 3171 | std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs); | ||||
| 3172 | if (PyErr_Occurred()) return nullptr; | ||||
| 3173 | |||||
| 3174 | bool should_record = false; | ||||
| 3175 | for (TFE_Py_Tape* tape : SafeTapeSet()) { | ||||
| 3176 | if (tape->tape->ShouldRecord(input_ids, input_dtypes)) { | ||||
| 3177 | should_record = true; | ||||
| 3178 | break; | ||||
| 3179 | } | ||||
| 3180 | } | ||||
| 3181 | if (!should_record) { | ||||
| 3182 | for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) { | ||||
| 3183 | if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) { | ||||
| 3184 | should_record = true; | ||||
| 3185 | break; | ||||
| 3186 | } | ||||
| 3187 | } | ||||
| 3188 | } | ||||
| 3189 |   if (!should_record) Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 3190 | |||||
| 3191 | string c_op_name = TFE_GetPythonString(op_name); | ||||
| 3192 | |||||
| 3193 | PyObject* op_outputs; | ||||
| 3194 | bool op_outputs_tuple_created = false; | ||||
| 3195 | |||||
| 3196 | if (const auto unused_output_indices = | ||||
| 3197 | OpGradientUnusedOutputIndices(c_op_name)) { | ||||
| 3198 | if (unused_output_indices->empty()) { | ||||
| 3199 | op_outputs = Py_None(&_Py_NoneStruct); | ||||
| 3200 | } else { | ||||
| 3201 | op_outputs_tuple_created = true; | ||||
| 3202 | op_outputs = | ||||
| 3203 | CopySequenceSettingIndicesToNull(results, *unused_output_indices); | ||||
| 3204 | } | ||||
| 3205 | } else { | ||||
| 3206 | op_outputs = results; | ||||
| 3207 | } | ||||
| 3208 | |||||
| 3209 | PyObject* op_inputs; | ||||
| 3210 | bool op_inputs_tuple_created = false; | ||||
| 3211 | |||||
| 3212 | if (const auto unused_input_indices = | ||||
| 3213 | OpGradientUnusedInputIndices(c_op_name)) { | ||||
| 3214 | if (unused_input_indices->empty()) { | ||||
| 3215 | op_inputs = Py_None(&_Py_NoneStruct); | ||||
| 3216 | } else { | ||||
| 3217 | op_inputs_tuple_created = true; | ||||
| 3218 | op_inputs = | ||||
| 3219 | CopySequenceSettingIndicesToNull(inputs, *unused_input_indices); | ||||
| 3220 | } | ||||
| 3221 | } else { | ||||
| 3222 | op_inputs = inputs; | ||||
| 3223 | } | ||||
| 3224 | |||||
| 3225 | tensorflow::eager::ForwardFunction<PyObject> py_forward_function( | ||||
| 3226 | [op_name, attrs, inputs, results]( | ||||
| 3227 | const std::vector<PyObject*>& input_tangents, | ||||
| 3228 | std::vector<PyObject*>* output_tangents, bool use_batch) { | ||||
| 3229 | return CallJVPFunction(op_name, attrs, inputs, results, input_tangents, | ||||
| 3230 | output_tangents, use_batch); | ||||
| 3231 | }); | ||||
| 3232 | tensorflow::eager::ForwardFunction<PyObject>* forward_function; | ||||
| 3233 | if (c_op_name == "While" || c_op_name == "StatelessWhile" || | ||||
| 3234 | c_op_name == "If" || c_op_name == "StatelessIf") { | ||||
| 3235 | // Control flow contains non-hashable attributes. Handling them in Python is | ||||
| 3236 | // a headache, so instead we'll stay as close to GradientTape's handling as | ||||
| 3237 | // possible (a null forward function means the accumulator forwards to a | ||||
| 3238 | // tape). | ||||
| 3239 | // | ||||
| 3240 | // This is safe to do since we'll only see control flow when graph building, | ||||
| 3241 | // in which case we can rely on pruning. | ||||
| 3242 | forward_function = nullptr; | ||||
| 3243 | } else { | ||||
| 3244 | forward_function = &py_forward_function; | ||||
| 3245 | } | ||||
| 3246 | |||||
| 3247 | PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs)); | ||||
| 3248 | |||||
| 3249 | if (!forward_pass_name_scope) forward_pass_name_scope = Py_None(&_Py_NoneStruct); | ||||
| 3250 | |||||
| 3251 | TapeSetRecordOperation( | ||||
| 3252 | op_name, inputs, results, input_ids, input_dtypes, | ||||
| 3253 | [op_name, attrs, num_inputs, op_inputs, op_outputs, | ||||
| 3254 | forward_pass_name_scope]() { | ||||
| 3255 | Py_INCREF(op_name)_Py_INCREF(((PyObject*)(op_name))); | ||||
| 3256 | Py_INCREF(attrs)_Py_INCREF(((PyObject*)(attrs))); | ||||
| 3257 | Py_INCREF(num_inputs)_Py_INCREF(((PyObject*)(num_inputs))); | ||||
| 3258 | Py_INCREF(op_inputs)_Py_INCREF(((PyObject*)(op_inputs))); | ||||
| 3259 | Py_INCREF(op_outputs)_Py_INCREF(((PyObject*)(op_outputs))); | ||||
| 3260 | Py_INCREF(forward_pass_name_scope)_Py_INCREF(((PyObject*)(forward_pass_name_scope))); | ||||
| 3261 | PyBackwardFunction* function = new PyBackwardFunction( | ||||
| 3262 | [op_name, attrs, num_inputs, op_inputs, op_outputs, | ||||
| 3263 | forward_pass_name_scope]( | ||||
| 3264 | PyObject* output_grads, | ||||
| 3265 | const std::vector<int64_t>& unneeded_gradients) { | ||||
| 3266 | if (PyErr_Occurred()) { | ||||
| 3267 | return static_cast<PyObject*>(nullptr); | ||||
| 3268 | } | ||||
| 3269 | tensorflow::Safe_PyObjectPtr skip_input_indices; | ||||
| 3270 | if (!unneeded_gradients.empty()) { | ||||
| 3271 | skip_input_indices.reset( | ||||
| 3272 | PyTuple_New(unneeded_gradients.size())); | ||||
| 3273 | for (int i = 0; i < unneeded_gradients.size(); i++) { | ||||
| 3274 |                   PyTuple_SET_ITEM(PyTuple_SetItem(skip_input_indices.get(), i, GetPythonObjectFromInt (unneeded_gradients[i]))  | ||||
| 3275 |                       skip_input_indices.get(), i,PyTuple_SetItem(skip_input_indices.get(), i, GetPythonObjectFromInt (unneeded_gradients[i]))  | ||||
| 3276 |                       GetPythonObjectFromInt(unneeded_gradients[i]))PyTuple_SetItem(skip_input_indices.get(), i, GetPythonObjectFromInt (unneeded_gradients[i]));  | ||||
| 3277 | } | ||||
| 3278 | } else { | ||||
| 3279 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
| 3280 | skip_input_indices.reset(Py_None(&_Py_NoneStruct)); | ||||
| 3281 | } | ||||
| 3282 | tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue( | ||||
| 3283 | "OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs, | ||||
| 3284 | output_grads, skip_input_indices.get(), | ||||
| 3285 | forward_pass_name_scope)); | ||||
| 3286 | |||||
| 3287 | tensorflow::Safe_PyObjectPtr result( | ||||
| 3288 | PyObject_CallObject(gradient_function, callback_args.get())); | ||||
| 3289 | |||||
| 3290 | if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr); | ||||
| 3291 | |||||
| 3292 | return tensorflow::swig::Flatten(result.get()); | ||||
| 3293 | }); | ||||
| 3294 | return function; | ||||
| 3295 | }, | ||||
| 3296 | [op_name, attrs, num_inputs, op_inputs, op_outputs, | ||||
| 3297 | forward_pass_name_scope](PyBackwardFunction* backward_function) { | ||||
| 3298 | Py_DECREF(op_name)_Py_DECREF(((PyObject*)(op_name))); | ||||
| 3299 | Py_DECREF(attrs)_Py_DECREF(((PyObject*)(attrs))); | ||||
| 3300 | Py_DECREF(num_inputs)_Py_DECREF(((PyObject*)(num_inputs))); | ||||
| 3301 | Py_DECREF(op_inputs)_Py_DECREF(((PyObject*)(op_inputs))); | ||||
| 3302 | Py_DECREF(op_outputs)_Py_DECREF(((PyObject*)(op_outputs))); | ||||
| 3303 | Py_DECREF(forward_pass_name_scope)_Py_DECREF(((PyObject*)(forward_pass_name_scope))); | ||||
| 3304 | |||||
| 3305 | delete backward_function; | ||||
| 3306 | }, | ||||
| 3307 | forward_function); | ||||
| 3308 | |||||
| 3309 | Py_DECREF(num_inputs)_Py_DECREF(((PyObject*)(num_inputs))); | ||||
| 3310 | if (op_outputs_tuple_created) Py_DECREF(op_outputs)_Py_DECREF(((PyObject*)(op_outputs))); | ||||
| 3311 | if (op_inputs_tuple_created) Py_DECREF(op_inputs)_Py_DECREF(((PyObject*)(op_inputs))); | ||||
| 3312 | |||||
| 3313 | if (PyErr_Occurred()) { | ||||
| 3314 | return nullptr; | ||||
| 3315 | } | ||||
| 3316 | |||||
| 3317 |   Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 3318 | } | ||||
| 3319 | |||||
| 3320 | void MaybeNotifyVariableAccessed(PyObject* input) { | ||||
| 3321 |   DCHECK(CheckResourceVariable(input))while (false && (CheckResourceVariable(input))) ::tensorflow ::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 3321);  | ||||
| 3322 |   DCHECK(PyObject_HasAttrString(input, "_trainable"))while (false && (PyObject_HasAttrString(input, "_trainable" ))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 3322);  | ||||
| 3323 | |||||
| 3324 | tensorflow::Safe_PyObjectPtr trainable( | ||||
| 3325 | PyObject_GetAttrString(input, "_trainable")); | ||||
| 3326 | if (trainable.get() == Py_False((PyObject *) &_Py_FalseStruct)) return; | ||||
| 3327 | TFE_Py_TapeVariableAccessed(input); | ||||
| 3328 | TFE_Py_VariableWatcherVariableAccessed(input); | ||||
| 3329 | } | ||||
| 3330 | |||||
| 3331 | bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, | ||||
| 3332 | PyObject* input, tensorflow::Safe_PyObjectPtr* output, | ||||
| 3333 | TF_Status* status) { | ||||
| 3334 | MaybeNotifyVariableAccessed(input); | ||||
| 3335 | |||||
| 3336 | TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status); | ||||
| 3337 | auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); }); | ||||
| 3338 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; | ||||
| 3339 | |||||
| 3340 | TFE_OpSetDevice(op, parent_op_exec_info.device_name, status); | ||||
| 3341 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; | ||||
| 3342 | |||||
| 3343 | // Set dtype | ||||
| 3344 |   DCHECK(PyObject_HasAttrString(input, "_dtype"))while (false && (PyObject_HasAttrString(input, "_dtype" ))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 3344);  | ||||
| 3345 | tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype")); | ||||
| 3346 | int value; | ||||
| 3347 | if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) { | ||||
| 3348 | return false; | ||||
| 3349 | } | ||||
| 3350 | TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value)); | ||||
| 3351 | |||||
| 3352 | // Get handle | ||||
| 3353 | tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle")); | ||||
| 3354 | if (!EagerTensor_CheckExact(handle.get())) return false; | ||||
| 3355 | TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status); | ||||
| 3356 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; | ||||
| 3357 | |||||
| 3358 | int num_retvals = 1; | ||||
| 3359 | TFE_TensorHandle* output_handle; | ||||
| 3360 | TFE_Execute(op, &output_handle, &num_retvals, status); | ||||
| 3361 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; | ||||
| 3362 | |||||
| 3363 | // Always create the py object (and correctly DECREF it) from the returned | ||||
| 3364 | // value, else the data will leak. | ||||
| 3365 | output->reset(EagerTensorFromHandle(output_handle)); | ||||
| 3366 | |||||
| 3367 | // TODO(nareshmodi): Should we run post exec callbacks here? | ||||
| 3368 | if (parent_op_exec_info.run_gradient_callback) { | ||||
| 3369 | tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1)); | ||||
| 3370 | PyTuple_SET_ITEM(inputs.get(), 0, handle.release())PyTuple_SetItem(inputs.get(), 0, handle.release()); | ||||
| 3371 | |||||
| 3372 | tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1)); | ||||
| 3373 | Py_INCREF(output->get())_Py_INCREF(((PyObject*)(output->get()))); // stay alive after since tuple steals. | ||||
| 3374 | PyTuple_SET_ITEM(outputs.get(), 0, output->get())PyTuple_SetItem(outputs.get(), 0, output->get()); | ||||
| 3375 | |||||
| 3376 | tensorflow::Safe_PyObjectPtr op_string( | ||||
| 3377 | GetPythonObjectFromString("ReadVariableOp")); | ||||
| 3378 | if (!RecordGradient(op_string.get(), inputs.get(), Py_None(&_Py_NoneStruct), | ||||
| 3379 | outputs.get())) { | ||||
| 3380 | return false; | ||||
| 3381 | } | ||||
| 3382 | } | ||||
| 3383 | |||||
| 3384 | return true; | ||||
| 3385 | } | ||||
| 3386 | |||||
| 3387 | // Supports 3 cases at the moment: | ||||
| 3388 | // i) input is an EagerTensor. | ||||
| 3389 | // ii) input is a ResourceVariable - in this case, the is_variable param is | ||||
| 3390 | // set to true. | ||||
| 3391 | // iii) input is an arbitrary python list/tuple (note, this handling doesn't | ||||
| 3392 | // support packing). | ||||
| 3393 | // | ||||
| 3394 | // NOTE: dtype_hint_getter must *always* return a PyObject that can be | ||||
| 3395 | // decref'd. So if no hint is found, Py_RETURN_NONE (which correctly | ||||
| 3396 | // increfs Py_None). | ||||
| 3397 | // | ||||
| 3398 | // NOTE: This function sets a python error directly, and returns false. | ||||
| 3399 | // TF_Status is only passed since we don't want to have to reallocate it. | ||||
| 3400 | bool ConvertToTensor( | ||||
| 3401 | const FastPathOpExecInfo& op_exec_info, PyObject* input, | ||||
| 3402 | tensorflow::Safe_PyObjectPtr* output_handle, | ||||
| 3403 | // This gets a hint for this particular input. | ||||
| 3404 | const std::function<tensorflow::DataType()>& dtype_hint_getter, | ||||
| 3405 | // This sets the dtype after conversion is complete. | ||||
| 3406 | const std::function<void(const tensorflow::DataType dtype)>& dtype_setter, | ||||
| 3407 | TF_Status* status) { | ||||
| 3408 | if (EagerTensor_CheckExact(input)) { | ||||
| 3409 | Py_INCREF(input)_Py_INCREF(((PyObject*)(input))); | ||||
| 3410 | output_handle->reset(input); | ||||
| 3411 | return true; | ||||
| 3412 | } else if (CheckResourceVariable(input)) { | ||||
| 3413 | return ReadVariableOp(op_exec_info, input, output_handle, status); | ||||
| 3414 | } | ||||
| 3415 | |||||
| 3416 | // The hint comes from a supposedly similarly typed tensor. | ||||
| 3417 | tensorflow::DataType dtype_hint = dtype_hint_getter(); | ||||
| 3418 | |||||
| 3419 | TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor( | ||||
| 3420 | op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name); | ||||
| 3421 | if (handle == nullptr) { | ||||
| 3422 | return MaybeRaiseExceptionFromTFStatus(status, nullptr); | ||||
| 3423 | } | ||||
| 3424 | |||||
| 3425 | output_handle->reset(EagerTensorFromHandle(handle)); | ||||
| 3426 | dtype_setter( | ||||
| 3427 | static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle))); | ||||
| 3428 | |||||
| 3429 | return true; | ||||
| 3430 | } | ||||
| 3431 | |||||
| 3432 | // Adds input and type attr to the op, and to the list of flattened | ||||
| 3433 | // inputs/attrs. | ||||
| 3434 | bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input, | ||||
| 3435 | const bool add_type_attr, | ||||
| 3436 | const tensorflow::OpDef::ArgDef& input_arg, | ||||
| 3437 | std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs, | ||||
| 3438 | std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs, | ||||
| 3439 | TFE_Op* op, TF_Status* status) { | ||||
| 3440 | // py_eager_tensor's ownership is transferred to flattened_inputs if it is | ||||
| 3441 | // required, else the object is destroyed and DECREF'd when the object goes | ||||
| 3442 | // out of scope in this function. | ||||
| 3443 | tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr; | ||||
| 3444 | |||||
| 3445 | if (!ConvertToTensor( | ||||
| 3446 | *op_exec_info, input, &py_eager_tensor, | ||||
| 3447 | [&]() { | ||||
| 3448 | if (input_arg.type() != tensorflow::DataType::DT_INVALID) { | ||||
| 3449 | return input_arg.type(); | ||||
| 3450 | } | ||||
| 3451 | return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info); | ||||
| 3452 | }, | ||||
| 3453 | [&](const tensorflow::DataType dtype) { | ||||
| 3454 | op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype; | ||||
| 3455 | }, | ||||
| 3456 | status)) { | ||||
| 3457 | return false; | ||||
| 3458 | } | ||||
| 3459 | |||||
| 3460 | TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get()); | ||||
| 3461 | |||||
| 3462 | if (add_type_attr && !input_arg.type_attr().empty()) { | ||||
| 3463 | auto dtype = TFE_TensorHandleDataType(input_handle); | ||||
| 3464 | TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype); | ||||
| 3465 | if (flattened_attrs != nullptr) { | ||||
| 3466 | flattened_attrs->emplace_back( | ||||
| 3467 | GetPythonObjectFromString(input_arg.type_attr())); | ||||
| 3468 | flattened_attrs->emplace_back(PyLong_FromLong(dtype)); | ||||
| 3469 | } | ||||
| 3470 | } | ||||
| 3471 | |||||
| 3472 | if (flattened_inputs != nullptr) { | ||||
| 3473 | flattened_inputs->emplace_back(std::move(py_eager_tensor)); | ||||
| 3474 | } | ||||
| 3475 | |||||
| 3476 | TFE_OpAddInput(op, input_handle, status); | ||||
| 3477 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { | ||||
| 3478 | return false; | ||||
| 3479 | } | ||||
| 3480 | |||||
| 3481 | return true; | ||||
| 3482 | } | ||||
| 3483 | |||||
| 3484 | const char* GetDeviceName(PyObject* py_device_name) { | ||||
| 3485 | if (py_device_name != Py_None(&_Py_NoneStruct)) { | ||||
| 3486 | return TFE_GetPythonString(py_device_name); | ||||
| 3487 | } | ||||
| 3488 | return nullptr; | ||||
| 3489 | } | ||||
| 3490 | |||||
| 3491 | bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) { | ||||
| 3492 | if (!PySequence_Check(seq)) { | ||||
| 3493 | PyErr_SetString(PyExc_TypeError, | ||||
| 3494 | Printf("expected a sequence for attr %s, got %s instead", | ||||
| 3495 | attr_name.data(), seq->ob_type->tp_name) | ||||
| 3496 | .data()); | ||||
| 3497 | |||||
| 3498 | return false; | ||||
| 3499 | } | ||||
| 3500 |   if (PyArray_Check(seq)((((PyObject*)(seq))->ob_type) == (&(*(PyTypeObject *) _tensorflow_numpy_api[2])) || PyType_IsSubtype((((PyObject*)( seq))->ob_type), (&(*(PyTypeObject *)_tensorflow_numpy_api [2])))) &&  | ||||
| 3501 | PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) { | ||||
| 3502 | PyErr_SetString(PyExc_ValueError, | ||||
| 3503 | Printf("expected a sequence for attr %s, got an ndarray " | ||||
| 3504 | "with rank %d instead", | ||||
| 3505 | attr_name.data(), | ||||
| 3506 | PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq))) | ||||
| 3507 | .data()); | ||||
| 3508 | return false; | ||||
| 3509 | } | ||||
| 3510 | return true; | ||||
| 3511 | } | ||||
| 3512 | |||||
| 3513 | bool RunCallbacks( | ||||
| 3514 | const FastPathOpExecInfo& op_exec_info, PyObject* args, | ||||
| 3515 | int num_inferred_attrs, | ||||
| 3516 | const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs, | ||||
| 3517 | const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs, | ||||
| 3518 | PyObject* flattened_result) { | ||||
| 3519 |   DCHECK(op_exec_info.run_callbacks)while (false && (op_exec_info.run_callbacks)) ::tensorflow ::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 3519);  | ||||
| 3520 | |||||
| 3521 | tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size())); | ||||
| 3522 | for (int i = 0; i < flattened_inputs.size(); i++) { | ||||
| 3523 | PyObject* input = flattened_inputs[i].get(); | ||||
| 3524 | Py_INCREF(input)_Py_INCREF(((PyObject*)(input))); | ||||
| 3525 | PyTuple_SET_ITEM(inputs.get(), i, input)PyTuple_SetItem(inputs.get(), i, input); | ||||
| 3526 | } | ||||
| 3527 | |||||
| 3528 |   int num_non_inferred_attrs = PyTuple_GET_SIZE(args)(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(args))))->ob_size) - num_inferred_attrs;  | ||||
| 3529 | int num_attrs = flattened_attrs.size() + num_non_inferred_attrs; | ||||
| 3530 | tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs)); | ||||
| 3531 | |||||
| 3532 | for (int i = 0; i < num_non_inferred_attrs; i++) { | ||||
| 3533 |     auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[num_inferred_attrs + i]);  | ||||
| 3534 | Py_INCREF(attr)_Py_INCREF(((PyObject*)(attr))); | ||||
| 3535 | PyTuple_SET_ITEM(attrs.get(), i, attr)PyTuple_SetItem(attrs.get(), i, attr); | ||||
| 3536 | } | ||||
| 3537 | |||||
| 3538 | for (int i = num_non_inferred_attrs; i < num_attrs; i++) { | ||||
| 3539 | PyObject* attr_or_name = | ||||
| 3540 | flattened_attrs.at(i - num_non_inferred_attrs).get(); | ||||
| 3541 | Py_INCREF(attr_or_name)_Py_INCREF(((PyObject*)(attr_or_name))); | ||||
| 3542 | PyTuple_SET_ITEM(attrs.get(), i, attr_or_name)PyTuple_SetItem(attrs.get(), i, attr_or_name); | ||||
| 3543 | } | ||||
| 3544 | |||||
| 3545 | if (op_exec_info.run_gradient_callback) { | ||||
| 3546 | if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(), | ||||
| 3547 | flattened_result)) { | ||||
| 3548 | return false; | ||||
| 3549 | } | ||||
| 3550 | } | ||||
| 3551 | |||||
| 3552 | if (op_exec_info.run_post_exec_callbacks) { | ||||
| 3553 | tensorflow::Safe_PyObjectPtr callback_args( | ||||
| 3554 | Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(), | ||||
| 3555 | flattened_result, op_exec_info.name)); | ||||
| 3556 | for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) { | ||||
| 3557 | PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i)(((PyListObject *)(op_exec_info.callbacks))->ob_item[i]); | ||||
| 3558 | if (!PyCallable_Check(callback_fn)) { | ||||
| 3559 | PyErr_SetString( | ||||
| 3560 | PyExc_TypeError, | ||||
| 3561 | Printf("expected a function for " | ||||
| 3562 | "post execution callback in index %ld, got %s instead", | ||||
| 3563 | i, callback_fn->ob_type->tp_name) | ||||
| 3564 | .c_str()); | ||||
| 3565 | return false; | ||||
| 3566 | } | ||||
| 3567 | PyObject* callback_result = | ||||
| 3568 | PyObject_CallObject(callback_fn, callback_args.get()); | ||||
| 3569 | if (!callback_result) { | ||||
| 3570 | return false; | ||||
| 3571 | } | ||||
| 3572 | Py_DECREF(callback_result)_Py_DECREF(((PyObject*)(callback_result))); | ||||
| 3573 | } | ||||
| 3574 | } | ||||
| 3575 | |||||
| 3576 | return true; | ||||
| 3577 | } | ||||
| 3578 | |||||
| 3579 | } // namespace | ||||
| 3580 | |||||
| 3581 | PyObject* TFE_Py_FastPathExecute_C(PyObject* args) { | ||||
| 3582 | tensorflow::profiler::TraceMe activity( | ||||
| 3583 | "TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo); | ||||
| 3584 |   Py_ssize_t args_size = PyTuple_GET_SIZE(args)(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(args))))->ob_size);  | ||||
| 3585 | if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) { | ||||
| 3586 | PyErr_SetString( | ||||
| 3587 | PyExc_ValueError, | ||||
| 3588 | Printf("There must be at least %d items in the input tuple.", | ||||
| 3589 | FAST_PATH_EXECUTE_ARG_INPUT_START) | ||||
| 3590 | .c_str()); | ||||
| 3591 | return nullptr; | ||||
| 3592 | } | ||||
| 3593 | |||||
| 3594 | FastPathOpExecInfo op_exec_info; | ||||
| 3595 | |||||
| 3596 | PyObject* py_eager_context = | ||||
| 3597 |       PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[FAST_PATH_EXECUTE_ARG_CONTEXT]);  | ||||
| 3598 | |||||
| 3599 | // TODO(edoper): Use interned string here | ||||
| 3600 | PyObject* eager_context_handle = | ||||
| 3601 | PyObject_GetAttrString(py_eager_context, "_context_handle"); | ||||
| 3602 | |||||
| 3603 | TFE_Context* ctx = reinterpret_cast<TFE_Context*>( | ||||
| 3604 | PyCapsule_GetPointer(eager_context_handle, nullptr)); | ||||
| 3605 | op_exec_info.ctx = ctx; | ||||
| 3606 | op_exec_info.args = args; | ||||
| 3607 | |||||
| 3608 | if (ctx == nullptr) { | ||||
| 3609 | // The context hasn't been initialized. It will be in the slow path. | ||||
| 3610 | RaiseFallbackException( | ||||
| 3611 | "This function does not handle the case of the path where " | ||||
| 3612 | "all inputs are not already EagerTensors."); | ||||
| 3613 | return nullptr; | ||||
| 3614 | } | ||||
| 3615 | |||||
| 3616 | auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context); | ||||
| 3617 | if (tld == nullptr) { | ||||
| 3618 | return nullptr; | ||||
| 3619 | } | ||||
| 3620 | op_exec_info.device_name = GetDeviceName(tld->device_name.get()); | ||||
| 3621 | op_exec_info.callbacks = tld->op_callbacks.get(); | ||||
| 3622 | |||||
| 3623 |   op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[FAST_PATH_EXECUTE_ARG_OP_NAME]);  | ||||
| 3624 |   op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[FAST_PATH_EXECUTE_ARG_NAME]);  | ||||
| 3625 | |||||
| 3626 | // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks | ||||
| 3627 | // (similar to benchmark_tf_gradient_function_*). Also consider using an | ||||
| 3628 | // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks | ||||
| 3629 | // point out problems with heap allocs. | ||||
| 3630 | op_exec_info.run_gradient_callback = | ||||
| 3631 | !*ThreadTapeIsStopped() && HasAccumulatorOrTape(); | ||||
| 3632 | op_exec_info.run_post_exec_callbacks = | ||||
| 3633 | op_exec_info.callbacks != Py_None(&_Py_NoneStruct) && | ||||
| 3634 | PyList_Size(op_exec_info.callbacks) > 0; | ||||
| 3635 | op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || | ||||
| 3636 | op_exec_info.run_post_exec_callbacks; | ||||
| 3637 | |||||
| 3638 | TF_Status* status = GetStatus(); | ||||
| 3639 | const char* op_name = TFE_GetPythonString(op_exec_info.op_name); | ||||
| 3640 | if (op_name == nullptr) { | ||||
| 3641 | PyErr_SetString(PyExc_TypeError, | ||||
| 3642 | Printf("expected a string for op_name, got %s instead", | ||||
| 3643 | op_exec_info.op_name->ob_type->tp_name) | ||||
| 3644 | .c_str()); | ||||
| 3645 | return nullptr; | ||||
| 3646 | } | ||||
| 3647 | |||||
| 3648 | TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status); | ||||
| 3649 | |||||
| 3650 | auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] { | ||||
| 3651 | ReturnStatus(status); | ||||
| 3652 | ReturnOp(ctx, op); | ||||
| 3653 | }); | ||||
| 3654 | |||||
| 3655 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { | ||||
| 3656 | return nullptr; | ||||
| 3657 | } | ||||
| 3658 | |||||
| 3659 | tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace( | ||||
| 3660 | tensorflow::StackTrace::kStackTraceInitialSize)); | ||||
| 3661 | |||||
| 3662 | const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef(); | ||||
| 3663 | if (op_def == nullptr) return nullptr; | ||||
| 3664 | |||||
| 3665 | if (args_size < | ||||
| 3666 | FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) { | ||||
| 3667 | PyErr_SetString( | ||||
| 3668 | PyExc_ValueError, | ||||
| 3669 | Printf("Tuple size smaller than intended. Expected to be at least %d, " | ||||
| 3670 | "was %ld", | ||||
| 3671 | FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(), | ||||
| 3672 | args_size) | ||||
| 3673 | .c_str()); | ||||
| 3674 | return nullptr; | ||||
| 3675 | } | ||||
| 3676 | |||||
| 3677 | if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) { | ||||
| 3678 | RaiseFallbackException( | ||||
| 3679 | "This function does not handle the case of the path where " | ||||
| 3680 | "all inputs are not already EagerTensors."); | ||||
| 3681 | return nullptr; | ||||
| 3682 | } | ||||
| 3683 | |||||
| 3684 | op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def); | ||||
| 3685 | op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def); | ||||
| 3686 | |||||
| 3687 | // Mapping of attr name to size - used to calculate the number of values | ||||
| 3688 | // to be expected by the TFE_Execute run. | ||||
| 3689 | tensorflow::gtl::FlatMap<string, int64_t> attr_list_sizes; | ||||
| 3690 | |||||
| 3691 | // Set non-inferred attrs, including setting defaults if the attr is passed in | ||||
| 3692 | // as None. | ||||
| 3693 | for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(); | ||||
| 3694 | i < args_size; i += 2) { | ||||
| 3695 |     PyObject* py_attr_name = PyTuple_GET_ITEM(args, i)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[i]);  | ||||
| 3696 | const char* attr_name = TFE_GetPythonString(py_attr_name); | ||||
| 3697 |     PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[i + 1]);  | ||||
| 3698 | |||||
| 3699 | // Not creating an index since most of the time there are not more than a | ||||
| 3700 | // few attrs. | ||||
| 3701 | // TODO(nareshmodi): Maybe include the index as part of the | ||||
| 3702 | // OpRegistrationData. | ||||
| 3703 | for (const auto& attr : op_def->attr()) { | ||||
| 3704 | if (tensorflow::StringPiece(attr_name) == attr.name()) { | ||||
| 3705 | SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value, | ||||
| 3706 | &attr_list_sizes, status); | ||||
| 3707 | |||||
| 3708 | if (!status->status.ok()) { | ||||
| 3709 |           VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static const bool vmodule_activated = ::tensorflow::internal::LogMessage ::VmoduleActivated(fname, level); return vmodule_activated; } )(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void )0 : ::tensorflow::internal::Voidifier() & ::tensorflow:: internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc" , 3709, tensorflow::INFO) << "Falling back to slow path for Op \"" << op_def->name()  | ||||
| 3710 | << "\" since we are unable to set the value for attr \"" | ||||
| 3711 | << attr.name() << "\" due to: " << TF_Message(status); | ||||
| 3712 | RaiseFallbackException(TF_Message(status)); | ||||
| 3713 | return nullptr; | ||||
| 3714 | } | ||||
| 3715 | |||||
| 3716 | break; | ||||
| 3717 | } | ||||
| 3718 | } | ||||
| 3719 | } | ||||
| 3720 | |||||
| 3721 | // Flat attrs and inputs as required by the record_gradient call. The attrs | ||||
| 3722 | // here only contain inferred attrs (non-inferred attrs are added directly | ||||
| 3723 | // from the input args). | ||||
| 3724 | // All items in flattened_attrs and flattened_inputs contain | ||||
| 3725 | // Safe_PyObjectPtr - any time something steals a reference to this, it must | ||||
| 3726 | // INCREF. | ||||
| 3727 | // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work | ||||
| 3728 | // directly. | ||||
| 3729 | std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs = | ||||
| 3730 | nullptr; | ||||
| 3731 | std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs = | ||||
| 3732 | nullptr; | ||||
| 3733 | |||||
| 3734 | // TODO(nareshmodi): Encapsulate callbacks information into a struct. | ||||
| 3735 | if (op_exec_info.run_callbacks) { | ||||
| 3736 | flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>); | ||||
| 3737 | flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>); | ||||
| 3738 | } | ||||
| 3739 | |||||
| 3740 | // Add inferred attrs and inputs. | ||||
| 3741 | // The following code might set duplicate type attrs. This will result in | ||||
| 3742 | // the CacheKey for the generated AttrBuilder possibly differing from | ||||
| 3743 | // those where the type attrs are correctly set. Inconsistent CacheKeys | ||||
| 3744 | // for ops means that there might be unnecessarily duplicated kernels. | ||||
| 3745 | // TODO(nareshmodi): Fix this. | ||||
| 3746 | for (int i = 0; i < op_def->input_arg_size(); i++) { | ||||
| 3747 | const auto& input_arg = op_def->input_arg(i); | ||||
| 3748 | |||||
| 3749 | PyObject* input = | ||||
| 3750 |         PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i)(((static_cast<void> (0)), (PyTupleObject *)(args))-> ob_item[FAST_PATH_EXECUTE_ARG_INPUT_START + i]);  | ||||
| 3751 | if (!input_arg.number_attr().empty()) { | ||||
| 3752 | // The item is a homogeneous list. | ||||
| 3753 | if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr; | ||||
| 3754 | tensorflow::Safe_PyObjectPtr fast_input( | ||||
| 3755 | PySequence_Fast(input, "Could not parse sequence.")); | ||||
| 3756 | if (fast_input.get() == nullptr) { | ||||
| 3757 | return nullptr; | ||||
| 3758 | } | ||||
| 3759 |       Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_input.get()))->ob_size)) : (( (PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(fast_input.get()))))->ob_size));  | ||||
| 3760 |       PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_input .get()))->ob_item : ((PyTupleObject *)(fast_input.get()))-> ob_item);  | ||||
| 3761 | |||||
| 3762 | TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len); | ||||
| 3763 | if (op_exec_info.run_callbacks) { | ||||
| 3764 | flattened_attrs->emplace_back( | ||||
| 3765 | GetPythonObjectFromString(input_arg.number_attr())); | ||||
| 3766 | flattened_attrs->emplace_back(PyLong_FromLong(len)); | ||||
| 3767 | } | ||||
| 3768 | attr_list_sizes[input_arg.number_attr()] = len; | ||||
| 3769 | |||||
| 3770 | if (len > 0) { | ||||
| 3771 | // First item adds the type attr. | ||||
| 3772 | if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg, | ||||
| 3773 | flattened_attrs.get(), flattened_inputs.get(), op, | ||||
| 3774 | status)) { | ||||
| 3775 | return nullptr; | ||||
| 3776 | } | ||||
| 3777 | |||||
| 3778 | for (Py_ssize_t j = 1; j < len; j++) { | ||||
| 3779 | // Since the list is homogeneous, we don't need to re-add the attr. | ||||
| 3780 | if (!AddInputToOp(&op_exec_info, fast_input_array[j], false, | ||||
| 3781 | input_arg, nullptr /* flattened_attrs */, | ||||
| 3782 | flattened_inputs.get(), op, status)) { | ||||
| 3783 | return nullptr; | ||||
| 3784 | } | ||||
| 3785 | } | ||||
| 3786 | } | ||||
| 3787 | } else if (!input_arg.type_list_attr().empty()) { | ||||
| 3788 | // The item is a heterogeneous list. | ||||
| 3789 | if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) { | ||||
| 3790 | return nullptr; | ||||
| 3791 | } | ||||
| 3792 | tensorflow::Safe_PyObjectPtr fast_input( | ||||
| 3793 | PySequence_Fast(input, "Could not parse sequence.")); | ||||
| 3794 | if (fast_input.get() == nullptr) { | ||||
| 3795 | return nullptr; | ||||
| 3796 | } | ||||
| 3797 | const string& attr_name = input_arg.type_list_attr(); | ||||
| 3798 |       Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(fast_input.get()))->ob_size)) : (( (PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *)(fast_input.get()))))->ob_size));  | ||||
| 3799 |       PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(fast_input .get()))->ob_item : ((PyTupleObject *)(fast_input.get()))-> ob_item);  | ||||
| 3800 | tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len); | ||||
| 3801 | PyObject* py_attr_value = nullptr; | ||||
| 3802 | if (op_exec_info.run_callbacks) { | ||||
| 3803 | py_attr_value = PyTuple_New(len); | ||||
| 3804 | } | ||||
| 3805 | for (Py_ssize_t j = 0; j < len; j++) { | ||||
| 3806 | PyObject* py_input = fast_input_array[j]; | ||||
| 3807 | tensorflow::Safe_PyObjectPtr py_eager_tensor; | ||||
| 3808 | if (!ConvertToTensor( | ||||
| 3809 | op_exec_info, py_input, &py_eager_tensor, | ||||
| 3810 | []() { return tensorflow::DT_INVALID; }, | ||||
| 3811 | [](const tensorflow::DataType dtype) {}, status)) { | ||||
| 3812 | return nullptr; | ||||
| 3813 | } | ||||
| 3814 | |||||
| 3815 | TFE_TensorHandle* input_handle = | ||||
| 3816 | EagerTensor_Handle(py_eager_tensor.get()); | ||||
| 3817 | |||||
| 3818 | attr_value[j] = TFE_TensorHandleDataType(input_handle); | ||||
| 3819 | |||||
| 3820 | TFE_OpAddInput(op, input_handle, status); | ||||
| 3821 | if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { | ||||
| 3822 | return nullptr; | ||||
| 3823 | } | ||||
| 3824 | |||||
| 3825 | if (op_exec_info.run_callbacks) { | ||||
| 3826 | flattened_inputs->emplace_back(std::move(py_eager_tensor)); | ||||
| 3827 | |||||
| 3828 |           PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]))PyTuple_SetItem(py_attr_value, j, PyLong_FromLong(attr_value[ j]));  | ||||
| 3829 | } | ||||
| 3830 | } | ||||
| 3831 | if (op_exec_info.run_callbacks) { | ||||
| 3832 | flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name)); | ||||
| 3833 | flattened_attrs->emplace_back(py_attr_value); | ||||
| 3834 | } | ||||
| 3835 | TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(), | ||||
| 3836 | attr_value.size()); | ||||
| 3837 | attr_list_sizes[attr_name] = len; | ||||
| 3838 | } else { | ||||
| 3839 | // The item is a single item. | ||||
| 3840 | if (!AddInputToOp(&op_exec_info, input, true, input_arg, | ||||
| 3841 | flattened_attrs.get(), flattened_inputs.get(), op, | ||||
| 3842 | status)) { | ||||
| 3843 | return nullptr; | ||||
| 3844 | } | ||||
| 3845 | } | ||||
| 3846 | } | ||||
| 3847 | |||||
| 3848 | int64_t num_outputs = 0; | ||||
| 3849 | for (int i = 0; i < op_def->output_arg_size(); i++) { | ||||
| 3850 | const auto& output_arg = op_def->output_arg(i); | ||||
| 3851 | int64_t delta = 1; | ||||
| 3852 | if (!output_arg.number_attr().empty()) { | ||||
| 3853 | delta = attr_list_sizes[output_arg.number_attr()]; | ||||
| 3854 | } else if (!output_arg.type_list_attr().empty()) { | ||||
| 3855 | delta = attr_list_sizes[output_arg.type_list_attr()]; | ||||
| 3856 | } | ||||
| 3857 | if (delta < 0) { | ||||
| 3858 | RaiseFallbackException( | ||||
| 3859 | "Attributes suggest that the size of an output list is less than 0"); | ||||
| 3860 | return nullptr; | ||||
| 3861 | } | ||||
| 3862 | num_outputs += delta; | ||||
| 3863 | } | ||||
| 3864 | |||||
| 3865 | // If number of retvals is larger than int32, we error out. | ||||
| 3866 | if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) { | ||||
| 3867 | PyErr_SetString( | ||||
| 3868 | PyExc_ValueError, | ||||
| 3869 | Printf("Number of outputs is too big: %ld", num_outputs).c_str()); | ||||
| 3870 | return nullptr; | ||||
| 3871 | } | ||||
| 3872 | int num_retvals = num_outputs; | ||||
| 3873 | |||||
| 3874 | tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals); | ||||
| 3875 | |||||
| 3876 | Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();; | ||||
| 3877 | TFE_Execute(op, retvals.data(), &num_retvals, status); | ||||
| 3878 | Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); }; | ||||
| 3879 | |||||
| 3880 | if (!status->status.ok()) { | ||||
| 3881 | // Augment the status with the op_name for easier debugging similar to | ||||
| 3882 | // TFE_Py_Execute. | ||||
| 3883 | status->status = tensorflow::errors::CreateWithUpdatedMessage( | ||||
| 3884 | status->status, tensorflow::strings::StrCat( | ||||
| 3885 | TF_Message(status), " [Op:", | ||||
| 3886 | TFE_GetPythonString(op_exec_info.op_name), "]")); | ||||
| 3887 | MaybeRaiseExceptionFromTFStatus(status, nullptr); | ||||
| 3888 | return nullptr; | ||||
| 3889 | } | ||||
| 3890 | |||||
| 3891 | tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals)); | ||||
| 3892 | for (int i = 0; i < num_retvals; ++i) { | ||||
| 3893 |     PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]))PyList_SetItem(flat_result.get(), i, EagerTensorFromHandle(retvals [i]));  | ||||
| 3894 | } | ||||
| 3895 | |||||
| 3896 | if (op_exec_info.run_callbacks) { | ||||
| 3897 | if (!RunCallbacks( | ||||
| 3898 | op_exec_info, args, | ||||
| 3899 | FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(), | ||||
| 3900 | *flattened_inputs, *flattened_attrs, flat_result.get())) { | ||||
| 3901 | return nullptr; | ||||
| 3902 | } | ||||
| 3903 | } | ||||
| 3904 | |||||
| 3905 | // Unflatten results. | ||||
| 3906 | if (op_def->output_arg_size() == 0) { | ||||
| 3907 |     Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 3908 | } | ||||
| 3909 | |||||
| 3910 | if (op_def->output_arg_size() == 1) { | ||||
| 3911 | if (!op_def->output_arg(0).number_attr().empty() || | ||||
| 3912 | !op_def->output_arg(0).type_list_attr().empty()) { | ||||
| 3913 | return flat_result.release(); | ||||
| 3914 | } else { | ||||
| 3915 | auto* result = PyList_GET_ITEM(flat_result.get(), 0)(((PyListObject *)(flat_result.get()))->ob_item[0]); | ||||
| 3916 | Py_INCREF(result)_Py_INCREF(((PyObject*)(result))); | ||||
| 3917 | return result; | ||||
| 3918 | } | ||||
| 3919 | } | ||||
| 3920 | |||||
| 3921 | // Correctly output the results that are made into a namedtuple. | ||||
| 3922 | PyObject* result = PyList_New(op_def->output_arg_size()); | ||||
| 3923 | int flat_result_index = 0; | ||||
| 3924 | for (int i = 0; i < op_def->output_arg_size(); i++) { | ||||
| 3925 | if (!op_def->output_arg(i).number_attr().empty()) { | ||||
| 3926 | int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()]; | ||||
| 3927 | PyObject* inner_list = PyList_New(list_length); | ||||
| 3928 | for (int j = 0; j < list_length; j++) { | ||||
| 3929 |         PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++)(((PyListObject *)(flat_result.get()))->ob_item[flat_result_index ++]);  | ||||
| 3930 | Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj))); | ||||
| 3931 | PyList_SET_ITEM(inner_list, j, obj)PyList_SetItem(inner_list, j, obj); | ||||
| 3932 | } | ||||
| 3933 | PyList_SET_ITEM(result, i, inner_list)PyList_SetItem(result, i, inner_list); | ||||
| 3934 | } else if (!op_def->output_arg(i).type_list_attr().empty()) { | ||||
| 3935 | int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()]; | ||||
| 3936 | PyObject* inner_list = PyList_New(list_length); | ||||
| 3937 | for (int j = 0; j < list_length; j++) { | ||||
| 3938 |         PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++)(((PyListObject *)(flat_result.get()))->ob_item[flat_result_index ++]);  | ||||
| 3939 | Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj))); | ||||
| 3940 | PyList_SET_ITEM(inner_list, j, obj)PyList_SetItem(inner_list, j, obj); | ||||
| 3941 | } | ||||
| 3942 | PyList_SET_ITEM(result, i, inner_list)PyList_SetItem(result, i, inner_list); | ||||
| 3943 | } else { | ||||
| 3944 |       PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++)(((PyListObject *)(flat_result.get()))->ob_item[flat_result_index ++]);  | ||||
| 3945 | Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj))); | ||||
| 3946 | PyList_SET_ITEM(result, i, obj)PyList_SetItem(result, i, obj); | ||||
| 3947 | } | ||||
| 3948 | } | ||||
| 3949 | return result; | ||||
| 3950 | } | ||||
| 3951 | |||||
| 3952 | PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, | ||||
| 3953 | PyObject* attrs, PyObject* results, | ||||
| 3954 | PyObject* forward_pass_name_scope) { | ||||
| 3955 | if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) { | ||||
| 3956 |     Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 3957 | } | ||||
| 3958 | |||||
| 3959 | return RecordGradient(op_name, inputs, attrs, results, | ||||
| 3960 | forward_pass_name_scope); | ||||
| 3961 | } | ||||
| 3962 | |||||
| 3963 | namespace { | ||||
| 3964 | const char kTensor[] = "T"; | ||||
| 3965 | const char kList[] = "L"; | ||||
| 3966 | const char kListEnd[] = "l"; | ||||
| 3967 | const char kTuple[] = "U"; | ||||
| 3968 | const char kTupleEnd[] = "u"; | ||||
| 3969 | const char kDIter[] = "I"; | ||||
| 3970 | const char kDict[] = "D"; | ||||
| 3971 | const char kRaw[] = "R"; | ||||
| 3972 | const char kResourceVariable[] = "r"; | ||||
| 3973 | const char kShape[] = "s"; | ||||
| 3974 | const char kShapeDelim[] = "-"; | ||||
| 3975 | const char kDType[] = "d"; | ||||
| 3976 | const char kNone[] = "n"; | ||||
| 3977 | const char kCompositeTensor[] = "C"; | ||||
| 3978 | const char kAttrs[] = "A"; | ||||
| 3979 | const char kAttrsEnd[] = "a"; | ||||
| 3980 | const char kName[] = "'"; | ||||
| 3981 | const char kNameEnd[] = "'"; | ||||
| 3982 | const char kLocalIdDelim[] = "_"; | ||||
| 3983 | |||||
| 3984 | // Container for storing generated string encoding as well as the raw python | ||||
| 3985 | // objects that were not included in the string. | ||||
| 3986 | struct EncodeResult { | ||||
| 3987 | string str; | ||||
| 3988 | std::vector<PyObject*> objects; | ||||
| 3989 | |||||
| 3990 | PyObject* ToPyTuple() { | ||||
| 3991 | PyObject* result = PyTuple_New(2); | ||||
| 3992 | |||||
| 3993 | PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str))PyTuple_SetItem(result, 0, GetPythonObjectFromString(str)); | ||||
| 3994 | |||||
| 3995 | if (objects.empty()) { | ||||
| 3996 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
| 3997 | PyTuple_SET_ITEM(result, 1, Py_None)PyTuple_SetItem(result, 1, (&_Py_NoneStruct)); | ||||
| 3998 | } else { | ||||
| 3999 | PyObject* objects_tuple = PyTuple_New(objects.size()); | ||||
| 4000 | |||||
| 4001 | for (int i = 0; i < objects.size(); i++) { | ||||
| 4002 | PyTuple_SET_ITEM(objects_tuple, i, objects[i])PyTuple_SetItem(objects_tuple, i, objects[i]); | ||||
| 4003 | } | ||||
| 4004 | |||||
| 4005 | PyTuple_SET_ITEM(result, 1, objects_tuple)PyTuple_SetItem(result, 1, objects_tuple); | ||||
| 4006 | } | ||||
| 4007 | |||||
| 4008 | return result; | ||||
| 4009 | } | ||||
| 4010 | }; | ||||
| 4011 | |||||
| 4012 | // Gives each unique resource_id a unique incremental local_id. Provides a | ||||
| 4013 | // string encoding that informs an order and uniqueness sensitive input | ||||
| 4014 | // signature. | ||||
| 4015 | // This class is not thread safe and is not meant to be shared across threads. | ||||
| 4016 | class LocalResourceIdMap { | ||||
| 4017 | public: | ||||
| 4018 | // When the resource ID is known (such as for OwnedIterator). | ||||
| 4019 | // Returns the existing local ID (if present) or a new unique one. | ||||
| 4020 | int AddResourceId(int resource_id) { | ||||
| 4021 | const auto& it = resource_id_to_local_id_.find(resource_id); | ||||
| 4022 | if (it == resource_id_to_local_id_.end()) { | ||||
| 4023 | resource_id_to_local_id_[resource_id] = next_local_id_; | ||||
| 4024 | return next_local_id_++; | ||||
| 4025 | } else { | ||||
| 4026 | return it->second; | ||||
| 4027 | } | ||||
| 4028 | } | ||||
| 4029 | |||||
| 4030 | // When the resource ID is not known (such as for IteratorSpec). | ||||
| 4031 | // Returns a new unique local ID. | ||||
| 4032 | int AddUnknownResource() { return next_local_id_++; } | ||||
| 4033 | |||||
| 4034 | private: | ||||
| 4035 | absl::flat_hash_map<int, int> resource_id_to_local_id_; | ||||
| 4036 | int next_local_id_ = 0; | ||||
| 4037 | }; | ||||
| 4038 | |||||
| 4039 | // Contains encoding configuration, intermediary data and result. | ||||
| 4040 | struct EncodingContext { | ||||
| 4041 | bool include_tensor_ranks_only; | ||||
| 4042 | bool encode_variable_by_resource_id; | ||||
| 4043 | |||||
| 4044 | LocalResourceIdMap resource_id_map; | ||||
| 4045 | EncodeResult result; | ||||
| 4046 | }; | ||||
| 4047 | |||||
| 4048 | tensorflow::Status EncodeTensorOrTensorSpec(PyObject* arg, bool is_tensor_spec, | ||||
| 4049 | EncodingContext& context) { | ||||
| 4050 | absl::StrAppend(&context.result.str, kTensor); | ||||
| 4051 | |||||
| 4052 |   if (is_tensor_spec
 
  | ||||
| 4053 | tensorflow::Safe_PyObjectPtr name(PyObject_GetAttrString(arg, "name")); | ||||
| 4054 | if (name != nullptr && name.get() != Py_None(&_Py_NoneStruct)) { | ||||
| 4055 | absl::StrAppend(&context.result.str, kName, | ||||
| 4056 | TFE_GetPythonString(name.get()), kNameEnd); | ||||
| 4057 | } | ||||
| 4058 | } | ||||
| 4059 | |||||
| 4060 | tensorflow::Safe_PyObjectPtr dtype_object( | ||||
| 4061 | PyObject_GetAttrString(arg, "dtype")); | ||||
| 4062 | if (dtype_object == nullptr) { | ||||
| 4063 | return tensorflow::errors::InvalidArgument( | ||||
| 4064 | "tf.TensorSpec object doesn't have dtype() attr."); | ||||
| 4065 | } | ||||
| 4066 | |||||
| 4067 | tensorflow::Safe_PyObjectPtr dtype_enum( | ||||
| 4068 | PyObject_GetAttrString(dtype_object.get(), "_type_enum")); | ||||
| 4069 | if (dtype_enum == nullptr) { | ||||
| 4070 | return tensorflow::errors::InvalidArgument( | ||||
| 4071 | "tf.TensorSpec's dtype object doesn't have _type_enum() " | ||||
| 4072 | "attr."); | ||||
| 4073 | } | ||||
| 4074 | |||||
| 4075 | tensorflow::DataType dtype = | ||||
| 4076 | static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get())); | ||||
| 4077 | absl::StrAppend(&context.result.str, kDType, dtype); | ||||
| 4078 | |||||
| 4079 | tensorflow::Safe_PyObjectPtr shape_tuple( | ||||
| 4080 | PyObject_GetAttrString(arg, "shape")); | ||||
| 4081 | if (shape_tuple == nullptr) { | ||||
| 4082 | return tensorflow::errors::InvalidArgument( | ||||
| 4083 | "tf.TensorSpec object doesn't have shape() attr."); | ||||
| 4084 | } | ||||
| 4085 | |||||
| 4086 | tensorflow::Safe_PyObjectPtr rank( | ||||
| 4087 | PyObject_GetAttr(shape_tuple.get(), PyUnicode_FromString("rank"))); | ||||
  | |||||
| 4088 | if (rank == nullptr || rank.get() == Py_None(&_Py_NoneStruct)) { | ||||
| 4089 | // Unknown shape, encode that directly. | ||||
| 4090 | absl::StrAppend(&context.result.str, kNone); | ||||
| 4091 | return tensorflow::Status::OK(); | ||||
| 4092 | } | ||||
| 4093 | |||||
| 4094 | absl::StrAppend(&context.result.str, kShape); | ||||
| 4095 | |||||
| 4096 | tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast( | ||||
| 4097 | shape_tuple.get(), "shape_tuple didn't return a sequence")); | ||||
| 4098 | |||||
| 4099 | int len = MakeInt(rank.get()); | ||||
| 4100 |   PyObject** shape_seq_array = PySequence_Fast_ITEMS(shape_seq.get())(((((((PyObject*)(shape_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(shape_seq .get()))->ob_item : ((PyTupleObject *)(shape_seq.get()))-> ob_item);  | ||||
| 4101 | |||||
| 4102 | if (context.include_tensor_ranks_only) { | ||||
| 4103 | absl::StrAppend(&context.result.str, len); | ||||
| 4104 | } else { | ||||
| 4105 | for (int i = 0; i < len; ++i) { | ||||
| 4106 | // Can be None, int or a Dimension object. | ||||
| 4107 | PyObject* dimension = shape_seq_array[i]; | ||||
| 4108 | |||||
| 4109 | // If it is a Dimension object, then we must extract value from it first. | ||||
| 4110 | bool is_dimension_class = PyObject_HasAttrString(dimension, "value"); | ||||
| 4111 | tensorflow::Safe_PyObjectPtr dimension_holder; | ||||
| 4112 | if (is_dimension_class) { | ||||
| 4113 | dimension_holder = | ||||
| 4114 | tensorflow::make_safe(PyObject_GetAttrString(dimension, "value")); | ||||
| 4115 | dimension = dimension_holder.get(); | ||||
| 4116 | } | ||||
| 4117 | |||||
| 4118 | if (dimension == Py_None(&_Py_NoneStruct)) { | ||||
| 4119 | absl::StrAppend(&context.result.str, kNone); | ||||
| 4120 | } else { | ||||
| 4121 | absl::StrAppend(&context.result.str, MakeInt(dimension), kShapeDelim); | ||||
| 4122 | } | ||||
| 4123 | } | ||||
| 4124 | } | ||||
| 4125 | |||||
| 4126 | return tensorflow::Status::OK(); | ||||
| 4127 | } | ||||
| 4128 | |||||
| 4129 | // TODO(b/199534088): Remove this function by using EncodeResource instead. | ||||
| 4130 | tensorflow::Status EncodeOwnedIterator(PyObject* arg, | ||||
| 4131 | EncodingContext& context) { | ||||
| 4132 | PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec")); | ||||
| 4133 | if (type_spec == nullptr) { | ||||
| 4134 | return tensorflow::errors::InvalidArgument( | ||||
| 4135 | "Error while reading OwnedIterator._type_spec."); | ||||
| 4136 | } | ||||
| 4137 | context.result.objects.push_back(type_spec); | ||||
| 4138 | |||||
| 4139 | // Add resource tracking | ||||
| 4140 | tensorflow::Safe_PyObjectPtr itr_res( | ||||
| 4141 | PyObject_GetAttrString(arg, "_iterator_resource")); | ||||
| 4142 | if (itr_res == nullptr) { | ||||
| 4143 | return tensorflow::errors::InvalidArgument( | ||||
| 4144 | "Error while reading Dataset iterator resource."); | ||||
| 4145 | } | ||||
| 4146 | // OwnedIterator should ideally always provide a unique resource id. | ||||
| 4147 | // TODO(b/199534088) Cases where resource_id is not provided need to be fixed. | ||||
| 4148 | if (tensorflow::swig::IsTensor(itr_res.get())) { | ||||
| 4149 | absl::StrAppend(&context.result.str, kDIter); | ||||
| 4150 | tensorflow::Safe_PyObjectPtr py_resource_id( | ||||
| 4151 | PyObject_GetAttrString(itr_res.get(), "_id")); | ||||
| 4152 | if (py_resource_id == nullptr) { | ||||
| 4153 | return tensorflow::errors::InvalidArgument( | ||||
| 4154 | "Error while reading Dataset iterator resouce id."); | ||||
| 4155 | } | ||||
| 4156 | int resource_id = PyLong_AsSize_t(py_resource_id.get()); | ||||
| 4157 | if (resource_id < 0) { | ||||
| 4158 | return tensorflow::errors::InvalidArgument("PyLong_AsSize_t failure"); | ||||
| 4159 | } | ||||
| 4160 | int local_id = context.resource_id_map.AddResourceId(resource_id); | ||||
| 4161 | absl::StrAppend(&context.result.str, local_id, kLocalIdDelim); | ||||
| 4162 | } else { | ||||
| 4163 | // If '_iterator_resource' is not a Tensor, there is no resource id. | ||||
| 4164 | // Instead we treat it the same way as a CompositeTensor | ||||
| 4165 | absl::StrAppend(&context.result.str, kCompositeTensor); | ||||
| 4166 | } | ||||
| 4167 | return tensorflow::Status::OK(); | ||||
| 4168 | } | ||||
| 4169 | |||||
| 4170 | tensorflow::Status EncodeResource(PyObject* arg, EncodingContext& context) { | ||||
| 4171 | absl::StrAppend(&context.result.str, kResourceVariable); | ||||
| 4172 | tensorflow::Safe_PyObjectPtr py_resource_id( | ||||
| 4173 | PyObject_CallMethod(arg, "__tf_resource_id__", nullptr)); | ||||
| 4174 |   DCHECK(py_resource_id != nullptr)while (false && (py_resource_id != nullptr)) ::tensorflow ::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 4174);  | ||||
| 4175 | |||||
| 4176 | int resource_id = PyLong_AsSize_t(py_resource_id.get()); | ||||
| 4177 |   DCHECK_GE(resource_id, 0)while (false && ((void)(resource_id), (void)(0), 0)) :: tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 4177);  | ||||
| 4178 | int local_id = context.resource_id_map.AddResourceId(resource_id); | ||||
| 4179 | absl::StrAppend(&context.result.str, local_id, kLocalIdDelim); | ||||
| 4180 | |||||
| 4181 | tensorflow::Safe_PyObjectPtr type_spec( | ||||
| 4182 | PyObject_CallMethod(arg, "__tf_function_cache_spec__", nullptr)); | ||||
| 4183 | absl::StrAppend(&context.result.str, PyUnicode_AsUTF8(type_spec.get())); | ||||
| 4184 |   DCHECK(type_spec != nullptr)while (false && (type_spec != nullptr)) ::tensorflow:: internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc" , 4184);  | ||||
| 4185 | |||||
| 4186 | return tensorflow::Status::OK(); | ||||
| 4187 | } | ||||
| 4188 | |||||
| 4189 | tensorflow::Status EncodeArgHelperInternal(PyObject* arg, | ||||
| 4190 | EncodingContext& context); | ||||
| 4191 | |||||
| 4192 | // This function doesn't set the type of sequence before | ||||
| 4193 | tensorflow::Status EncodeSequence(PyObject* arg, const char* type, | ||||
| 4194 | const char* end_type, | ||||
| 4195 | EncodingContext& context) { | ||||
| 4196 | tensorflow::Safe_PyObjectPtr arg_seq( | ||||
| 4197 | PySequence_Fast(arg, "unable to create seq from list/tuple")); | ||||
| 4198 | |||||
| 4199 | absl::StrAppend(&context.result.str, type); | ||||
| 4200 |   int len = PySequence_Fast_GET_SIZE(arg_seq.get())(((((((PyObject*)(arg_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject*)(arg_seq.get()))->ob_size)) : (((PyVarObject *)(((static_cast<void> (0)), (PyTupleObject *)(arg_seq. get()))))->ob_size));  | ||||
| 4201 |   PyObject** arg_seq_array = PySequence_Fast_ITEMS(arg_seq.get())(((((((PyObject*)(arg_seq.get()))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((PyListObject *)(arg_seq.get() ))->ob_item : ((PyTupleObject *)(arg_seq.get()))->ob_item );  | ||||
| 4202 | for (int i = 0; i < len; ++i) { | ||||
| 4203 | PyObject* item = arg_seq_array[i]; | ||||
| 4204 | if (item == Py_None(&_Py_NoneStruct)) { | ||||
| 4205 | absl::StrAppend(&context.result.str, kNone); | ||||
| 4206 | } else { | ||||
| 4207 |       TF_RETURN_IF_ERROR(EncodeArgHelperInternal(item, context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal( item, context)); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0);  | ||||
| 4208 | } | ||||
| 4209 | } | ||||
| 4210 | absl::StrAppend(&context.result.str, end_type); | ||||
| 4211 | |||||
| 4212 | return tensorflow::Status::OK(); | ||||
| 4213 | } | ||||
| 4214 | |||||
| 4215 | tensorflow::Status EncodeMapping(PyObject* arg, EncodingContext& context) { | ||||
| 4216 | tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg)); | ||||
| 4217 | if (PyList_Sort(keys.get()) == -1) { | ||||
| 4218 | return tensorflow::errors::Internal("Unable to sort keys"); | ||||
| 4219 | } | ||||
| 4220 | |||||
| 4221 | absl::StrAppend(&context.result.str, kDict); | ||||
| 4222 | int len = PyList_Size(keys.get()); | ||||
| 4223 | |||||
| 4224 | for (int i = 0; i < len; i++) { | ||||
| 4225 | PyObject* key = PyList_GetItem(keys.get(), i); | ||||
| 4226 |     TF_RETURN_IF_ERROR(EncodeArgHelperInternal(key, context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal( key, context)); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0);  | ||||
| 4227 | tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key)); | ||||
| 4228 |     TF_RETURN_IF_ERROR(EncodeArgHelperInternal(value.get(), context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal( value.get(), context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0);  | ||||
| 4229 | } | ||||
| 4230 | |||||
| 4231 | return tensorflow::Status::OK(); | ||||
| 4232 | } | ||||
| 4233 | |||||
| 4234 | tensorflow::Status EncodeCompositeTensor(PyObject* arg, | ||||
| 4235 | EncodingContext& context) { | ||||
| 4236 | absl::StrAppend(&context.result.str, kCompositeTensor); | ||||
| 4237 | PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec")); | ||||
| 4238 | if (type_spec == nullptr) { | ||||
| 4239 | return tensorflow::errors::InvalidArgument( | ||||
| 4240 | "Error while reading CompositeTensor._type_spec."); | ||||
| 4241 | } | ||||
| 4242 | context.result.objects.push_back(type_spec); | ||||
| 4243 | |||||
| 4244 | return tensorflow::Status::OK(); | ||||
| 4245 | } | ||||
| 4246 | |||||
| 4247 | tensorflow::Status EncodeTypeSpec(PyObject* arg, EncodingContext& context) { | ||||
| 4248 | absl::StrAppend(&context.result.str, kRaw); | ||||
| 4249 | Py_INCREF(arg)_Py_INCREF(((PyObject*)(arg))); | ||||
| 4250 | context.result.objects.push_back(arg); | ||||
| 4251 | return tensorflow::Status::OK(); | ||||
| 4252 | } | ||||
| 4253 | |||||
| 4254 | tensorflow::Status EncodeAttrs(PyObject* arg, EncodingContext& context) { | ||||
| 4255 | absl::StrAppend(&context.result.str, kAttrs); | ||||
| 4256 | tensorflow::Safe_PyObjectPtr attrs( | ||||
| 4257 | PyObject_GetAttrString(arg, "__attrs_attrs__")); | ||||
| 4258 | tensorflow::Safe_PyObjectPtr iter(PyObject_GetIter(attrs.get())); | ||||
| 4259 | for (tensorflow::Safe_PyObjectPtr item(PyIter_Next(iter.get())); item; | ||||
| 4260 | item.reset(PyIter_Next(iter.get()))) { | ||||
| 4261 | tensorflow::Safe_PyObjectPtr name( | ||||
| 4262 | PyObject_GetAttrString(item.get(), "name")); | ||||
| 4263 | tensorflow::Safe_PyObjectPtr attr_arg(PyObject_GetAttr(arg, name.get())); | ||||
| 4264 |     TF_RETURN_IF_ERROR(EncodeArgHelperInternal(attr_arg.get(), context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal( attr_arg.get(), context)); if ((__builtin_expect(!_status.ok( ), 0))) return _status; } while (0);  | ||||
| 4265 | } | ||||
| 4266 | absl::StrAppend(&context.result.str, kAttrsEnd); | ||||
| 4267 | |||||
| 4268 | return tensorflow::Status::OK(); | ||||
| 4269 | } | ||||
| 4270 | |||||
| 4271 | tensorflow::Status EncodeUnidentified(PyObject* arg, EncodingContext& context) { | ||||
| 4272 | // We hold a weak reference because cache keys live practically forever, and | ||||
| 4273 | // this may leak heavy objects. | ||||
| 4274 | PyObject* object = PyWeakref_NewRef(arg, nullptr); | ||||
| 4275 | if (object == nullptr) { | ||||
| 4276 | PyErr_Clear(); | ||||
| 4277 | object = arg; | ||||
| 4278 | Py_INCREF(object)_Py_INCREF(((PyObject*)(object))); | ||||
| 4279 | } | ||||
| 4280 | |||||
| 4281 | absl::StrAppend(&context.result.str, kRaw); | ||||
| 4282 | context.result.objects.push_back(object); | ||||
| 4283 | return tensorflow::Status::OK(); | ||||
| 4284 | } | ||||
| 4285 | |||||
| 4286 | tensorflow::Status EncodeArgHelperInternal(PyObject* arg, | ||||
| 4287 | EncodingContext& context) { | ||||
| 4288 | if (tensorflow::swig::IsTensorSpec(arg)) { | ||||
| 4289 |     TF_RETURN_IF_ERROR(EncodeTensorOrTensorSpec(arg, true, context))do { ::tensorflow::Status _status = (EncodeTensorOrTensorSpec (arg, true, context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0);  | ||||
| 4290 | } else if (tensorflow::swig::IsTensor(arg)) { | ||||
| 4291 |     TF_RETURN_IF_ERROR(EncodeTensorOrTensorSpec(arg, false, context))do { ::tensorflow::Status _status = (EncodeTensorOrTensorSpec (arg, false, context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0);  | ||||
| 4292 | } else if (tensorflow::swig::IsOwnedIterator(arg)) { | ||||
| 4293 |     TF_RETURN_IF_ERROR(EncodeOwnedIterator(arg, context))do { ::tensorflow::Status _status = (EncodeOwnedIterator(arg, context)); if ((__builtin_expect(!_status.ok(), 0))) return _status ; } while (0);  | ||||
| 4294 |   } else if (PyList_Check(arg)((((((PyObject*)(arg))->ob_type))->tp_flags & ((1UL << 25))) != 0)) {  | ||||
| 4295 |     TF_RETURN_IF_ERROR(EncodeSequence(arg, kList, kListEnd, context))do { ::tensorflow::Status _status = (EncodeSequence(arg, kList , kListEnd, context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0);  | ||||
| 4296 | } else if (tensorflow::swig::IsTuple(arg)) { | ||||
| 4297 |     TF_RETURN_IF_ERROR(EncodeSequence(arg, kTuple, kTupleEnd, context))do { ::tensorflow::Status _status = (EncodeSequence(arg, kTuple , kTupleEnd, context)); if ((__builtin_expect(!_status.ok(), 0 ))) return _status; } while (0);  | ||||
| 4298 | } else if (tensorflow::swig::IsMapping(arg)) { | ||||
| 4299 |     TF_RETURN_IF_ERROR(EncodeMapping(arg, context))do { ::tensorflow::Status _status = (EncodeMapping(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0);  | ||||
| 4300 | } else if (tensorflow::swig::IsCompositeTensor(arg)) { | ||||
| 4301 |     TF_RETURN_IF_ERROR(EncodeCompositeTensor(arg, context))do { ::tensorflow::Status _status = (EncodeCompositeTensor(arg , context)); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0);  | ||||
| 4302 | } else if (tensorflow::swig::IsTypeSpec(arg)) { | ||||
| 4303 |     TF_RETURN_IF_ERROR(EncodeTypeSpec(arg, context))do { ::tensorflow::Status _status = (EncodeTypeSpec(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0);  | ||||
| 4304 | } else if (tensorflow::swig::IsAttrs(arg)) { | ||||
| 4305 |     TF_RETURN_IF_ERROR(EncodeAttrs(arg, context))do { ::tensorflow::Status _status = (EncodeAttrs(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0);  | ||||
| 4306 | } else if (tensorflow::swig::IsResourceVariable(arg) && | ||||
| 4307 | context.encode_variable_by_resource_id) { | ||||
| 4308 |     TF_RETURN_IF_ERROR(EncodeResource(arg, context))do { ::tensorflow::Status _status = (EncodeResource(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0);  | ||||
| 4309 | } else { | ||||
| 4310 |     TF_RETURN_IF_ERROR(EncodeUnidentified(arg, context))do { ::tensorflow::Status _status = (EncodeUnidentified(arg, context )); if ((__builtin_expect(!_status.ok(), 0))) return _status; } while (0);  | ||||
| 4311 | } | ||||
| 4312 | |||||
| 4313 | return tensorflow::Status::OK(); | ||||
| 4314 | } | ||||
| 4315 | |||||
| 4316 | tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, | ||||
| 4317 | EncodingContext& context) { | ||||
| 4318 | auto status = EncodeArgHelperInternal(arg, context); | ||||
| 4319 | return status; | ||||
| 4320 | } | ||||
| 4321 | |||||
| 4322 | } // namespace | ||||
| 4323 | |||||
| 4324 | // `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes | ||||
| 4325 | // are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes | ||||
| 4326 | // are used for both performance reasons, as much TensorFlow code specializes | ||||
| 4327 | // on known shapes to produce slimmer graphs, and correctness, as some | ||||
| 4328 | // high-level APIs require shapes to be fully-known. | ||||
| 4329 | // | ||||
| 4330 | // `include_tensor_ranks_only` allows caching on arguments excluding shape info, | ||||
| 4331 | // so that a slow path using relaxed shape can rely on a cache key that excludes | ||||
| 4332 | // shapes. | ||||
| 4333 | PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only, | ||||
| 4334 | bool encode_variable_by_resource_id) { | ||||
| 4335 | EncodingContext context; | ||||
| 4336 | context.include_tensor_ranks_only = include_tensor_ranks_only; | ||||
| 4337 | context.encode_variable_by_resource_id = encode_variable_by_resource_id; | ||||
| 4338 | const auto status = TFE_Py_EncodeArgHelper(arg, context); | ||||
  | |||||
| 4339 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { | ||||
| 4340 | return nullptr; | ||||
| 4341 | } | ||||
| 4342 | |||||
| 4343 | return context.result.ToPyTuple(); | ||||
| 4344 | } | ||||
| 4345 | |||||
| 4346 | // A method prints incoming messages directly to Python's | ||||
| 4347 | // stdout using Python's C API. This is necessary in Jupyter notebooks | ||||
| 4348 | // and colabs where messages to the C stdout don't go to the notebook | ||||
| 4349 | // cell outputs, but calls to Python's stdout do. | ||||
| 4350 | void PrintToPythonStdout(const char* msg) { | ||||
| 4351 | if (Py_IsInitialized()) { | ||||
| 4352 | PyGILState_STATE py_threadstate; | ||||
| 4353 | py_threadstate = PyGILState_Ensure(); | ||||
| 4354 | |||||
| 4355 | string string_msg = msg; | ||||
| 4356 | // PySys_WriteStdout truncates strings over 1000 bytes, so | ||||
| 4357 | // we write the message in chunks small enough to not be truncated. | ||||
| 4358 | int CHUNK_SIZE = 900; | ||||
| 4359 | auto len = string_msg.length(); | ||||
| 4360 | for (int i = 0; i < len; i += CHUNK_SIZE) { | ||||
| 4361 | PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str()); | ||||
| 4362 | } | ||||
| 4363 | |||||
| 4364 | // Force flushing to make sure print newlines aren't interleaved in | ||||
| 4365 | // some colab environments | ||||
| 4366 |     PyRun_SimpleString("import sys; sys.stdout.flush()")PyRun_SimpleStringFlags("import sys; sys.stdout.flush()", __null );  | ||||
| 4367 | |||||
| 4368 | PyGILState_Release(py_threadstate); | ||||
| 4369 | } | ||||
| 4370 | } | ||||
| 4371 | |||||
| 4372 | // Register PrintToPythonStdout as a log listener, to allow | ||||
| 4373 | // printing in colabs and jupyter notebooks to work. | ||||
| 4374 | void TFE_Py_EnableInteractivePythonLogging() { | ||||
| 4375 | static bool enabled_interactive_logging = false; | ||||
| 4376 | if (!enabled_interactive_logging) { | ||||
| 4377 | enabled_interactive_logging = true; | ||||
| 4378 | TF_RegisterLogListener(PrintToPythonStdout); | ||||
| 4379 | } | ||||
| 4380 | } | ||||
| 4381 | |||||
| 4382 | namespace { | ||||
| 4383 | // weak reference to Python Context object currently active | ||||
| 4384 | PyObject* weak_eager_context = nullptr; | ||||
| 4385 | } // namespace | ||||
| 4386 | |||||
| 4387 | PyObject* TFE_Py_SetEagerContext(PyObject* py_context) { | ||||
| 4388 | Py_XDECREF(weak_eager_context)_Py_XDECREF(((PyObject*)(weak_eager_context))); | ||||
| 4389 | weak_eager_context = PyWeakref_NewRef(py_context, nullptr); | ||||
| 4390 | if (weak_eager_context == nullptr) { | ||||
| 4391 | return nullptr; | ||||
| 4392 | } | ||||
| 4393 |   Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct);  | ||||
| 4394 | } | ||||
| 4395 | |||||
| 4396 | PyObject* GetPyEagerContext() { | ||||
| 4397 | if (weak_eager_context == nullptr) { | ||||
| 4398 | PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set"); | ||||
| 4399 | return nullptr; | ||||
| 4400 | } | ||||
| 4401 |   PyObject* py_context = PyWeakref_GET_OBJECT(weak_eager_context)((((PyObject*)(((PyWeakReference *)(weak_eager_context))-> wr_object))->ob_refcnt) > 0 ? ((PyWeakReference *)(weak_eager_context ))->wr_object : (&_Py_NoneStruct));  | ||||
| 4402 | if (py_context == Py_None(&_Py_NoneStruct)) { | ||||
| 4403 | PyErr_SetString(PyExc_RuntimeError, "Eager context has been destroyed"); | ||||
| 4404 | return nullptr; | ||||
| 4405 | } | ||||
| 4406 | Py_INCREF(py_context)_Py_INCREF(((PyObject*)(py_context))); | ||||
| 4407 | return py_context; | ||||
| 4408 | } | ||||
| 4409 | |||||
| 4410 | namespace { | ||||
| 4411 | |||||
| 4412 | // Default values for thread_local_data fields. | ||||
| 4413 | struct EagerContextThreadLocalDataDefaults { | ||||
| 4414 | tensorflow::Safe_PyObjectPtr is_eager; | ||||
| 4415 | tensorflow::Safe_PyObjectPtr device_spec; | ||||
| 4416 | }; | ||||
| 4417 | |||||
| 4418 | // Maps each py_eager_context object to its thread_local_data. | ||||
| 4419 | // | ||||
| 4420 | // Note: we need to use the python Context object as the key here (and not | ||||
| 4421 | // its handle object), because the handle object isn't created until the | ||||
| 4422 | // context is initialized; but thread_local_data is potentially accessed | ||||
| 4423 | // before then. | ||||
| 4424 | using EagerContextThreadLocalDataMap = absl::flat_hash_map< | ||||
| 4425 | PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>; | ||||
| 4426 | thread_local EagerContextThreadLocalDataMap* | ||||
| 4427 | eager_context_thread_local_data_map = nullptr; | ||||
| 4428 | |||||
| 4429 | // Maps each py_eager_context object to default values. | ||||
| 4430 | using EagerContextThreadLocalDataDefaultsMap = | ||||
| 4431 | absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>; | ||||
| 4432 | EagerContextThreadLocalDataDefaultsMap* | ||||
| 4433 | eager_context_thread_local_data_defaults = nullptr; | ||||
| 4434 | |||||
| 4435 | } // namespace | ||||
| 4436 | |||||
| 4437 | namespace tensorflow { | ||||
| 4438 | |||||
| 4439 | void MakeEagerContextThreadLocalData(PyObject* py_eager_context, | ||||
| 4440 | PyObject* is_eager, | ||||
| 4441 | PyObject* device_spec) { | ||||
| 4442 | DCheckPyGilState(); | ||||
| 4443 | if (eager_context_thread_local_data_defaults == nullptr) { | ||||
| 4444 | absl::LeakCheckDisabler disabler; | ||||
| 4445 | eager_context_thread_local_data_defaults = | ||||
| 4446 | new EagerContextThreadLocalDataDefaultsMap(); | ||||
| 4447 | } | ||||
| 4448 | if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) { | ||||
| 4449 | PyErr_SetString(PyExc_AssertionError, | ||||
| 4450 | "MakeEagerContextThreadLocalData may not be called " | ||||
| 4451 | "twice on the same eager Context object."); | ||||
| 4452 | } | ||||
| 4453 | |||||
| 4454 | auto& defaults = | ||||
| 4455 | (*eager_context_thread_local_data_defaults)[py_eager_context]; | ||||
| 4456 | Py_INCREF(is_eager)_Py_INCREF(((PyObject*)(is_eager))); | ||||
| 4457 | defaults.is_eager.reset(is_eager); | ||||
| 4458 | Py_INCREF(device_spec)_Py_INCREF(((PyObject*)(device_spec))); | ||||
| 4459 | defaults.device_spec.reset(device_spec); | ||||
| 4460 | } | ||||
| 4461 | |||||
| 4462 | EagerContextThreadLocalData* GetEagerContextThreadLocalData( | ||||
| 4463 | PyObject* py_eager_context) { | ||||
| 4464 | if (eager_context_thread_local_data_defaults == nullptr) { | ||||
| 4465 | PyErr_SetString(PyExc_AssertionError, | ||||
| 4466 | "MakeEagerContextThreadLocalData must be called " | ||||
| 4467 | "before GetEagerContextThreadLocalData."); | ||||
| 4468 | return nullptr; | ||||
| 4469 | } | ||||
| 4470 | auto defaults = | ||||
| 4471 | eager_context_thread_local_data_defaults->find(py_eager_context); | ||||
| 4472 | if (defaults == eager_context_thread_local_data_defaults->end()) { | ||||
| 4473 | PyErr_SetString(PyExc_AssertionError, | ||||
| 4474 | "MakeEagerContextThreadLocalData must be called " | ||||
| 4475 | "before GetEagerContextThreadLocalData."); | ||||
| 4476 | return nullptr; | ||||
| 4477 | } | ||||
| 4478 | |||||
| 4479 | if (eager_context_thread_local_data_map == nullptr) { | ||||
| 4480 | absl::LeakCheckDisabler disabler; | ||||
| 4481 | eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap(); | ||||
| 4482 | } | ||||
| 4483 | auto& thread_local_data = | ||||
| 4484 | (*eager_context_thread_local_data_map)[py_eager_context]; | ||||
| 4485 | |||||
| 4486 | if (!thread_local_data) { | ||||
| 4487 | thread_local_data.reset(new EagerContextThreadLocalData()); | ||||
| 4488 | |||||
| 4489 | Safe_PyObjectPtr is_eager( | ||||
| 4490 | PyObject_CallFunctionObjArgs(defaults->second.is_eager.get(), nullptr)); | ||||
| 4491 | if (!is_eager) return nullptr; | ||||
| 4492 | thread_local_data->is_eager = PyObject_IsTrue(is_eager.get()); | ||||
| 4493 | |||||
| 4494 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 4495 | PyObject* scope_name = PyUnicode_FromString(""); | ||||
| 4496 | #else | ||||
| 4497 | PyObject* scope_name = PyString_FromString(""); | ||||
| 4498 | #endif | ||||
| 4499 | thread_local_data->scope_name.reset(scope_name); | ||||
| 4500 | |||||
| 4501 | #if PY_MAJOR_VERSION3 >= 3 | ||||
| 4502 | PyObject* device_name = PyUnicode_FromString(""); | ||||
| 4503 | #else | ||||
| 4504 | PyObject* device_name = PyString_FromString(""); | ||||
| 4505 | #endif | ||||
| 4506 | thread_local_data->device_name.reset(device_name); | ||||
| 4507 | |||||
| 4508 |     Py_INCREF(defaults->second.device_spec.get())_Py_INCREF(((PyObject*)(defaults->second.device_spec.get() )));  | ||||
| 4509 | thread_local_data->device_spec.reset(defaults->second.device_spec.get()); | ||||
| 4510 | |||||
| 4511 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
| 4512 | thread_local_data->function_call_options.reset(Py_None(&_Py_NoneStruct)); | ||||
| 4513 | |||||
| 4514 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | ||||
| 4515 | thread_local_data->executor.reset(Py_None(&_Py_NoneStruct)); | ||||
| 4516 | |||||
| 4517 | thread_local_data->op_callbacks.reset(PyList_New(0)); | ||||
| 4518 | } | ||||
| 4519 | return thread_local_data.get(); | ||||
| 4520 | } | ||||
| 4521 | |||||
| 4522 | void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) { | ||||
| 4523 | DCheckPyGilState(); | ||||
| 4524 | if (eager_context_thread_local_data_defaults) { | ||||
| 4525 | eager_context_thread_local_data_defaults->erase(py_eager_context); | ||||
| 4526 | } | ||||
| 4527 | if (eager_context_thread_local_data_map) { | ||||
| 4528 | eager_context_thread_local_data_map->erase(py_eager_context); | ||||
| 4529 | } | ||||
| 4530 | } | ||||
| 4531 | |||||
| 4532 | } // namespace tensorflow | 
| 1 | #ifndef PyUnicode_FromString | 
| 2 | struct _object; | 
| 3 | typedef struct _object PyObject; | 
| 4 | PyObject* clang_analyzer_PyObject_New_Reference(); | 
| 5 | PyObject *PyUnicode_FromString(const char *u) { | 
| 6 | return clang_analyzer_PyObject_New_Reference(); | 
| 7 | } | 
| 8 | #else | 
| 9 | #warning "API PyUnicode_FromString is defined as a macro." | 
| 10 | #endif |