| File: | .cache/bazel/_bazel_alan/39be661231df2a680c9b74265384c13c/execroot/org_tensorflow/tensorflow/python/client/tf_session_wrapper.cc |
| Warning: | line 279, column 15 PyObject ownership leak with reference count of 1 |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
| 1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |||
| 2 | ||||
| 3 | Licensed under the Apache License, Version 2.0 (the "License"); | |||
| 4 | you may not use this file except in compliance with the License. | |||
| 5 | You may obtain a copy of the License at | |||
| 6 | ||||
| 7 | http://www.apache.org/licenses/LICENSE-2.0 | |||
| 8 | ||||
| 9 | Unless required by applicable law or agreed to in writing, software | |||
| 10 | distributed under the License is distributed on an "AS IS" BASIS, | |||
| 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| 12 | See the License for the specific language governing permissions and | |||
| 13 | limitations under the License. | |||
| 14 | ==============================================================================*/ | |||
| 15 | ||||
| 16 | #include "Python.h" | |||
| 17 | #include "absl/types/optional.h" | |||
| 18 | #include "third_party/eigen3/Eigen/Core" | |||
| 19 | #include "pybind11/chrono.h" | |||
| 20 | #include "pybind11/complex.h" | |||
| 21 | #include "pybind11/functional.h" | |||
| 22 | #include "pybind11/pybind11.h" | |||
| 23 | #include "pybind11/stl.h" | |||
| 24 | #include "tensorflow/c/c_api.h" | |||
| 25 | #include "tensorflow/c/c_api_experimental.h" | |||
| 26 | #include "tensorflow/c/c_api_internal.h" | |||
| 27 | #include "tensorflow/c/python_api.h" | |||
| 28 | #include "tensorflow/c/tf_datatype.h" | |||
| 29 | #include "tensorflow/core/distributed_runtime/server_lib.h" | |||
| 30 | #include "tensorflow/core/public/version.h" | |||
| 31 | #include "tensorflow/python/client/tf_session_helper.h" | |||
| 32 | #include "tensorflow/python/lib/core/numpy.h" | |||
| 33 | #include "tensorflow/python/lib/core/pybind11_lib.h" | |||
| 34 | #include "tensorflow/python/lib/core/pybind11_status.h" | |||
| 35 | #include "tensorflow/python/lib/core/safe_ptr.h" | |||
| 36 | ||||
| 37 | namespace pybind11 { | |||
| 38 | namespace detail { | |||
| 39 | // Convert between absl::optional and python. | |||
| 40 | // | |||
| 41 | // pybind11 supports std::optional, and absl::optional is meant to be a | |||
| 42 | // drop-in replacement for std::optional, so we can just use the built in | |||
| 43 | // implementation. | |||
| 44 | #ifndef ABSL_USES_STD_OPTIONAL | |||
| 45 | template <typename T> | |||
| 46 | struct type_caster<absl::optional<T>> | |||
| 47 | : public optional_caster<absl::optional<T>> {}; | |||
| 48 | template <> | |||
| 49 | struct type_caster<absl::nullopt_t> : public void_caster<absl::nullopt_t> {}; | |||
| 50 | #endif | |||
| 51 | ||||
| 52 | } // namespace detail | |||
| 53 | } // namespace pybind11 | |||
| 54 | ||||
| 55 | // TODO(amitpatankar): Consolidate Buffer methods into a separate header file. | |||
| 56 | TF_Buffer* ProtoStringToTFBuffer(PyObject* input) { | |||
| 57 | // Convert a Python string object to TF_Buffer. | |||
| 58 | char* c_string; | |||
| 59 | Py_ssize_t py_size; | |||
| 60 | // PyBytes_AsStringAndSize() does not copy but simply interprets the input | |||
| 61 | if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) { | |||
| 62 | // Python has raised an error (likely TypeError or UnicodeEncodeError). | |||
| 63 | throw py::error_already_set(); | |||
| 64 | } | |||
| 65 | return TF_NewBufferFromString(static_cast<void*>(c_string), | |||
| 66 | static_cast<size_t>(py_size)); | |||
| 67 | } | |||
| 68 | ||||
| 69 | // Copied from tf_session.i | |||
| 70 | // We have to do convoluted logic of passing in a vector of py::bytes. If we | |||
| 71 | // pass in strings they are freed prior to the necessary function calls. | |||
| 72 | tensorflow::NameVector ConvertPyListToNameVector( | |||
| 73 | const std::vector<py::bytes>& py_vector) { | |||
| 74 | tensorflow::NameVector temp; | |||
| 75 | for (size_t i = 0; i < py_vector.size(); ++i) { | |||
| 76 | const char* string_elem = PyBytes_AsString(py_vector.at(i).ptr()); | |||
| 77 | temp.push_back(string_elem); | |||
| 78 | } | |||
| 79 | return temp; | |||
| 80 | } | |||
| 81 | ||||
| 82 | namespace py = pybind11; | |||
| 83 | ||||
| 84 | PYBIND11_MAKE_OPAQUE(TF_Graph)namespace pybind11 { namespace detail { template<> class type_caster<TF_Graph> : public type_caster_base<TF_Graph > { }; }}; | |||
| 85 | PYBIND11_MAKE_OPAQUE(TF_Session)namespace pybind11 { namespace detail { template<> class type_caster<TF_Session> : public type_caster_base<TF_Session > { }; }}; | |||
| 86 | PYBIND11_MAKE_OPAQUE(TF_Operation)namespace pybind11 { namespace detail { template<> class type_caster<TF_Operation> : public type_caster_base< TF_Operation> { }; }}; | |||
| 87 | PYBIND11_MAKE_OPAQUE(TF_Buffer)namespace pybind11 { namespace detail { template<> class type_caster<TF_Buffer> : public type_caster_base<TF_Buffer > { }; }}; | |||
| 88 | PYBIND11_MAKE_OPAQUE(TF_ImportGraphDefOptions)namespace pybind11 { namespace detail { template<> class type_caster<TF_ImportGraphDefOptions> : public type_caster_base <TF_ImportGraphDefOptions> { }; }}; | |||
| 89 | PYBIND11_MAKE_OPAQUE(TF_ImportGraphDefResults)namespace pybind11 { namespace detail { template<> class type_caster<TF_ImportGraphDefResults> : public type_caster_base <TF_ImportGraphDefResults> { }; }}; | |||
| 90 | PYBIND11_MAKE_OPAQUE(TF_DeprecatedSession)namespace pybind11 { namespace detail { template<> class type_caster<TF_DeprecatedSession> : public type_caster_base <TF_DeprecatedSession> { }; }}; | |||
| 91 | PYBIND11_MAKE_OPAQUE(TF_OperationDescription)namespace pybind11 { namespace detail { template<> class type_caster<TF_OperationDescription> : public type_caster_base <TF_OperationDescription> { }; }}; | |||
| 92 | PYBIND11_MAKE_OPAQUE(TF_Library)namespace pybind11 { namespace detail { template<> class type_caster<TF_Library> : public type_caster_base<TF_Library > { }; }}; | |||
| 93 | PYBIND11_MAKE_OPAQUE(TF_SessionOptions)namespace pybind11 { namespace detail { template<> class type_caster<TF_SessionOptions> : public type_caster_base <TF_SessionOptions> { }; }}; | |||
| 94 | PYBIND11_MAKE_OPAQUE(TF_ApiDefMap)namespace pybind11 { namespace detail { template<> class type_caster<TF_ApiDefMap> : public type_caster_base< TF_ApiDefMap> { }; }}; | |||
| 95 | PYBIND11_MAKE_OPAQUE(TF_Server)namespace pybind11 { namespace detail { template<> class type_caster<TF_Server> : public type_caster_base<TF_Server > { }; }}; | |||
| 96 | PYBIND11_MAKE_OPAQUE(TF_DeviceList)namespace pybind11 { namespace detail { template<> class type_caster<TF_DeviceList> : public type_caster_base< TF_DeviceList> { }; }}; | |||
| 97 | PYBIND11_MAKE_OPAQUE(TF_Status)namespace pybind11 { namespace detail { template<> class type_caster<TF_Status> : public type_caster_base<TF_Status > { }; }}; | |||
| 98 | ||||
| 99 | PYBIND11_MODULE(_pywrap_tf_session, m)static ::pybind11::module_::module_def pybind11_module_def__pywrap_tf_session ; __attribute__ ((__unused__)) static void pybind11_init__pywrap_tf_session (::pybind11::module_ &); extern "C" __attribute__ ((__unused__ )) __attribute__ ((visibility("default"))) PyObject *PyInit__pywrap_tf_session (); extern "C" __attribute__ ((visibility("default"))) PyObject *PyInit__pywrap_tf_session() { { const char *compiled_ver = "3" "." "8"; const char *runtime_ver = Py_GetVersion(); size_t len = std::strlen(compiled_ver); if (std::strncmp(runtime_ver, compiled_ver , len) != 0 || (runtime_ver[len] >= '0' && runtime_ver [len] <= '9')) { PyErr_Format(PyExc_ImportError, "Python version mismatch: module was compiled for Python %s, " "but the interpreter version is incompatible: %s.", compiled_ver , runtime_ver); return nullptr; } } pybind11::detail::get_internals (); auto m = ::pybind11::module_::create_extension_module( "_pywrap_tf_session" , nullptr, &pybind11_module_def__pywrap_tf_session); try { pybind11_init__pywrap_tf_session(m); return m.ptr(); } catch (pybind11::error_already_set &e) { PyErr_SetString(PyExc_ImportError , e.what()); return nullptr; } catch (const std::exception & e) { PyErr_SetString(PyExc_ImportError, e.what()); return nullptr ; } } void pybind11_init__pywrap_tf_session(::pybind11::module_ &m) { | |||
| 100 | // Numpy initialization code for array checks. | |||
| 101 | tensorflow::ImportNumpy(); | |||
| 102 | ||||
| 103 | py::class_<TF_Graph> TF_Graph_class(m, "TF_Graph"); | |||
| 104 | py::class_<TF_Operation> TF_Operation_class(m, "TF_Operation"); | |||
| 105 | ||||
| 106 | py::class_<TF_Output>(m, "TF_Output") | |||
| 107 | .def(py::init<>()) | |||
| 108 | .def_readwrite("oper", &TF_Output::oper) | |||
| 109 | .def_readwrite("index", &TF_Output::index); | |||
| 110 | ||||
| 111 | py::class_<TF_Input>(m, "TF_Input") | |||
| 112 | .def(py::init<>()) | |||
| 113 | .def_readwrite("oper", &TF_Input::oper) | |||
| 114 | .def_readwrite("index", &TF_Input::index); | |||
| 115 | ||||
| 116 | py::class_<TF_ImportGraphDefOptions> TF_ImportGraphDefOptions_class( | |||
| 117 | m, "TF_ImportGraphDefOptions"); | |||
| 118 | py::class_<TF_ImportGraphDefResults> TF_ImportGraphDefResults_class( | |||
| 119 | m, "TF_ImportGraphDefResults"); | |||
| 120 | py::class_<TF_DeprecatedSession> TF_DeprecatedSession_class( | |||
| 121 | m, "TF_DeprecatedSession"); | |||
| 122 | py::class_<TF_Session> TF_Session_class(m, "TF_Session"); | |||
| 123 | py::class_<TF_OperationDescription> TF_OperationDescription_class( | |||
| 124 | m, "TF_OperationDescription"); | |||
| 125 | py::class_<TF_Library> TF_Library_class(m, "TF_Library"); | |||
| 126 | py::class_<TF_SessionOptions> TF_SessionOptions_class(m, "TF_SessionOptions"); | |||
| 127 | py::class_<TF_Buffer> TF_Buffer_class(m, "TF_Buffer"); | |||
| 128 | py::class_<TF_ApiDefMap> TF_ApiDefMap_class(m, "TF_ApiDefMap"); | |||
| 129 | py::class_<TF_Server> TF_Server_class(m, "TF_Server"); | |||
| 130 | py::class_<TF_Status> TF_Status_class(m, "TF_Status"); | |||
| 131 | ||||
| 132 | // We only release the Python GIL for certain methods that are | |||
| 133 | // not explicitly marked. We disable this behavior for some functions | |||
| 134 | // because they uses Python method(s) that expect the GIL to be held | |||
| 135 | // (at least PyArray_Return, maybe others). | |||
| 136 | ||||
| 137 | // Do not release GIL. | |||
| 138 | m.def("TF_OperationGetControlInputs_wrapper", | |||
| 139 | tensorflow::TF_OperationGetControlInputs_wrapper); | |||
| 140 | // Do not release GIL. | |||
| 141 | m.def("TF_OperationGetControlOutputs_wrapper", | |||
| 142 | tensorflow::TF_OperationGetControlOutputs_wrapper); | |||
| 143 | m.def("TF_OperationOutputConsumers_wrapper", | |||
| 144 | tensorflow::TF_OperationOutputConsumers_wrapper); | |||
| 145 | // Do not release GIL. | |||
| 146 | m.def("GetOperationInputs", tensorflow::GetOperationInputs); | |||
| 147 | // Do not release GIL. | |||
| 148 | m.def("TF_ImportGraphDefOptionsSetValidateColocationConstraints", | |||
| 149 | TF_ImportGraphDefOptionsSetValidateColocationConstraints); | |||
| 150 | // Do not release GIL. | |||
| 151 | m.def("TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper", | |||
| 152 | tensorflow::TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper); | |||
| 153 | m.def("TF_SessionMakeCallable", | |||
| 154 | [](TF_Session* session, const TF_Buffer* callable_options) { | |||
| 155 | int64_t out_handle; | |||
| 156 | tensorflow::Safe_TF_StatusPtr status = | |||
| 157 | tensorflow::make_safe(TF_NewStatus()); | |||
| 158 | ||||
| 159 | // Release GIL. | |||
| 160 | py::gil_scoped_release release; | |||
| 161 | tensorflow::TF_SessionMakeCallable(session, callable_options, | |||
| 162 | &out_handle, status.get()); | |||
| 163 | ||||
| 164 | // Acquire GIL for returning int conversion. | |||
| 165 | pybind11::gil_scoped_acquire acquire; | |||
| 166 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 167 | return out_handle; | |||
| 168 | }); | |||
| 169 | m.def("_TF_SetTarget", TF_SetTarget); | |||
| 170 | m.def("_TF_SetConfig", [](TF_SessionOptions* options, py::bytes proto) { | |||
| 171 | tensorflow::Safe_TF_StatusPtr status = | |||
| 172 | tensorflow::make_safe(TF_NewStatus()); | |||
| 173 | tensorflow::Safe_TF_BufferPtr buf = | |||
| 174 | tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); | |||
| 175 | TF_SetConfig(options, buf.get()->data, buf.get()->length, status.get()); | |||
| 176 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 177 | }); | |||
| 178 | m.def("_TF_NewSessionOptions", TF_NewSessionOptions, | |||
| 179 | py::return_value_policy::reference, | |||
| 180 | py::call_guard<py::gil_scoped_release>()); | |||
| 181 | m.def("TF_DeleteSessionOptions", TF_DeleteSessionOptions, | |||
| 182 | py::call_guard<py::gil_scoped_release>()); | |||
| 183 | ||||
| 184 | m.def("EqualGraphDefWrapper", tensorflow::EqualGraphDefWrapper, | |||
| 185 | py::call_guard<py::gil_scoped_release>()); | |||
| 186 | m.def("EqualAttrValueWrapper", tensorflow::EqualAttrValueWrapper, | |||
| 187 | py::call_guard<py::gil_scoped_release>()); | |||
| 188 | ||||
| 189 | m.def( | |||
| 190 | "TF_GraphToFunction_wrapper", | |||
| 191 | [](const TF_Graph* fn_body, const char* fn_name, | |||
| 192 | bool append_hash_to_fn_name, | |||
| 193 | absl::optional<std::vector<TF_Operation*>> opers_opt, | |||
| 194 | const std::vector<TF_Output>& inputs, | |||
| 195 | const std::vector<TF_Output>& outputs, | |||
| 196 | const std::vector<py::bytes> output_names, | |||
| 197 | const std::vector<TF_Operation*> control_outputs, | |||
| 198 | const std::vector<py::bytes> control_output_names, py::none opts, | |||
| 199 | const char* description) { | |||
| 200 | tensorflow::Safe_TF_StatusPtr status = | |||
| 201 | tensorflow::make_safe(TF_NewStatus()); | |||
| 202 | ||||
| 203 | // TODO(b/147674626): Use pybind11 list_caster instead. | |||
| 204 | tensorflow::NameVector output_names_name_vector = | |||
| 205 | ConvertPyListToNameVector(output_names); | |||
| 206 | ||||
| 207 | // TODO(b/147674626): Use pybind11 list_caster instead. | |||
| 208 | tensorflow::NameVector control_output_names_name_vector = | |||
| 209 | ConvertPyListToNameVector(control_output_names); | |||
| 210 | ||||
| 211 | // Release GIL. | |||
| 212 | py::gil_scoped_release release; | |||
| 213 | auto output = tensorflow::TF_GraphToFunction_wrapper( | |||
| 214 | fn_body, fn_name, append_hash_to_fn_name, | |||
| 215 | opers_opt.has_value() ? &opers_opt.value() : nullptr, inputs, | |||
| 216 | outputs, output_names_name_vector, &control_outputs, | |||
| 217 | control_output_names_name_vector, | |||
| 218 | /*opts=*/nullptr, description, status.get()); | |||
| 219 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 220 | return output; | |||
| 221 | }, | |||
| 222 | py::return_value_policy::reference); | |||
| 223 | ||||
| 224 | m.def("TF_GraphGetTensorShapeHelper", [](TF_Graph* graph, TF_Output output) { | |||
| 225 | tensorflow::Safe_TF_StatusPtr status = | |||
| 226 | tensorflow::make_safe(TF_NewStatus()); | |||
| 227 | bool unknown_shape; | |||
| 228 | ||||
| 229 | auto result = tensorflow::TF_GraphGetTensorShapeHelper( | |||
| 230 | graph, output, status.get(), &unknown_shape); | |||
| 231 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 232 | ||||
| 233 | // Create a python list from InlinedVector | |||
| 234 | py::list py_list; | |||
| 235 | for (size_t i = 0; i < result.size(); ++i) { | |||
| 236 | py_list.append(py::cast(result[i])); | |||
| 237 | } | |||
| 238 | ||||
| 239 | // Return a tuple. | |||
| 240 | py::tuple result_tuple = py::make_tuple(py_list, py::cast(unknown_shape)); | |||
| 241 | return result_tuple; | |||
| 242 | }); | |||
| 243 | ||||
| 244 | m.def("TF_GraphSetTensorShape_wrapper", | |||
| 245 | [](TF_Graph* graph, TF_Output output, const std::vector<int64_t>& dims, | |||
| 246 | bool unknown_shape) { | |||
| 247 | tensorflow::Safe_TF_StatusPtr status = | |||
| 248 | tensorflow::make_safe(TF_NewStatus()); | |||
| 249 | ||||
| 250 | // Release GIL. | |||
| 251 | py::gil_scoped_release release; | |||
| 252 | tensorflow::TF_GraphSetTensorShape_wrapper( | |||
| 253 | graph, output, dims, unknown_shape, status.get()); | |||
| 254 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 255 | }); | |||
| 256 | ||||
| 257 | m.def("TF_GraphGetTensorShape_wrapper", | |||
| 258 | [](TF_Graph* graph, TF_Output output, const std::vector<int64_t>& dims, | |||
| 259 | bool unknown_shape) { | |||
| 260 | tensorflow::Safe_TF_StatusPtr status = | |||
| 261 | tensorflow::make_safe(TF_NewStatus()); | |||
| 262 | // Release GIL. | |||
| 263 | py::gil_scoped_release release; | |||
| 264 | tensorflow::TF_GraphSetTensorShape_wrapper( | |||
| 265 | graph, output, dims, unknown_shape, status.get()); | |||
| 266 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 267 | }); | |||
| 268 | ||||
| 269 | m.def("TF_GraphSetOutputHandleShapesAndTypes_wrapper", | |||
| 270 | [](TF_Graph* graph, TF_Output output, | |||
| 271 | const std::vector<absl::optional<std::vector<int64_t>>>& shapes, | |||
| 272 | const std::vector<int>& ranks, py::handle& types) { | |||
| 273 | tensorflow::Safe_TF_StatusPtr status = | |||
| 274 | tensorflow::make_safe(TF_NewStatus()); | |||
| 275 | ||||
| 276 | // Cast types | |||
| 277 | std::vector<TF_DataType> types_local; | |||
| 278 | PyObject* seq = | |||
| 279 | PySequence_Fast(types.ptr(), "$symname: expected list"); | |||
| ||||
| ||||
| 280 | if (seq == nullptr) { | |||
| 281 | PyErr_SetString(PyExc_RuntimeError, | |||
| 282 | "$symname: PySequence_Fast returned NULL."); | |||
| 283 | throw py::error_already_set(); | |||
| 284 | } | |||
| 285 | ||||
| 286 | int size = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject *)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void > (0)), (PyTupleObject *)(seq))))->ob_size)); | |||
| 287 | if (size == 0) { | |||
| 288 | PyErr_SetString(PyExc_ValueError, | |||
| 289 | "$symname: shapes list must be non-empty"); | |||
| 290 | throw py::error_already_set(); | |||
| 291 | } | |||
| 292 | ||||
| 293 | for (int i = 0; i < size; ++i) { | |||
| 294 | PyObject* item = PySequence_Fast_GET_ITEM(seq, i)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL << 25))) != 0) ? (((PyListObject *)(seq))->ob_item[ i]) : (((static_cast<void> (0)), (PyTupleObject *)(seq) )->ob_item[i])); | |||
| 295 | types_local.push_back((TF_DataType)PyLong_AsLong(item)); | |||
| 296 | } | |||
| 297 | ||||
| 298 | // Convert shapes nested vector | |||
| 299 | std::vector<std::vector<int64_t>> shapes_local; | |||
| 300 | for (size_t i = 0; i < shapes.size(); ++i) { | |||
| 301 | std::vector<int64_t> dims; | |||
| 302 | std::vector<int64_t> item = | |||
| 303 | shapes[i].has_value() ? shapes[i].value() : dims; | |||
| 304 | shapes_local.push_back(item); | |||
| 305 | } | |||
| 306 | ||||
| 307 | Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq))); | |||
| 308 | ||||
| 309 | tensorflow::TF_GraphSetOutputHandleShapesAndTypes_wrapper( | |||
| 310 | graph, output, shapes_local, ranks, types_local, status.get()); | |||
| 311 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 312 | }); | |||
| 313 | ||||
| 314 | // Do not release GIL. | |||
| 315 | m.def("TF_CreatePlaceholders", | |||
| 316 | [](TF_Graph* graph, py::handle& dtypes, const char* prefix) { | |||
| 317 | tensorflow::Safe_TF_StatusPtr status = | |||
| 318 | tensorflow::make_safe(TF_NewStatus()); | |||
| 319 | auto output = tensorflow::TF_CreatePlaceholders(graph, dtypes.ptr(), | |||
| 320 | prefix, status.get()); | |||
| 321 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 322 | return output; | |||
| 323 | }); | |||
| 324 | ||||
| 325 | m.def( | |||
| 326 | "TF_NewSession", | |||
| 327 | [](TF_Graph* graph, const TF_SessionOptions* opts) { | |||
| 328 | tensorflow::Safe_TF_StatusPtr status = | |||
| 329 | tensorflow::make_safe(TF_NewStatus()); | |||
| 330 | // Release GIL. | |||
| 331 | py::gil_scoped_release release; | |||
| 332 | auto output = TF_NewSession(graph, opts, status.get()); | |||
| 333 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 334 | return output; | |||
| 335 | }, | |||
| 336 | py::return_value_policy::reference); | |||
| 337 | ||||
| 338 | m.def( | |||
| 339 | "TF_NewSessionRef", | |||
| 340 | [](TF_Graph* graph, const TF_SessionOptions* opts) { | |||
| 341 | tensorflow::Safe_TF_StatusPtr status = | |||
| 342 | tensorflow::make_safe(TF_NewStatus()); | |||
| 343 | // Release GIL. | |||
| 344 | py::gil_scoped_release release; | |||
| 345 | auto output = tensorflow::TF_NewSessionRef(graph, opts, status.get()); | |||
| 346 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 347 | return output; | |||
| 348 | }, | |||
| 349 | py::return_value_policy::reference); | |||
| 350 | ||||
| 351 | m.def("TF_CloseSession", [](TF_Session* session) { | |||
| 352 | tensorflow::Safe_TF_StatusPtr status = | |||
| 353 | tensorflow::make_safe(TF_NewStatus()); | |||
| 354 | ||||
| 355 | // Release GIL. | |||
| 356 | py::gil_scoped_release release; | |||
| 357 | TF_CloseSession(session, status.get()); | |||
| 358 | ||||
| 359 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 360 | }); | |||
| 361 | ||||
| 362 | m.def("TF_DeleteSession", [](TF_Session* session) { | |||
| 363 | tensorflow::Safe_TF_StatusPtr status = | |||
| 364 | tensorflow::make_safe(TF_NewStatus()); | |||
| 365 | // Release GIL. | |||
| 366 | py::gil_scoped_release release; | |||
| 367 | TF_DeleteSession(session, status.get()); | |||
| 368 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 369 | }); | |||
| 370 | ||||
| 371 | m.def("SetRequireShapeInferenceFns", tensorflow::SetRequireShapeInferenceFns); | |||
| 372 | ||||
| 373 | // Do not release GIL. | |||
| 374 | m.def("TF_TryEvaluateConstant_wrapper", | |||
| 375 | [](TF_Graph* graph, const TF_Output output) { | |||
| 376 | tensorflow::Safe_TF_StatusPtr status = | |||
| 377 | tensorflow::make_safe(TF_NewStatus()); | |||
| 378 | auto result = tensorflow::TF_TryEvaluateConstant_wrapper( | |||
| 379 | graph, output, status.get()); | |||
| 380 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 381 | return tensorflow::PyoOrThrow(result); | |||
| 382 | }); | |||
| 383 | ||||
| 384 | m.def("ExtendSession", [](TF_Session* session) { | |||
| 385 | tensorflow::Safe_TF_StatusPtr status = | |||
| 386 | tensorflow::make_safe(TF_NewStatus()); | |||
| 387 | // Release GIL for threading. | |||
| 388 | pybind11::gil_scoped_release release; | |||
| 389 | tensorflow::ExtendSession(session, status.get()); | |||
| 390 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 391 | }); | |||
| 392 | ||||
| 393 | m.def("GetHandleShapeAndType", [](TF_Graph* graph, TF_Output output) { | |||
| 394 | std::string output_string = | |||
| 395 | tensorflow::GetHandleShapeAndType(graph, output); | |||
| 396 | // Override default py3 behavior of attempting to encode into Unicode as | |||
| 397 | // the dependent functions expect bytes. | |||
| 398 | return py::bytes(output_string); | |||
| 399 | }); | |||
| 400 | ||||
| 401 | m.def("SetHandleShapeAndType", | |||
| 402 | [](TF_Graph* graph, TF_Output output, py::bytes proto) { | |||
| 403 | tensorflow::Safe_TF_StatusPtr status = | |||
| 404 | tensorflow::make_safe(TF_NewStatus()); | |||
| 405 | tensorflow::Safe_TF_BufferPtr buf = | |||
| 406 | tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); | |||
| 407 | tensorflow::SetHandleShapeAndType(graph, output, buf.get()->data, | |||
| 408 | buf.get()->length, status.get()); | |||
| 409 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 410 | }); | |||
| 411 | ||||
| 412 | // Do not release GIL. | |||
| 413 | m.def("TF_SessionRun_wrapper", [](TF_Session* session, TF_Buffer* run_options, | |||
| 414 | const py::handle& input_dict, | |||
| 415 | const std::vector<TF_Output>& outputs, | |||
| 416 | const std::vector<TF_Operation*>& targets, | |||
| 417 | TF_Buffer* run_metadata) { | |||
| 418 | // Convert inputs dictionary | |||
| 419 | std::vector<TF_Output> inputs; | |||
| 420 | std::vector<PyObject*> input_ndarrays; | |||
| 421 | if (!PyDict_Check(input_dict.ptr())((((((PyObject*)(input_dict.ptr()))->ob_type))->tp_flags & ((1UL << 29))) != 0)) { | |||
| 422 | PyErr_SetString( | |||
| 423 | PyExc_TypeError, | |||
| 424 | "Expected a dictionary as an argument to TF_SessionRun_wrapper."); | |||
| 425 | throw py::error_already_set(); | |||
| 426 | } | |||
| 427 | PyObject* key; | |||
| 428 | PyObject* value; | |||
| 429 | Py_ssize_t pos = 0; | |||
| 430 | while (PyDict_Next(input_dict.ptr(), &pos, &key, &value)) { | |||
| 431 | TF_Output item = py::cast<TF_Output>(key); | |||
| 432 | inputs.push_back(item); | |||
| 433 | ||||
| 434 | // TODO(amitpatankar): Fix this PyArray check. (b/147855599) | |||
| 435 | ||||
| 436 | // if (!PyArray_Check(value)) { | |||
| 437 | // PyErr_SetString( | |||
| 438 | // PyExc_TypeError, | |||
| 439 | // "$symname: Expected all values in input dict to be ndarray."); | |||
| 440 | // throw py::error_already_set(); | |||
| 441 | // } | |||
| 442 | input_ndarrays.push_back(value); | |||
| 443 | } | |||
| 444 | ||||
| 445 | tensorflow::Safe_TF_StatusPtr status = | |||
| 446 | tensorflow::make_safe(TF_NewStatus()); | |||
| 447 | std::vector<PyObject*> py_outputs; | |||
| 448 | tensorflow::TF_SessionRun_wrapper(session, run_options, inputs, | |||
| 449 | input_ndarrays, outputs, targets, | |||
| 450 | run_metadata, status.get(), &py_outputs); | |||
| 451 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 452 | ||||
| 453 | // Create a Python list using the C API rather than py::list. b/147855599 | |||
| 454 | PyObject* result = PyList_New(py_outputs.size()); | |||
| 455 | if (result == nullptr) { | |||
| 456 | PyErr_SetString(PyExc_MemoryError, "Failed to create a list."); | |||
| 457 | throw py::error_already_set(); | |||
| 458 | } | |||
| 459 | for (size_t i = 0; i < py_outputs.size(); ++i) { | |||
| 460 | PyList_SET_ITEM(result, i, py_outputs.at(i))PyList_SetItem(result, i, py_outputs.at(i)); | |||
| 461 | } | |||
| 462 | ||||
| 463 | return tensorflow::PyoOrThrow(result); | |||
| 464 | }); | |||
| 465 | ||||
| 466 | // Do not release GIL. | |||
| 467 | m.def("TF_SessionPRun_wrapper", [](TF_Session* session, const char* handle, | |||
| 468 | const py::handle& input_dict, | |||
| 469 | const std::vector<TF_Output>& outputs) { | |||
| 470 | // Convert inputs dictionary | |||
| 471 | std::vector<TF_Output> inputs; | |||
| 472 | std::vector<PyObject*> input_ndarrays; | |||
| 473 | if (!PyDict_Check(input_dict.ptr())((((((PyObject*)(input_dict.ptr()))->ob_type))->tp_flags & ((1UL << 29))) != 0)) { | |||
| 474 | PyErr_SetString( | |||
| 475 | PyExc_TypeError, | |||
| 476 | "Expected a dictionary as an argument to TF_SessionPRun_wrapper."); | |||
| 477 | throw py::error_already_set(); | |||
| 478 | } | |||
| 479 | PyObject* key; | |||
| 480 | PyObject* value; | |||
| 481 | Py_ssize_t pos = 0; | |||
| 482 | while (PyDict_Next(input_dict.ptr(), &pos, &key, &value)) { | |||
| 483 | TF_Output item = py::cast<TF_Output>(key); | |||
| 484 | inputs.push_back(item); | |||
| 485 | ||||
| 486 | // TODO(amitpatankar): Fix this PyArray check. (b/147855599) | |||
| 487 | ||||
| 488 | // if (!PyArray_Check(value)) { | |||
| 489 | // PyErr_SetString( | |||
| 490 | // PyExc_TypeError, | |||
| 491 | // "$symname: Expected all values in input dict to be ndarray."); | |||
| 492 | // throw py::error_already_set(); | |||
| 493 | // } | |||
| 494 | input_ndarrays.push_back(value); | |||
| 495 | } | |||
| 496 | ||||
| 497 | tensorflow::Safe_TF_StatusPtr status = | |||
| 498 | tensorflow::make_safe(TF_NewStatus()); | |||
| 499 | std::vector<PyObject*> py_outputs; | |||
| 500 | tensorflow::TF_SessionPRun_wrapper(session, handle, inputs, input_ndarrays, | |||
| 501 | outputs, status.get(), &py_outputs); | |||
| 502 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 503 | ||||
| 504 | PyObject* result = PyList_New(py_outputs.size()); | |||
| 505 | if (result == nullptr) { | |||
| 506 | PyErr_SetString(PyExc_MemoryError, "Failed to create a list."); | |||
| 507 | throw py::error_already_set(); | |||
| 508 | } | |||
| 509 | for (size_t i = 0; i < py_outputs.size(); ++i) { | |||
| 510 | PyList_SET_ITEM(result, i, py_outputs.at(i))PyList_SetItem(result, i, py_outputs.at(i)); | |||
| 511 | } | |||
| 512 | ||||
| 513 | return tensorflow::PyoOrThrow(result); | |||
| 514 | }); | |||
| 515 | ||||
| 516 | // Do not release GIL. | |||
| 517 | m.def("TF_SessionPRunSetup_wrapper", | |||
| 518 | [](TF_Session* session, const std::vector<TF_Output>& inputs, | |||
| 519 | const std::vector<TF_Output>& outputs, | |||
| 520 | const std::vector<TF_Operation*>& targets) { | |||
| 521 | tensorflow::Safe_TF_StatusPtr status = | |||
| 522 | tensorflow::make_safe(TF_NewStatus()); | |||
| 523 | const char* out_handle; | |||
| 524 | tensorflow::TF_SessionPRunSetup_wrapper( | |||
| 525 | session, inputs, outputs, targets, &out_handle, status.get()); | |||
| 526 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 527 | return out_handle; | |||
| 528 | }); | |||
| 529 | ||||
| 530 | // Do not release GIL. | |||
| 531 | m.def("TF_SessionRunCallable", [](TF_Session* session, int64_t handle, | |||
| 532 | py::object feed_values, | |||
| 533 | TF_Buffer* run_metadata) { | |||
| 534 | tensorflow::PyObjectVector out_values; | |||
| 535 | tensorflow::Safe_TF_StatusPtr status = | |||
| 536 | tensorflow::make_safe(TF_NewStatus()); | |||
| 537 | tensorflow::TF_SessionRunCallable(session, handle, feed_values.ptr(), | |||
| 538 | &out_values, run_metadata, status.get()); | |||
| 539 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 540 | ||||
| 541 | // Return out_values | |||
| 542 | py::list py_list; | |||
| 543 | for (size_t i = 0; i < out_values.size(); ++i) { | |||
| 544 | py::object obj = tensorflow::Pyo(out_values.at(i)); | |||
| 545 | py_list.append(obj); | |||
| 546 | } | |||
| 547 | return py_list; | |||
| 548 | }); | |||
| 549 | ||||
| 550 | m.def("TF_SessionReleaseCallable", [](TF_Session* session, int64_t handle) { | |||
| 551 | tensorflow::Safe_TF_StatusPtr status = | |||
| 552 | tensorflow::make_safe(TF_NewStatus()); | |||
| 553 | // Release GIL. | |||
| 554 | py::gil_scoped_release release; | |||
| 555 | tensorflow::TF_SessionReleaseCallable(session, handle, status.get()); | |||
| 556 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 557 | }); | |||
| 558 | ||||
| 559 | m.def("TF_NewGraph", TF_NewGraph, py::return_value_policy::reference, | |||
| 560 | py::call_guard<py::gil_scoped_release>()); | |||
| 561 | m.def("TF_DeleteGraph", TF_DeleteGraph, | |||
| 562 | py::call_guard<py::gil_scoped_release>()); | |||
| 563 | ||||
| 564 | m.def("TF_GraphGetOpDef", | |||
| 565 | [](TF_Graph* graph, const char* op_name, TF_Buffer* output_op_def) { | |||
| 566 | tensorflow::Safe_TF_StatusPtr status = | |||
| 567 | tensorflow::make_safe(TF_NewStatus()); | |||
| 568 | // Release GIL. | |||
| 569 | py::gil_scoped_release release; | |||
| 570 | TF_GraphGetOpDef(graph, op_name, output_op_def, status.get()); | |||
| 571 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 572 | }); | |||
| 573 | ||||
| 574 | m.def( | |||
| 575 | "TF_NewOperation", | |||
| 576 | [](TF_Graph* graph, const char* op_type, const char* oper_name) { | |||
| 577 | tensorflow::Safe_TF_StatusPtr status = | |||
| 578 | tensorflow::make_safe(TF_NewStatus()); | |||
| 579 | // Release GIL. | |||
| 580 | py::gil_scoped_release release; | |||
| 581 | TF_OperationDescription* output = | |||
| 582 | TF_NewOperation(graph, op_type, oper_name); | |||
| 583 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 584 | return output; | |||
| 585 | }, | |||
| 586 | py::return_value_policy::reference); | |||
| 587 | ||||
| 588 | m.def( | |||
| 589 | "TF_FinishOperation", | |||
| 590 | [](TF_OperationDescription* desc) { | |||
| 591 | tensorflow::Safe_TF_StatusPtr status = | |||
| 592 | tensorflow::make_safe(TF_NewStatus()); | |||
| 593 | // Release GIL. | |||
| 594 | py::gil_scoped_release release; | |||
| 595 | TF_Operation* output = TF_FinishOperation(desc, status.get()); | |||
| 596 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 597 | return output; | |||
| 598 | }, | |||
| 599 | py::return_value_policy::reference); | |||
| 600 | ||||
| 601 | m.def("TF_OperationGetAttrInt", | |||
| 602 | [](TF_Operation* oper, const char* attr_name) { | |||
| 603 | tensorflow::Safe_TF_StatusPtr status = | |||
| 604 | tensorflow::make_safe(TF_NewStatus()); | |||
| 605 | int64_t value; | |||
| 606 | // Release GIL. | |||
| 607 | py::gil_scoped_release release; | |||
| 608 | TF_OperationGetAttrInt(oper, attr_name, &value, status.get()); | |||
| 609 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 610 | // Convert TF_OperationGetAttrInt int64_t* out-argument to Python | |||
| 611 | // bool. | |||
| 612 | // Acquire GIL for returning output returning. | |||
| 613 | pybind11::gil_scoped_acquire acquire; | |||
| 614 | return tensorflow::Pyo(PyLong_FromLongLong(value)); | |||
| 615 | }); | |||
| 616 | ||||
| 617 | m.def("TF_SetAttrValueProto", [](TF_OperationDescription* desc, | |||
| 618 | const char* attr_name, py::bytes proto) { | |||
| 619 | tensorflow::Safe_TF_StatusPtr status = | |||
| 620 | tensorflow::make_safe(TF_NewStatus()); | |||
| 621 | tensorflow::Safe_TF_BufferPtr buf = | |||
| 622 | tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); | |||
| 623 | TF_SetAttrValueProto(desc, attr_name, buf.get()->data, buf.get()->length, | |||
| 624 | status.get()); | |||
| 625 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 626 | }); | |||
| 627 | ||||
| 628 | m.def("TF_OperationNumOutputs", TF_OperationNumOutputs, | |||
| 629 | py::call_guard<py::gil_scoped_release>()); | |||
| 630 | ||||
| 631 | // Convert types to ints | |||
| 632 | m.def("TF_OperationInputType", TF_OperationInputType, | |||
| 633 | py::call_guard<py::gil_scoped_release>()); | |||
| 634 | m.def("TF_OperationOutputType", TF_OperationOutputType, | |||
| 635 | py::call_guard<py::gil_scoped_release>()); | |||
| 636 | ||||
| 637 | m.def("TF_OperationName", TF_OperationName, | |||
| 638 | py::call_guard<py::gil_scoped_release>()); | |||
| 639 | m.def("TF_OperationOpType", TF_OperationOpType, | |||
| 640 | py::call_guard<py::gil_scoped_release>()); | |||
| 641 | m.def("TF_OperationDevice", TF_OperationDevice, | |||
| 642 | py::call_guard<py::gil_scoped_release>()); | |||
| 643 | ||||
| 644 | m.def("TF_AddInput", TF_AddInput); | |||
| 645 | ||||
| 646 | m.def("TF_OperationToNodeDef", | |||
| 647 | [](TF_Operation* oper, TF_Buffer* output_node_def) { | |||
| 648 | tensorflow::Safe_TF_StatusPtr status = | |||
| 649 | tensorflow::make_safe(TF_NewStatus()); | |||
| 650 | TF_OperationToNodeDef(oper, output_node_def, status.get()); | |||
| 651 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 652 | }); | |||
| 653 | ||||
| 654 | m.def("TF_OperationGetAttrValueProto", | |||
| 655 | [](TF_Operation* oper, const char* attr_name, | |||
| 656 | TF_Buffer* output_attr_value) { | |||
| 657 | tensorflow::Safe_TF_StatusPtr status = | |||
| 658 | tensorflow::make_safe(TF_NewStatus()); | |||
| 659 | TF_OperationGetAttrValueProto(oper, attr_name, output_attr_value, | |||
| 660 | status.get()); | |||
| 661 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 662 | }); | |||
| 663 | ||||
| 664 | m.def("SetRequestedDevice", tensorflow::SetRequestedDevice); | |||
| 665 | ||||
| 666 | // TF_Buffer util methods | |||
| 667 | // TODO(amitpatankar): Consolidate Buffer methods into a separate header file. | |||
| 668 | m.def("TF_NewBuffer", TF_NewBuffer, py::return_value_policy::reference); | |||
| 669 | m.def("TF_GetBuffer", [](TF_Buffer* buf) { | |||
| 670 | TF_Buffer buffer = TF_GetBuffer(buf); | |||
| 671 | return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize( | |||
| 672 | reinterpret_cast<const char*>(buffer.data), buffer.length)); | |||
| 673 | }); | |||
| 674 | m.def("TF_DeleteBuffer", &TF_DeleteBuffer); | |||
| 675 | m.def( | |||
| 676 | "TF_NewBufferFromString", | |||
| 677 | [](py::bytes buffer_as_string) { | |||
| 678 | tensorflow::Safe_TF_BufferPtr buf = tensorflow::make_safe( | |||
| 679 | ProtoStringToTFBuffer(buffer_as_string.ptr())); | |||
| 680 | return TF_NewBufferFromString(buf.get()->data, buf.get()->length); | |||
| 681 | }, | |||
| 682 | py::return_value_policy::reference); | |||
| 683 | ||||
| 684 | m.def("SetAttr", [](TF_Graph* graph, TF_Operation* op, const char* attr_name, | |||
| 685 | TF_Buffer* attr_value_proto) { | |||
| 686 | tensorflow::Safe_TF_StatusPtr status = | |||
| 687 | tensorflow::make_safe(TF_NewStatus()); | |||
| 688 | // Release GIL. | |||
| 689 | py::gil_scoped_release release; | |||
| 690 | tensorflow::SetAttr(graph, op, attr_name, attr_value_proto, status.get()); | |||
| 691 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 692 | }); | |||
| 693 | ||||
| 694 | m.def("ClearAttr", | |||
| 695 | [](TF_Graph* graph, TF_Operation* op, const char* attr_name) { | |||
| 696 | tensorflow::Safe_TF_StatusPtr status = | |||
| 697 | tensorflow::make_safe(TF_NewStatus()); | |||
| 698 | // Release GIL. | |||
| 699 | py::gil_scoped_release release; | |||
| 700 | tensorflow::ClearAttr(graph, op, attr_name, status.get()); | |||
| 701 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 702 | }); | |||
| 703 | ||||
| 704 | m.def( | |||
| 705 | "TF_LoadLibrary", | |||
| 706 | [](const char* library_filename) { | |||
| 707 | tensorflow::Safe_TF_StatusPtr status = | |||
| 708 | tensorflow::make_safe(TF_NewStatus()); | |||
| 709 | auto output = TF_LoadLibrary(library_filename, status.get()); | |||
| 710 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 711 | return output; | |||
| 712 | }, | |||
| 713 | py::return_value_policy::reference); | |||
| 714 | ||||
| 715 | m.def( | |||
| 716 | "TF_LoadPluggableDeviceLibrary", | |||
| 717 | [](const char* library_filename) { | |||
| 718 | tensorflow::Safe_TF_StatusPtr status = | |||
| 719 | tensorflow::make_safe(TF_NewStatus()); | |||
| 720 | auto output = | |||
| 721 | TF_LoadPluggableDeviceLibrary(library_filename, status.get()); | |||
| 722 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 723 | return output; | |||
| 724 | }, | |||
| 725 | py::return_value_policy::reference); | |||
| 726 | ||||
| 727 | m.def("TF_GetOpList", [](TF_Library* lib_handle) { | |||
| 728 | TF_Buffer output_buffer = TF_GetOpList(lib_handle); | |||
| 729 | return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize( | |||
| 730 | reinterpret_cast<const char*>(output_buffer.data), | |||
| 731 | output_buffer.length)); | |||
| 732 | }); | |||
| 733 | ||||
| 734 | m.def("TF_DeleteLibraryHandle", TF_DeleteLibraryHandle, | |||
| 735 | py::call_guard<py::gil_scoped_release>()); | |||
| 736 | ||||
| 737 | m.def("TF_PluggableDeviceLibraryHandle", | |||
| 738 | TF_DeletePluggableDeviceLibraryHandle, | |||
| 739 | py::call_guard<py::gil_scoped_release>()); | |||
| 740 | ||||
| 741 | m.def("TF_AddControlInput", TF_AddControlInput); | |||
| 742 | m.def( | |||
| 743 | "TF_AddInputList", [](TF_OperationDescription* desc, py::handle& inputs) { | |||
| 744 | std::vector<TF_Output> vec; | |||
| 745 | size_t size = PyList_Size(inputs.ptr()); | |||
| 746 | for (size_t i = 0; i < size; ++i) { | |||
| 747 | TF_Output item = py::cast<TF_Output>(PyList_GetItem(inputs.ptr(), i)); | |||
| 748 | vec.push_back(item); | |||
| 749 | } | |||
| 750 | TF_AddInputList(desc, vec.data(), vec.size()); | |||
| 751 | }); | |||
| 752 | ||||
| 753 | m.def("UpdateEdge", [](TF_Graph* graph, TF_Output new_src, TF_Input dst) { | |||
| 754 | tensorflow::Safe_TF_StatusPtr status = | |||
| 755 | tensorflow::make_safe(TF_NewStatus()); | |||
| 756 | // Release GIL. | |||
| 757 | py::gil_scoped_release release; | |||
| 758 | tensorflow::UpdateEdge(graph, new_src, dst, status.get()); | |||
| 759 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 760 | }); | |||
| 761 | ||||
| 762 | m.def("RemoveAllControlInputs", tensorflow::RemoveAllControlInputs, | |||
| 763 | py::call_guard<py::gil_scoped_release>()); | |||
| 764 | m.def("AddControlInput", tensorflow::AddControlInput, | |||
| 765 | py::call_guard<py::gil_scoped_release>()); | |||
| 766 | ||||
| 767 | m.def("TF_NewImportGraphDefOptions", TF_NewImportGraphDefOptions, | |||
| 768 | py::return_value_policy::reference, | |||
| 769 | py::call_guard<py::gil_scoped_release>()); | |||
| 770 | m.def("TF_ImportGraphDefOptionsSetPrefix", TF_ImportGraphDefOptionsSetPrefix, | |||
| 771 | py::call_guard<py::gil_scoped_release>()); | |||
| 772 | m.def("TF_ImportGraphDefOptionsSetUniquifyNames", | |||
| 773 | TF_ImportGraphDefOptionsSetUniquifyNames, | |||
| 774 | py::call_guard<py::gil_scoped_release>()); | |||
| 775 | m.def("TF_ImportGraphDefOptionsRemapControlDependency", | |||
| 776 | TF_ImportGraphDefOptionsRemapControlDependency, | |||
| 777 | py::call_guard<py::gil_scoped_release>()); | |||
| 778 | m.def("TF_ImportGraphDefOptionsAddInputMapping", | |||
| 779 | TF_ImportGraphDefOptionsAddInputMapping, | |||
| 780 | py::call_guard<py::gil_scoped_release>()); | |||
| 781 | m.def("TF_ImportGraphDefOptionsAddReturnOperation", | |||
| 782 | TF_ImportGraphDefOptionsAddReturnOperation, | |||
| 783 | py::call_guard<py::gil_scoped_release>()); | |||
| 784 | m.def("TF_ImportGraphDefOptionsAddReturnOutput", | |||
| 785 | TF_ImportGraphDefOptionsAddReturnOutput, | |||
| 786 | py::call_guard<py::gil_scoped_release>()); | |||
| 787 | ||||
| 788 | m.def( | |||
| 789 | "TF_GraphImportGraphDefWithResults", | |||
| 790 | [](TF_Graph* graph, const TF_Buffer* graph_def, | |||
| 791 | const TF_ImportGraphDefOptions* options) { | |||
| 792 | tensorflow::Safe_TF_StatusPtr status = | |||
| 793 | tensorflow::make_safe(TF_NewStatus()); | |||
| 794 | auto output = TF_GraphImportGraphDefWithResults(graph, graph_def, | |||
| 795 | options, status.get()); | |||
| 796 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 797 | return output; | |||
| 798 | }, | |||
| 799 | py::return_value_policy::reference); | |||
| 800 | ||||
| 801 | m.def( | |||
| 802 | "TF_GraphNextOperation", | |||
| 803 | [](TF_Graph* graph, size_t pos) { | |||
| 804 | tensorflow::Safe_TF_StatusPtr status = | |||
| 805 | tensorflow::make_safe(TF_NewStatus()); | |||
| 806 | auto output = TF_GraphNextOperation(graph, &pos); | |||
| 807 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 808 | ||||
| 809 | // Returns a (TF_Operation*, int pos) tuple. | |||
| 810 | py::tuple result_tuple = py::make_tuple( | |||
| 811 | py::cast(output), tensorflow::Pyo(PyLong_FromSize_t(pos))); | |||
| 812 | return result_tuple; | |||
| 813 | }, | |||
| 814 | py::return_value_policy::reference); | |||
| 815 | ||||
| 816 | // Python needs to own deletion of outputs | |||
| 817 | m.def("TF_ImportGraphDefResultsReturnOutputs", | |||
| 818 | [](TF_ImportGraphDefResults* results) { | |||
| 819 | int num_outputs; | |||
| 820 | TF_Output* outputs; | |||
| 821 | TF_ImportGraphDefResultsReturnOutputs(results, &num_outputs, | |||
| 822 | &outputs); | |||
| 823 | py::list py_list; | |||
| 824 | for (int i = 0; i < num_outputs; ++i) { | |||
| 825 | TF_Output tf_output = TF_Output(outputs[i]); | |||
| 826 | py_list.append(tf_output); | |||
| 827 | } | |||
| 828 | return py_list; | |||
| 829 | }); | |||
| 830 | ||||
| 831 | m.def( | |||
| 832 | "TF_ImportGraphDefResultsReturnOperations", | |||
| 833 | [](TF_ImportGraphDefResults* results) { | |||
| 834 | int num_opers; | |||
| 835 | TF_Operation** opers; | |||
| 836 | TF_ImportGraphDefResultsReturnOperations(results, &num_opers, &opers); | |||
| 837 | py::list py_list; | |||
| 838 | for (int i = 0; i < num_opers; ++i) { | |||
| 839 | py_list.append(opers[i]); | |||
| 840 | } | |||
| 841 | return py_list; | |||
| 842 | }, | |||
| 843 | py::return_value_policy::reference); | |||
| 844 | ||||
| 845 | m.def("TF_GraphToGraphDef", [](TF_Graph* graph, TF_Buffer* output_graph_def) { | |||
| 846 | tensorflow::Safe_TF_StatusPtr status = | |||
| 847 | tensorflow::make_safe(TF_NewStatus()); | |||
| 848 | // Release GIL. | |||
| 849 | py::gil_scoped_release release; | |||
| 850 | TF_GraphToGraphDef(graph, output_graph_def, status.get()); | |||
| 851 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 852 | }); | |||
| 853 | ||||
| 854 | m.def("TF_OperationNumInputs", TF_OperationNumInputs, | |||
| 855 | py::call_guard<py::gil_scoped_release>()); | |||
| 856 | ||||
| 857 | m.def("TF_GraphVersions", [](TF_Graph* graph, TF_Buffer* output_graph_def) { | |||
| 858 | tensorflow::Safe_TF_StatusPtr status = | |||
| 859 | tensorflow::make_safe(TF_NewStatus()); | |||
| 860 | // Release GIL. | |||
| 861 | py::gil_scoped_release release; | |||
| 862 | TF_GraphVersions(graph, output_graph_def, status.get()); | |||
| 863 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 864 | }); | |||
| 865 | ||||
| 866 | m.def("TF_DeleteFunction", TF_DeleteFunction, | |||
| 867 | py::call_guard<py::gil_scoped_release>()); | |||
| 868 | m.def("TF_DeleteImportGraphDefResults", TF_DeleteImportGraphDefResults, | |||
| 869 | py::call_guard<py::gil_scoped_release>()); | |||
| 870 | m.def("TF_DeleteImportGraphDefOptions", TF_DeleteImportGraphDefOptions, | |||
| 871 | py::call_guard<py::gil_scoped_release>()); | |||
| 872 | ||||
| 873 | m.def("TF_FunctionSetAttrValueProto", | |||
| 874 | [](TF_Function* func, const char* attr_name, py::bytes proto) { | |||
| 875 | tensorflow::Safe_TF_StatusPtr status = | |||
| 876 | tensorflow::make_safe(TF_NewStatus()); | |||
| 877 | tensorflow::Safe_TF_BufferPtr buf = | |||
| 878 | tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); | |||
| 879 | // Release GIL. | |||
| 880 | py::gil_scoped_release release; | |||
| 881 | TF_FunctionSetAttrValueProto(func, attr_name, buf.get()->data, | |||
| 882 | buf.get()->length, status.get()); | |||
| 883 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 884 | }); | |||
| 885 | ||||
| 886 | m.def("TF_FunctionToFunctionDef", | |||
| 887 | [](TF_Function* graph, TF_Buffer* output_func_def) { | |||
| 888 | tensorflow::Safe_TF_StatusPtr status = | |||
| 889 | tensorflow::make_safe(TF_NewStatus()); | |||
| 890 | // Release GIL. | |||
| 891 | py::gil_scoped_release release; | |||
| 892 | TF_FunctionToFunctionDef(graph, output_func_def, status.get()); | |||
| 893 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 894 | }); | |||
| 895 | ||||
| 896 | m.def("TF_GraphCopyFunction", | |||
| 897 | [](TF_Graph* graph, const TF_Function* func, const TF_Function* grad) { | |||
| 898 | tensorflow::Safe_TF_StatusPtr status = | |||
| 899 | tensorflow::make_safe(TF_NewStatus()); | |||
| 900 | // Release GIL. | |||
| 901 | py::gil_scoped_release release; | |||
| 902 | TF_GraphCopyFunction(graph, func, grad, status.get()); | |||
| 903 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 904 | }); | |||
| 905 | ||||
| 906 | m.def( | |||
| 907 | "TF_FunctionImportFunctionDef", | |||
| 908 | [](py::bytes proto) { | |||
| 909 | tensorflow::Safe_TF_StatusPtr status = | |||
| 910 | tensorflow::make_safe(TF_NewStatus()); | |||
| 911 | tensorflow::Safe_TF_BufferPtr buf = | |||
| 912 | tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); | |||
| 913 | ||||
| 914 | // Release GIL. | |||
| 915 | py::gil_scoped_release release; | |||
| 916 | auto output = TF_FunctionImportFunctionDef( | |||
| 917 | buf.get()->data, buf.get()->length, status.get()); | |||
| 918 | ||||
| 919 | // Acquire GIL for returning output returning. | |||
| 920 | pybind11::gil_scoped_acquire acquire; | |||
| 921 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 922 | return output; | |||
| 923 | }, | |||
| 924 | py::return_value_policy::reference); | |||
| 925 | ||||
| 926 | m.def("EqualAttrValueWrapper", tensorflow::EqualAttrValueWrapper, | |||
| 927 | py::call_guard<py::gil_scoped_release>()); | |||
| 928 | ||||
| 929 | m.def( | |||
| 930 | "TF_GetAllRegisteredKernels", | |||
| 931 | []() { | |||
| 932 | tensorflow::Safe_TF_StatusPtr status = | |||
| 933 | tensorflow::make_safe(TF_NewStatus()); | |||
| 934 | // Release GIL. | |||
| 935 | py::gil_scoped_release release; | |||
| 936 | auto output = TF_GetAllRegisteredKernels(status.get()); | |||
| 937 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 938 | return output; | |||
| 939 | }, | |||
| 940 | py::return_value_policy::reference); | |||
| 941 | ||||
| 942 | m.def( | |||
| 943 | "TF_GetRegisteredKernelsForOp", | |||
| 944 | [](const char* name) { | |||
| 945 | tensorflow::Safe_TF_StatusPtr status = | |||
| 946 | tensorflow::make_safe(TF_NewStatus()); | |||
| 947 | // Release GIL. | |||
| 948 | py::gil_scoped_release release; | |||
| 949 | auto output = TF_GetRegisteredKernelsForOp(name, status.get()); | |||
| 950 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 951 | return output; | |||
| 952 | }, | |||
| 953 | py::return_value_policy::reference); | |||
| 954 | ||||
| 955 | m.def("TF_GetAllOpList", TF_GetAllOpList, py::return_value_policy::reference, | |||
| 956 | py::call_guard<py::gil_scoped_release>()); | |||
| 957 | ||||
| 958 | m.def( | |||
| 959 | "TF_NewApiDefMap", | |||
| 960 | [](TF_Buffer* op_list_buffer) { | |||
| 961 | tensorflow::Safe_TF_StatusPtr status = | |||
| 962 | tensorflow::make_safe(TF_NewStatus()); | |||
| 963 | // Release GIL. | |||
| 964 | py::gil_scoped_release release; | |||
| 965 | auto output = TF_NewApiDefMap(op_list_buffer, status.get()); | |||
| 966 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 967 | return output; | |||
| 968 | }, | |||
| 969 | py::return_value_policy::reference); | |||
| 970 | ||||
| 971 | m.def("TF_DeleteApiDefMap", TF_DeleteApiDefMap, | |||
| 972 | py::call_guard<py::gil_scoped_release>()); | |||
| 973 | ||||
| 974 | m.def( | |||
| 975 | "TF_ApiDefMapGet", | |||
| 976 | [](TF_ApiDefMap* api_def_map, const char* name, size_t name_len) { | |||
| 977 | tensorflow::Safe_TF_StatusPtr status = | |||
| 978 | tensorflow::make_safe(TF_NewStatus()); | |||
| 979 | // Release GIL. | |||
| 980 | py::gil_scoped_release release; | |||
| 981 | auto output = | |||
| 982 | TF_ApiDefMapGet(api_def_map, name, name_len, status.get()); | |||
| 983 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 984 | return output; | |||
| 985 | }, | |||
| 986 | py::return_value_policy::reference); | |||
| 987 | ||||
| 988 | m.def("TF_ApiDefMapPut", | |||
| 989 | [](TF_ApiDefMap* api_def_map, const char* name, size_t name_len) { | |||
| 990 | tensorflow::Safe_TF_StatusPtr status = | |||
| 991 | tensorflow::make_safe(TF_NewStatus()); | |||
| 992 | // Release GIL. | |||
| 993 | py::gil_scoped_release release; | |||
| 994 | TF_ApiDefMapPut(api_def_map, name, name_len, status.get()); | |||
| 995 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 996 | }); | |||
| 997 | ||||
| 998 | m.def("TF_OperationGetAttrType", | |||
| 999 | [](TF_Operation* oper, const char* attr_name) { | |||
| 1000 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1001 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1002 | TF_DataType value; | |||
| 1003 | // Release GIL. | |||
| 1004 | py::gil_scoped_release release; | |||
| 1005 | TF_OperationGetAttrType(oper, attr_name, &value, status.get()); | |||
| 1006 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 1007 | return value; | |||
| 1008 | }); | |||
| 1009 | ||||
| 1010 | m.def( | |||
| 1011 | "TF_NewServer", | |||
| 1012 | [](py::bytes proto) { | |||
| 1013 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1014 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1015 | tensorflow::Safe_TF_BufferPtr buf = | |||
| 1016 | tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); | |||
| 1017 | TF_Server* output = | |||
| 1018 | TF_NewServer(buf.get()->data, buf.get()->length, status.get()); | |||
| 1019 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 1020 | return output; | |||
| 1021 | }, | |||
| 1022 | py::return_value_policy::reference); | |||
| 1023 | ||||
| 1024 | m.def("TF_ServerStart", [](TF_Server* server) { | |||
| 1025 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1026 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1027 | // Release GIL. | |||
| 1028 | py::gil_scoped_release release; | |||
| 1029 | TF_ServerStart(server, status.get()); | |||
| 1030 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 1031 | }); | |||
| 1032 | ||||
| 1033 | m.def("TF_ServerStop", [](TF_Server* server) { | |||
| 1034 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1035 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1036 | // Release GIL for threading. | |||
| 1037 | py::gil_scoped_release release; | |||
| 1038 | TF_ServerStop(server, status.get()); | |||
| 1039 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 1040 | }); | |||
| 1041 | ||||
| 1042 | m.def("TF_ServerJoin", [](TF_Server* server) { | |||
| 1043 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1044 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1045 | // Release GIL for threading. | |||
| 1046 | py::gil_scoped_release release; | |||
| 1047 | TF_ServerJoin(server, status.get()); | |||
| 1048 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 1049 | }); | |||
| 1050 | ||||
| 1051 | m.def( | |||
| 1052 | "TF_ServerTarget", | |||
| 1053 | [](TF_Server* server) { return TF_ServerTarget(server); }, | |||
| 1054 | py::call_guard<py::gil_scoped_release>()); | |||
| 1055 | ||||
| 1056 | m.def( | |||
| 1057 | "TF_SessionListDevices", | |||
| 1058 | [](TF_Session* session) { | |||
| 1059 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1060 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1061 | TF_DeviceList* output = TF_SessionListDevices(session, status.get()); | |||
| 1062 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 1063 | return output; | |||
| 1064 | }, | |||
| 1065 | py::return_value_policy::reference); | |||
| 1066 | ||||
| 1067 | m.def("TF_DeviceListCount", | |||
| 1068 | [](const TF_DeviceList* list) { return TF_DeviceListCount(list); }); | |||
| 1069 | ||||
| 1070 | m.def("TF_DeviceListName", [](const TF_DeviceList* list, int index) { | |||
| 1071 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1072 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1073 | const char* output = TF_DeviceListName(list, index, status.get()); | |||
| 1074 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 1075 | return output; | |||
| 1076 | }); | |||
| 1077 | ||||
| 1078 | m.def("TF_DeviceListType", [](const TF_DeviceList* list, int index) { | |||
| 1079 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1080 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1081 | const char* output = TF_DeviceListType(list, index, status.get()); | |||
| 1082 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 1083 | return output; | |||
| 1084 | }); | |||
| 1085 | ||||
| 1086 | m.def("TF_DeviceListMemoryBytes", [](const TF_DeviceList* list, int index) { | |||
| 1087 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1088 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1089 | int64_t output = TF_DeviceListMemoryBytes(list, index, status.get()); | |||
| 1090 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 1091 | return output; | |||
| 1092 | }); | |||
| 1093 | ||||
| 1094 | m.def("TF_DeviceListIncarnation", [](const TF_DeviceList* list, int index) { | |||
| 1095 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1096 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1097 | int64_t output = TF_DeviceListIncarnation(list, index, status.get()); | |||
| 1098 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 1099 | return output; | |||
| 1100 | }); | |||
| 1101 | ||||
| 1102 | m.def("TF_SetDevice", TF_SetDevice); | |||
| 1103 | ||||
| 1104 | m.def("TF_DeleteDeviceList", TF_DeleteDeviceList); | |||
| 1105 | ||||
| 1106 | m.def("TF_OperationGetAttrBool", | |||
| 1107 | [](TF_Operation* oper, const char* attr_name) { | |||
| 1108 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1109 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1110 | unsigned char value; | |||
| 1111 | // Release GIL for threading. | |||
| 1112 | py::gil_scoped_release release; | |||
| 1113 | TF_OperationGetAttrBool(oper, attr_name, &value, status.get()); | |||
| 1114 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 1115 | return tensorflow::Pyo(PyBool_FromLong(value)); | |||
| 1116 | }); | |||
| 1117 | ||||
| 1118 | m.def("TF_NewStatus", TF_NewStatus, py::return_value_policy::reference); | |||
| 1119 | m.def("TF_DeleteStatus", TF_DeleteStatus); | |||
| 1120 | ||||
| 1121 | m.def("TF_DeleteDeviceList", TF_DeleteDeviceList); | |||
| 1122 | ||||
| 1123 | m.def("AddWhileInputHack", | |||
| 1124 | [](TF_Graph* graph, TF_Output new_src, TF_Operation* dst) { | |||
| 1125 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1126 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1127 | // Release GIL for threading. | |||
| 1128 | py::gil_scoped_release release; | |||
| 1129 | tensorflow::AddWhileInputHack(graph, new_src, dst, status.get()); | |||
| 1130 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 1131 | }); | |||
| 1132 | ||||
| 1133 | m.def("TF_Reset_wrapper", [](const TF_SessionOptions* opt, | |||
| 1134 | const std::vector<py::bytes> containers) { | |||
| 1135 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1136 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1137 | // Release GIL for threading. | |||
| 1138 | py::gil_scoped_release release; | |||
| 1139 | tensorflow::NameVector containers_name_vector = | |||
| 1140 | ConvertPyListToNameVector(containers); | |||
| 1141 | tensorflow::TF_Reset_wrapper(opt, containers_name_vector, status.get()); | |||
| 1142 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); | |||
| 1143 | }); | |||
| 1144 | m.def("TF_GetCode", TF_GetCode); | |||
| 1145 | ||||
| 1146 | m.def("TF_SetXlaAutoJitMode", TF_SetXlaAutoJitMode); | |||
| 1147 | m.def("TF_SetXlaAutoJitMode", TF_SetXlaAutoJitMode); | |||
| 1148 | m.def("TF_SetXlaEnableLazyCompilation", TF_SetXlaEnableLazyCompilation); | |||
| 1149 | m.def("TF_SetTfXlaCpuGlobalJit", TF_SetTfXlaCpuGlobalJit); | |||
| 1150 | m.def("TF_SetXlaMinClusterSize", TF_SetXlaMinClusterSize); | |||
| 1151 | m.def("TF_GetXlaConstantFoldingDisabled", TF_GetXlaConstantFoldingDisabled); | |||
| 1152 | m.def("TF_SetXlaConstantFoldingDisabled", TF_SetXlaConstantFoldingDisabled); | |||
| 1153 | ||||
| 1154 | // // Static constants are not working on Windows. b/145559202 | |||
| 1155 | // // Creating getters instead. | |||
| 1156 | ||||
| 1157 | m.def("get_version", []() { return TF_VERSION_STRING("2" "." "7" "." "0" ""); }); | |||
| 1158 | m.def("get_git_version", []() { return tf_git_version(); }); | |||
| 1159 | m.def("get_compiler_version", []() { return tf_compiler_version(); }); | |||
| 1160 | m.def("get_cxx11_abi_flag", []() { return tf_cxx11_abi_flag(); }); | |||
| 1161 | m.def("get_eigen_max_align_bytes", []() { return EIGEN_MAX_ALIGN_BYTES64; }); | |||
| 1162 | m.def("get_monolithic_build", []() { return tf_monolithic_build(); }); | |||
| 1163 | m.def("get_graph_def_version", []() { return TF_GRAPH_DEF_VERSION892; }); | |||
| 1164 | m.def("get_graph_def_version_min_consumer", | |||
| 1165 | []() { return TF_GRAPH_DEF_VERSION_MIN_CONSUMER0; }); | |||
| 1166 | m.def("get_graph_def_version_min_producer", | |||
| 1167 | []() { return TF_GRAPH_DEF_VERSION_MIN_PRODUCER0; }); | |||
| 1168 | m.def("get_tensor_handle_key", []() { | |||
| 1169 | // TODO(amitpatankar): Look into a more elegant solution. | |||
| 1170 | // Since this is a shared object we will hard code the value from | |||
| 1171 | // third_party/tensorflow/core/common_runtime/session_state.cc because | |||
| 1172 | // the Windows import will not load the libraries necessarily | |||
| 1173 | // in order. b/145559202 | |||
| 1174 | return "TensorHandle"; | |||
| 1175 | }); | |||
| 1176 | ||||
| 1177 | m.def("TF_RegisterFilesystemPlugin", [](const char* plugin_filename) { | |||
| 1178 | tensorflow::Safe_TF_StatusPtr status = | |||
| 1179 | tensorflow::make_safe(TF_NewStatus()); | |||
| 1180 | TF_RegisterFilesystemPlugin(plugin_filename, status.get()); | |||
| 1181 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); | |||
| 1182 | }); | |||
| 1183 | ||||
| 1184 | py::enum_<TF_DataType>(m, "TF_DataType") | |||
| 1185 | .value("TF_FLOAT", TF_FLOAT) | |||
| 1186 | .value("TF_DOUBLE", TF_DOUBLE) | |||
| 1187 | .value("TF_INT32", TF_INT32) | |||
| 1188 | .value("TF_UINT8", TF_UINT8) | |||
| 1189 | .value("TF_INT16", TF_INT16) | |||
| 1190 | .value("TF_INT8", TF_INT8) | |||
| 1191 | .value("TF_STRING", TF_STRING) | |||
| 1192 | .value("TF_COMPLEX64", TF_COMPLEX64) | |||
| 1193 | .value("TF_COMPLEX", TF_COMPLEX) | |||
| 1194 | .value("TF_INT64", TF_INT64) | |||
| 1195 | .value("TF_BOOL", TF_BOOL) | |||
| 1196 | .value("TF_QINT8", TF_QINT8) | |||
| 1197 | .value("TF_QUINT8", TF_QUINT8) | |||
| 1198 | .value("TF_QINT32", TF_QINT32) | |||
| 1199 | .value("TF_BFLOAT16", TF_BFLOAT16) | |||
| 1200 | .value("TF_QINT16", TF_QINT16) | |||
| 1201 | .value("TF_QUINT16", TF_QUINT16) | |||
| 1202 | .value("TF_UINT16", TF_UINT16) | |||
| 1203 | .value("TF_COMPLEX128", TF_COMPLEX128) | |||
| 1204 | .value("TF_HALF", TF_HALF) | |||
| 1205 | .value("TF_RESOURCE", TF_RESOURCE) | |||
| 1206 | .value("TF_VARIANT", TF_VARIANT) | |||
| 1207 | .value("TF_UINT32", TF_UINT32) | |||
| 1208 | .value("TF_UINT64", TF_UINT64) | |||
| 1209 | .export_values(); | |||
| 1210 | ||||
| 1211 | py::enum_<TF_Code>(m, "TF_Code") | |||
| 1212 | .value("TF_OK", TF_OK) | |||
| 1213 | .value("TF_CANCELLED", TF_CANCELLED) | |||
| 1214 | .value("TF_UNKNOWN", TF_UNKNOWN) | |||
| 1215 | .value("TF_INVALID_ARGUMENT", TF_INVALID_ARGUMENT) | |||
| 1216 | .value("TF_DEADLINE_EXCEEDED", TF_DEADLINE_EXCEEDED) | |||
| 1217 | .value("TF_PERMISSION_DENIED", TF_PERMISSION_DENIED) | |||
| 1218 | .value("TF_UNAUTHENTICATED", TF_UNAUTHENTICATED) | |||
| 1219 | .value("TF_RESOURCE_EXHAUSTED", TF_RESOURCE_EXHAUSTED) | |||
| 1220 | .value("TF_FAILED_PRECONDITION", TF_FAILED_PRECONDITION) | |||
| 1221 | .value("TF_ABORTED", TF_ABORTED) | |||
| 1222 | .value("TF_OUT_OF_RANGE", TF_OUT_OF_RANGE) | |||
| 1223 | .value("TF_UNIMPLEMENTED", TF_UNIMPLEMENTED) | |||
| 1224 | .value("TF_INTERNAL", TF_INTERNAL) | |||
| 1225 | .value("TF_DATA_LOSS", TF_DATA_LOSS) | |||
| 1226 | .export_values(); | |||
| 1227 | }; |
| 1 | #ifndef PySequence_Fast |
| 2 | struct _object; |
| 3 | typedef struct _object PyObject; |
| 4 | PyObject* clang_analyzer_PyObject_New_Reference(); |
| 5 | PyObject* PySequence_Fast(PyObject *o, const char *m) { |
| 6 | return clang_analyzer_PyObject_New_Reference(); |
| 7 | } |
| 8 | #else |
| 9 | #warning "API PySequence_Fast is defined as a macro." |
| 10 | #endif |