Bug Summary

File:.cache/bazel/_bazel_alan/39be661231df2a680c9b74265384c13c/execroot/org_tensorflow/tensorflow/python/tfe_wrapper.cc
Warning:line 140, column 13
PyObject ownership leak with reference count of 1

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-unknown-linux-gnu -analyze -disable-free -disable-llvm-verifier -discard-value-names -main-file-name tfe_wrapper.cc -analyzer-store=region -analyzer-opt-analyze-nested-blocks -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -analyzer-output=html -analyzer-checker=python -analyzer-disable-checker=deadcode -analyzer-config prune-paths=true,suppress-c++-stdlib=true,suppress-null-return-paths=false,crosscheck-with-z3=true,model-path=/opt/pyrefcon/lib/pyrefcon/models/models -analyzer-config experimental-enable-naive-ctu-analysis=true,ctu-dir=/tmp/pyrefcon/tensorflow/csa-scan,ctu-index-name=/tmp/pyrefcon/tensorflow/csa-scan/externalDefMap.txt,ctu-invocation-list=/tmp/pyrefcon/tensorflow/csa-scan/invocations.yaml,display-ctu-progress=false -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -fhalf-no-semantic-interposition -mframe-pointer=all -relaxed-aliasing -fmath-errno -fno-rounding-math -mconstructor-aliases -munwind-tables -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/home/pyrefcon/.cache/bazel/_bazel_alan/39be661231df2a680c9b74265384c13c/execroot/org_tensorflow -resource-dir /opt/pyrefcon/lib/clang/13.0.0 -iquote . -iquote bazel-out/k8-opt/bin -iquote external/local_config_python -iquote bazel-out/k8-opt/bin/external/local_config_python -iquote external/pybind11 -iquote bazel-out/k8-opt/bin/external/pybind11 -iquote external/com_google_protobuf -iquote bazel-out/k8-opt/bin/external/com_google_protobuf -iquote external/com_google_absl -iquote bazel-out/k8-opt/bin/external/com_google_absl -iquote external/nsync -iquote bazel-out/k8-opt/bin/external/nsync -iquote external/eigen_archive -iquote bazel-out/k8-opt/bin/external/eigen_archive -iquote external/gif -iquote bazel-out/k8-opt/bin/external/gif -iquote external/libjpeg_turbo -iquote bazel-out/k8-opt/bin/external/libjpeg_turbo -iquote external/com_googlesource_code_re2 -iquote bazel-out/k8-opt/bin/external/com_googlesource_code_re2 -iquote external/farmhash_archive -iquote bazel-out/k8-opt/bin/external/farmhash_archive -iquote external/fft2d -iquote bazel-out/k8-opt/bin/external/fft2d -iquote external/highwayhash -iquote bazel-out/k8-opt/bin/external/highwayhash -iquote external/zlib -iquote bazel-out/k8-opt/bin/external/zlib -iquote external/double_conversion -iquote bazel-out/k8-opt/bin/external/double_conversion -iquote external/llvm-project -iquote bazel-out/k8-opt/bin/external/llvm-project -iquote external/llvm_terminfo -iquote bazel-out/k8-opt/bin/external/llvm_terminfo -iquote external/llvm_zlib -iquote bazel-out/k8-opt/bin/external/llvm_zlib -iquote external/bazel_tools -iquote bazel-out/k8-opt/bin/external/bazel_tools -isystem /opt/pyrefcon/lib/pyrefcon/models/python3.8 -isystem /opt/pyrefcon/lib/pyrefcon/models/python3.8 -isystem external/pybind11/include -isystem bazel-out/k8-opt/bin/external/pybind11/include -isystem external/local_config_python/numpy_include -isystem bazel-out/k8-opt/bin/external/local_config_python/numpy_include -isystem external/com_google_protobuf/src -isystem bazel-out/k8-opt/bin/external/com_google_protobuf/src -isystem external/nsync/public -isystem bazel-out/k8-opt/bin/external/nsync/public -isystem third_party/eigen3/mkl_include -isystem bazel-out/k8-opt/bin/third_party/eigen3/mkl_include -isystem external/eigen_archive -isystem bazel-out/k8-opt/bin/external/eigen_archive -isystem external/gif -isystem bazel-out/k8-opt/bin/external/gif -isystem external/farmhash_archive/src -isystem bazel-out/k8-opt/bin/external/farmhash_archive/src -isystem external/zlib -isystem bazel-out/k8-opt/bin/external/zlib -isystem external/double_conversion -isystem bazel-out/k8-opt/bin/external/double_conversion -isystem external/llvm-project/llvm/include -isystem bazel-out/k8-opt/bin/external/llvm-project/llvm/include -U _FORTIFY_SOURCE -D _FORTIFY_SOURCE=1 -D NDEBUG -D EIGEN_MPL2_ONLY -D EIGEN_MAX_ALIGN_BYTES=64 -D LLVM_ON_UNIX=1 -D HAVE_BACKTRACE=1 -D BACKTRACE_HEADER=<execinfo.h> -D LTDL_SHLIB_EXT=".so" -D LLVM_PLUGIN_EXT=".so" -D LLVM_ENABLE_THREADS=1 -D HAVE_SYSEXITS_H=1 -D HAVE_UNISTD_H=1 -D HAVE_STRERROR_R=1 -D HAVE_LIBPTHREAD=1 -D HAVE_PTHREAD_GETNAME_NP=1 -D HAVE_PTHREAD_SETNAME_NP=1 -D HAVE_PTHREAD_GETSPECIFIC=1 -D HAVE_REGISTER_FRAME=1 -D HAVE_DEREGISTER_FRAME=1 -D _GNU_SOURCE -D HAVE_LINK_H=1 -D HAVE_LSEEK64=1 -D HAVE_MALLINFO=1 -D HAVE_POSIX_FALLOCATE=1 -D HAVE_SBRK=1 -D HAVE_STRUCT_STAT_ST_MTIM_TV_NSEC=1 -D LLVM_NATIVE_ARCH="X86" -D LLVM_NATIVE_ASMPARSER=LLVMInitializeX86AsmParser -D LLVM_NATIVE_ASMPRINTER=LLVMInitializeX86AsmPrinter -D LLVM_NATIVE_DISASSEMBLER=LLVMInitializeX86Disassembler -D LLVM_NATIVE_TARGET=LLVMInitializeX86Target -D LLVM_NATIVE_TARGETINFO=LLVMInitializeX86TargetInfo -D LLVM_NATIVE_TARGETMC=LLVMInitializeX86TargetMC -D LLVM_NATIVE_TARGETMCA=LLVMInitializeX86TargetMCA -D LLVM_HOST_TRIPLE="x86_64-unknown-linux-gnu" -D LLVM_DEFAULT_TARGET_TRIPLE="x86_64-unknown-linux-gnu" -D __STDC_LIMIT_MACROS -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -I bazel-out/k8-opt/bin/external/pybind11/_virtual_includes/pybind11 -D AUTOLOAD_DYNAMIC_KERNELS -D __DATE__="redacted" -D __TIMESTAMP__="redacted" -D __TIME__="redacted" -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /opt/pyrefcon/lib/clang/13.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -O2 -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -Wno-builtin-macro-redefined -w -std=c++14 -fdeprecated-macro -fdebug-compilation-dir=/home/pyrefcon/.cache/bazel/_bazel_alan/39be661231df2a680c9b74265384c13c/execroot/org_tensorflow -ferror-limit 19 -fvisibility hidden -stack-protector 1 -fgnuc-version=4.2.1 -fcxx-exceptions -fexceptions -vectorize-loops -vectorize-slp -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/pyrefcon/tensorflow/csa-scan/reports -x c++ tensorflow/python/tfe_wrapper.cc

tensorflow/python/tfe_wrapper.cc

1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");;
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <memory>
17
18#include "Python.h"
19#include "absl/strings/str_format.h"
20#include "pybind11/chrono.h"
21#include "pybind11/complex.h"
22#include "pybind11/functional.h"
23#include "pybind11/pybind11.h"
24#include "pybind11/pytypes.h"
25#include "pybind11/stl.h"
26#include "tensorflow/c/c_api.h"
27#include "tensorflow/c/c_api_experimental.h"
28#include "tensorflow/c/eager/c_api.h"
29#include "tensorflow/c/eager/c_api_experimental.h"
30#include "tensorflow/c/eager/c_api_internal.h"
31#include "tensorflow/c/eager/dlpack.h"
32#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
33#include "tensorflow/c/eager/tfe_context_internal.h"
34#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
35#include "tensorflow/c/tf_status.h"
36#include "tensorflow/c/tf_status_helper.h"
37#include "tensorflow/compiler/jit/flags.h"
38#include "tensorflow/compiler/jit/get_compiler_ir.h"
39#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
40#include "tensorflow/python/eager/pywrap_tfe.h"
41#include "tensorflow/python/lib/core/py_exception_registry.h"
42#include "tensorflow/python/lib/core/pybind11_lib.h"
43#include "tensorflow/python/lib/core/pybind11_status.h"
44#include "tensorflow/python/lib/core/safe_ptr.h"
45#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
46#include "tensorflow/python/util/util.h"
47
48namespace py = pybind11;
49
50PYBIND11_MAKE_OPAQUE(TFE_Executor)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_Executor> : public type_caster_base<
TFE_Executor> { }; }}
;
51PYBIND11_MAKE_OPAQUE(TFE_ContextOptions)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_ContextOptions> : public type_caster_base
<TFE_ContextOptions> { }; }}
;
52PYBIND11_MAKE_OPAQUE(tensorflow::CancellationManager)namespace pybind11 { namespace detail { template<> class
type_caster<tensorflow::CancellationManager> : public type_caster_base
<tensorflow::CancellationManager> { }; }}
;
53
54PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringCounter0> : public type_caster_base
<TFE_MonitoringCounter0> { }; }}
;
55PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringCounter1> : public type_caster_base
<TFE_MonitoringCounter1> { }; }}
;
56PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter2)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringCounter2> : public type_caster_base
<TFE_MonitoringCounter2> { }; }}
;
57PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge0)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringStringGauge0> : public type_caster_base
<TFE_MonitoringStringGauge0> { }; }}
;
58PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge1)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringStringGauge1> : public type_caster_base
<TFE_MonitoringStringGauge1> { }; }}
;
59PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge2)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringStringGauge2> : public type_caster_base
<TFE_MonitoringStringGauge2> { }; }}
;
60PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge3)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringStringGauge3> : public type_caster_base
<TFE_MonitoringStringGauge3> { }; }}
;
61PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge4)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringStringGauge4> : public type_caster_base
<TFE_MonitoringStringGauge4> { }; }}
;
62PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge0)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringIntGauge0> : public type_caster_base
<TFE_MonitoringIntGauge0> { }; }}
;
63PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge1)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringIntGauge1> : public type_caster_base
<TFE_MonitoringIntGauge1> { }; }}
;
64PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge2)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringIntGauge2> : public type_caster_base
<TFE_MonitoringIntGauge2> { }; }}
;
65PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge0)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringBoolGauge0> : public type_caster_base
<TFE_MonitoringBoolGauge0> { }; }}
;
66PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge1)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringBoolGauge1> : public type_caster_base
<TFE_MonitoringBoolGauge1> { }; }}
;
67PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge2)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringBoolGauge2> : public type_caster_base
<TFE_MonitoringBoolGauge2> { }; }}
;
68PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler0)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringSampler0> : public type_caster_base
<TFE_MonitoringSampler0> { }; }}
;
69PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler1)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringSampler1> : public type_caster_base
<TFE_MonitoringSampler1> { }; }}
;
70PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler2)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringSampler2> : public type_caster_base
<TFE_MonitoringSampler2> { }; }}
;
71PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounterCell)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringCounterCell> : public type_caster_base
<TFE_MonitoringCounterCell> { }; }}
;
72PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGaugeCell)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringIntGaugeCell> : public type_caster_base
<TFE_MonitoringIntGaugeCell> { }; }}
;
73PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGaugeCell)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringStringGaugeCell> : public type_caster_base
<TFE_MonitoringStringGaugeCell> { }; }}
;
74PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGaugeCell)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringBoolGaugeCell> : public type_caster_base
<TFE_MonitoringBoolGaugeCell> { }; }}
;
75PYBIND11_MAKE_OPAQUE(TFE_MonitoringSamplerCell)namespace pybind11 { namespace detail { template<> class
type_caster<TFE_MonitoringSamplerCell> : public type_caster_base
<TFE_MonitoringSamplerCell> { }; }}
;
76
77PYBIND11_MAKE_OPAQUE(TF_DeviceList)namespace pybind11 { namespace detail { template<> class
type_caster<TF_DeviceList> : public type_caster_base<
TF_DeviceList> { }; }}
;
78PYBIND11_MAKE_OPAQUE(TF_Function)namespace pybind11 { namespace detail { template<> class
type_caster<TF_Function> : public type_caster_base<
TF_Function> { }; }}
;
79PYBIND11_MAKE_OPAQUE(TF_Buffer)namespace pybind11 { namespace detail { template<> class
type_caster<TF_Buffer> : public type_caster_base<TF_Buffer
> { }; }}
;
80
81// Eager helper functions migrated from pywrap_tfe.i.
82
83namespace tensorflow {
84
85// We cannot use Context as an opaque type. SWIG also had
86// difficult directly passing the pointer around. These
87// typemaps are migrated over from pywrap_tfe.i. I tried
88// using a custom type caster, but we get segfaults periodically.
89
90// TODO(amitpatankar): Move input and output logic of Context into a
91// pybind11 custom type caster.
92
93TFE_Context* InputTFE_Context(const py::handle& ctx) {
94 return static_cast<TFE_Context*>(PyCapsule_GetPointer(ctx.ptr(), nullptr));
95}
96
97PyObject* OutputTFE_Context(TFE_Context* context) {
98 return PyCapsule_New(context, nullptr, TFE_DeleteContextCapsule);
99}
100
101TF_Buffer* ProtoStringToTFBuffer(PyObject* input) {
102 // Convert a Python string object to TF_Buffer.
103 char* c_string;
104 Py_ssize_t py_size;
105 // PyBytes_AsStringAndSize() does not copy but simply interprets the input
106 if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
107 // Python has raised an error (likely TypeError or UnicodeEncodeError).
108 throw py::error_already_set();
109 }
110 return TF_NewBufferFromString(static_cast<void*>(c_string),
111 static_cast<size_t>(py_size));
112}
113
114// These functions are typemaps from the Python side. I did not use
115// a custom type caster since the logic is slightly harder to follow. This
116// converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
117TFE_InputTensorHandles InputTFE_InputTensorHandles(
118 const py::handle& input_tensors) {
119 TFE_InputTensorHandles input_tensor_handles;
120 if (input_tensors.ptr() != Py_None(&_Py_NoneStruct)) {
2
Assuming the condition is true
3
Taking true branch
121 if (!PyList_Check(input_tensors.ptr())((((((PyObject*)(input_tensors.ptr()))->ob_type))->tp_flags
& ((1UL << 25))) != 0)
) {
4
Assuming the condition is true
5
Taking false branch
122 tensorflow::ThrowTypeError("must provide a list of Tensors as inputs");
123 }
124 Py_ssize_t len = PyList_Size(input_tensors.ptr());
125 input_tensor_handles.resize(len);
126 for (Py_ssize_t i = 0; i < len; ++i) {
6
Assuming 'i' is < 'len'
7
Loop condition is true. Entering loop body
127 PyObject* elem = PyList_GetItem(input_tensors.ptr(), i);
128 if (!elem) {
8
Assuming 'elem' is non-null
9
Taking false branch
129 tensorflow::ThrowTypeError("Input Tensor does not exist.");
130 }
131 if (EagerTensor_CheckExact(elem)) {
10
Assuming the condition is false
11
Taking false branch
132 (input_tensor_handles)[i] = EagerTensor_Handle(elem);
133 } else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
12
Assuming the condition is true
13
Taking true branch
134 // Use equivalent of object.__getattribute__ to get the underlying
135 // tf wrapped EagerTensor (if there is one).
136 tensorflow::Safe_PyObjectPtr tf_should_use_attr(
137#if PY_MAJOR_VERSION3 < 3
138 PyString_InternFromString("_tf_should_use_wrapped_value")
139#else
140 PyUnicode_InternFromString("_tf_should_use_wrapped_value")
14
Calling 'PyUnicode_InternFromString'
16
Returning from 'PyUnicode_InternFromString'
18
PyObject ownership leak with reference count of 1
141#endif
142 );
143 tensorflow::Safe_PyObjectPtr value_attr(
144 PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
145 if (value_attr) {
17
Taking false branch
146 // This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
147 (input_tensor_handles)[i] = EagerTensor_Handle(value_attr.get());
148 } else {
149 // This is a subclass of EagerTensor that we don't support.
150 PyErr_Clear();
151 tensorflow::ThrowTypeError(
152 tensorflow::strings::StrCat(
153 "Saw an object that is an instance of a strict subclass of "
154 "EagerTensor, which is not supported. Item ",
155 i, " is type: ", elem->ob_type->tp_name)
156 .c_str());
157 }
158 } else if (tensorflow::swig::IsTensor(elem)) {
159 // If it isnt an EagerTensor, but is still a Tensor, it must be a graph
160 // tensor.
161 tensorflow::Safe_PyObjectPtr name_attr(
162 PyObject_GetAttrString(elem, "name"));
163 tensorflow::ThrowTypeError(
164 tensorflow::strings::StrCat(
165 "An op outside of the function building code is being passed\n"
166 "a \"Graph\" tensor. It is possible to have Graph tensors\n"
167 "leak out of the function building context by including a\n"
168 "tf.init_scope in your function building code.\n"
169 "For example, the following function will fail:\n",
170 " @tf.function\n", " def has_init_scope():\n",
171 " my_constant = tf.constant(1.)\n",
172 " with tf.init_scope():\n",
173 " added = my_constant * 2\n",
174 "The graph tensor has name: ",
175 name_attr ? TFE_GetPythonString(name_attr.get()) : "<unknown>")
176 .c_str());
177 } else {
178 tensorflow::ThrowTypeError(
179 tensorflow::strings::StrCat(
180 "provided list of inputs contains objects other "
181 "than 'EagerTensor'. Item ",
182 i, " is type: ", elem->ob_type->tp_name)
183 .c_str());
184 }
185 }
186 }
187 return input_tensor_handles;
188}
189
190// These functions are typemaps from the Python side. I did not use
191// a custom type caster since the logic is slightly harder to follow. This
192// converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
193// This function actually takes a number rather than an output Tensor holder.
194TFE_OutputTensorHandles InputTFE_OutputTensorHandles(
195 const py::handle& num_outputs) {
196 TFE_OutputTensorHandles output_tensor_handles;
197#if PY_MAJOR_VERSION3 < 3
198 if (!PyInt_Check(num_outputs.ptr())) {
199#else
200 if (!PyLong_Check(num_outputs.ptr())((((((PyObject*)(num_outputs.ptr()))->ob_type))->tp_flags
& ((1UL << 24))) != 0)
) {
201#endif
202 PyErr_SetString(PyExc_TypeError,
203 "expected an integer value (size of the number of "
204 "outputs of the operation)");
205 throw py::error_already_set();
206 }
207#if PY_MAJOR_VERSION3 < 3
208 long sz = PyInt_AsLong(num_outputs.ptr()); // NOLINT
209#else
210 long sz = PyLong_AsLong(num_outputs.ptr()); // NOLINT
211#endif
212 // PyLong_AsLong might throw an error if an overflow occurs.
213 if (PyErr_Occurred()) {
214 PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
215 "Number of outputs is too big: ", sz)
216 .c_str());
217 throw py::error_already_set();
218 }
219 // We can't handle more than int32 sizes for number of outputs.
220 if (static_cast<long>(static_cast<int32_t>(sz)) != sz) { // NOLINT
221 PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
222 "Number of outputs is too big: ", sz)
223 .c_str());
224 throw py::error_already_set();
225 }
226 if (sz > 0) {
227#if PY_MAJOR_VERSION3 < 3
228 output_tensor_handles.resize(PyInt_AsLong(num_outputs.ptr()), nullptr);
229#else
230 output_tensor_handles.resize(PyLong_AsLong(num_outputs.ptr()), nullptr);
231#endif
232 }
233 return output_tensor_handles;
234}
235
236tensorflow::Device* GetMatchedDevice(py::handle& ctx, const char* device_name) {
237 auto* context = reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
238 tensorflow::InputTFE_Context(ctx));
239
240 tensorflow::DeviceNameUtils::ParsedName input_device_name;
241 if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(device_name,
242 &input_device_name)) {
243 tensorflow::ThrowValueError(
244 absl::StrFormat("Failed parsing device name: '%s'. Note a valid device "
245 "string should at least contain a device type and a "
246 "device index, like \"GPU:0\".",
247 device_name)
248 .c_str());
249 }
250
251 std::vector<tensorflow::Device*> devices = context->ListLocalTfDevices();
252
253 tensorflow::Device* matched_device = nullptr;
254 for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
255 tensorflow::Device* device = devices[device_idx];
256
257 if (tensorflow::DeviceNameUtils::AreCompatibleDevNames(
258 input_device_name, device->parsed_name())) {
259 if (matched_device != nullptr) {
260 tensorflow::ThrowValueError(
261 absl::StrFormat("Multiple devices match the provided string "
262 "'%s': '%s' and '%s'.",
263 device_name, matched_device->name(), device->name())
264 .c_str());
265 }
266 matched_device = device;
267 }
268 }
269
270 if (matched_device == nullptr) {
271 tensorflow::ThrowValueError(
272 absl::StrFormat("No matching devices found for '%s'", device_name)
273 .c_str());
274 }
275
276 return matched_device;
277}
278
279// Packs multiple `EagerTensor`s of the same dtype and shape into one
280// `EagerTensor`.
281py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
282 const py::handle& tensors) {
283 TFE_Context* ctx = tensorflow::InputTFE_Context(context);
284 TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors);
285 tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
286 int size = handles.size();
287 TFE_TensorHandle* packed_handle =
288 TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get());
289 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
290 PyObject* packed_tensor =
291 EagerTensorFromHandle(packed_handle, /*is_packed=*/true);
292 return tensorflow::PyoOrThrow(packed_tensor);
293}
294
295// This function was created from fusing the typemap logic in platform/base.i.
296py::object TFE_Py_ExecuteCancelable_wrapper(
297 const py::handle& context, const char* device_name, const char* op_name,
298 const py::handle& inputs, const py::handle& attrs,
299 tensorflow::CancellationManager* cancellation_manager,
300 const py::handle& num_outputs) {
301 TFE_Context* ctx = tensorflow::InputTFE_Context(context);
302 TFE_InputTensorHandles input_tensor_handles =
303 InputTFE_InputTensorHandles(inputs);
304 TFE_OutputTensorHandles output_tensor_handles =
305 InputTFE_OutputTensorHandles(num_outputs);
306 tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
307 TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles,
308 attrs.ptr(), tensorflow::wrap(cancellation_manager),
309 &output_tensor_handles, status.get());
310
311 int output_len = output_tensor_handles.size();
312 PyObject* output_list = PyList_New(output_len);
313 for (int i = 0; i < output_len; ++i) {
314 PyObject* output;
315 output = EagerTensorFromHandle(output_tensor_handles.at(i));
316 PyList_SetItem(output_list, i, output);
317 }
318 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
319 return tensorflow::PyoOrThrow(output_list);
320}
321
322static py::object TF_ListPhysicalDevices() {
323 std::vector<string> devices;
324 tensorflow::Status s =
325 tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
326 MaybeRaiseRegisteredFromStatus(s);
327 PyObject* result = PyList_New(devices.size());
328 int i = 0;
329 for (auto& dev : devices) {
330 PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
331 PyList_SetItem(result, i, dev_obj);
332 ++i;
333 }
334 return tensorflow::PyoOrThrow(result);
335}
336
337static py::object TF_ListPluggablePhysicalDevices() {
338 std::vector<string> devices;
339 tensorflow::Status s =
340 tensorflow::DeviceFactory::ListPluggablePhysicalDevices(&devices);
341 MaybeRaiseRegisteredFromStatus(s);
342 Safe_PyObjectPtr result(PyList_New(devices.size()));
343 int i = 0;
344 for (auto& dev : devices) {
345 PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
346 PyList_SetItem(result.get(), i, dev_obj);
347 ++i;
348 }
349 return tensorflow::PyoOrThrow(result.release());
350}
351
352static std::unordered_map<string, string> TF_GetDeviceDetails(int index) {
353 tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
354 std::unordered_map<string, string> device_details;
355 tensorflow::Status s =
356 tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details);
357 tensorflow::Set_TF_Status_from_Status(status.get(), s);
358 MaybeRaiseRegisteredFromTFStatus(status.get());
359 return device_details;
360}
361
362static py::object TFE_ClearScalarCache() {
363 tensorflow::TFE_TensorHandleCache::Get()->Clear();
364 return py::none();
365}
366
367// Returns compiler IR for a given function.
368static py::bytes TFE_GetCompilerIr(py::handle& ctx,
369 const char* concrete_function_name,
370 const char* stage, const char* device_name,
371 py::handle& inputs) {
372 EagerContext* context = ContextFromInterface(
373 reinterpret_cast<ImmediateExecutionContext*>(InputTFE_Context(ctx)));
374
375 std::string s_stage(stage);
376 IrExportStage selected_stage = [&] {
377 if (s_stage == "hlo") {
378 return IrExportStage::HLO;
379 } else if (s_stage == "hlo_serialized") {
380 return IrExportStage::HLO_SERIALIZED;
381 } else if (s_stage == "optimized_hlo") {
382 return IrExportStage::OPTIMIZED_HLO;
383 } else if (s_stage == "optimized_hlo_serialized") {
384 return IrExportStage::OPTIMIZED_HLO_SERIALIZED;
385 } else if (s_stage == "optimized_hlo_proto_serialized") {
386 return IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED;
387 } else if (s_stage == "optimized_hlo_dot") {
388 return IrExportStage::OPTIMIZED_HLO_DOT;
389 } else {
390 ThrowValueError(
391 absl::StrFormat("Invalid stage selected: '%s'. Valid values are: "
392 "'hlo', 'hlo_serialized', 'optimized_hlo', "
393 "'optimized_hlo_serialized', 'optimized_hlo_dot'",
394 s_stage)
395 .c_str());
396 }
397 }();
398
399 TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
1
Calling 'InputTFE_InputTensorHandles'
400
401 std::vector<const TensorHandle*> input_handles;
402 for (TFE_TensorHandle* tensor_handle : handles) {
403 AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
404 input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle));
405 }
406
407 DeviceNameUtils::ParsedName input_device_name;
408 if (!DeviceNameUtils::ParseFullOrLocalName(device_name, &input_device_name)) {
409 ThrowValueError(
410 absl::StrFormat("Failed parsing device name: '%s'", device_name)
411 .c_str());
412 }
413
414 std::vector<Device*> devices = context->local_device_mgr()->ListDevices();
415 auto selected_device = absl::c_find_if(devices, [&](const Device* d) {
416 return DeviceNameUtils::AreCompatibleDevNames(input_device_name,
417 d->parsed_name());
418 });
419 if (selected_device == devices.end()) {
420 ThrowValueError(
421 absl::StrFormat("No matching device found for '%s'", device_name)
422 .c_str());
423 }
424
425 xla::StatusOr<std::string> hlo_str =
426 GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
427 *selected_device, context, input_handles);
428
429 if (!hlo_str.ok()) {
430 ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
431 hlo_str.status().error_message())
432 .c_str());
433 }
434 return py::bytes(*hlo_str);
435}
436
437} // namespace tensorflow
438
439namespace {
440
441// Wrapper around the EagerContextThreadLocalData struct (defined in
442// pywrap_tfe.h), so it can be accessed from Python.
443//
444// For PyObject* fields, the get_*() methods return a new reference; and the
445// set_*() methods create a new reference (i.e., they do not steal a reference).
446class EagerContextThreadLocalDataWrapper {
447 public:
448 explicit EagerContextThreadLocalDataWrapper(py::handle py_eager_context,
449 py::handle is_eager,
450 py::handle device_spec)
451 : py_eager_context_(py_eager_context.ptr()) {
452 tensorflow::MakeEagerContextThreadLocalData(
453 py_eager_context.ptr(), is_eager.ptr(), device_spec.ptr());
454 }
455
456 ~EagerContextThreadLocalDataWrapper() {
457 tensorflow::DestroyEagerContextThreadLocalData(py_eager_context_);
458 }
459
460 bool get_is_eager() const { return GetData()->is_eager; }
461 void set_is_eager(bool v) { GetData()->is_eager = v; }
462
463 bool get_invoking_op_callbacks() const {
464 return GetData()->invoking_op_callbacks;
465 }
466 void set_invoking_op_callbacks(bool v) {
467 GetData()->invoking_op_callbacks = v;
468 }
469
470 py::object get_device_name() const {
471 return GetPyObject(&GetData()->device_name);
472 }
473 void set_device_name(py::handle v) {
474 SetPyObject(v, &GetData()->device_name);
475 }
476
477 py::object get_scope_name() const {
478 return GetPyObject(&GetData()->scope_name);
479 }
480 void set_scope_name(py::handle v) { SetPyObject(v, &GetData()->scope_name); }
481
482 py::object get_device_spec() const {
483 return GetPyObject(&GetData()->device_spec);
484 }
485 void set_device_spec(py::handle v) {
486 SetPyObject(v, &GetData()->device_spec);
487 }
488
489 py::object get_function_call_options() const {
490 return GetPyObject(&GetData()->function_call_options);
491 }
492 void set_function_call_options(py::handle v) {
493 SetPyObject(v, &GetData()->function_call_options);
494 }
495
496 py::handle get_executor() const { return GetPyObject(&GetData()->executor); }
497 void set_executor(py::handle v) { SetPyObject(v, &GetData()->executor); }
498
499 py::object get_op_callbacks() const {
500 return GetPyObject(&GetData()->op_callbacks);
501 }
502 void set_op_callbacks(py::handle v) {
503 SetPyObject(v, &GetData()->op_callbacks);
504 }
505
506 private:
507 tensorflow::EagerContextThreadLocalData* GetData() const {
508 auto* result =
509 tensorflow::GetEagerContextThreadLocalData(py_eager_context_);
510 if (!result) {
511 throw py::error_already_set();
512 }
513 return result;
514 }
515
516 py::object GetPyObject(tensorflow::Safe_PyObjectPtr* obj) const {
517 return pybind11::reinterpret_borrow<py::object>(obj->get());
518 }
519
520 void SetPyObject(py::handle value, tensorflow::Safe_PyObjectPtr* ptr) {
521 Py_INCREF(value.ptr())_Py_INCREF(((PyObject*)(value.ptr())));
522 ptr->reset(value.ptr());
523 }
524
525 PyObject* py_eager_context_; // not owned (borrowed reference).
526};
527
528} // namespace
529
530// py::return_value_policy::reference is defined as specified by the
531// pybind11 documents listed here.
532// https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies
533// This means that C++ maintains ownership of the object. We
534// are only assigning this to functions that return opaque types.
535
536PYBIND11_MODULE(_pywrap_tfe, m)static ::pybind11::module_::module_def pybind11_module_def__pywrap_tfe
; __attribute__ ((__unused__)) static void pybind11_init__pywrap_tfe
(::pybind11::module_ &); extern "C" __attribute__ ((__unused__
)) __attribute__ ((visibility("default"))) PyObject *PyInit__pywrap_tfe
(); extern "C" __attribute__ ((visibility("default"))) PyObject
*PyInit__pywrap_tfe() { { const char *compiled_ver = "3" "."
"8"; const char *runtime_ver = Py_GetVersion(); size_t len =
std::strlen(compiled_ver); if (std::strncmp(runtime_ver, compiled_ver
, len) != 0 || (runtime_ver[len] >= '0' && runtime_ver
[len] <= '9')) { PyErr_Format(PyExc_ImportError, "Python version mismatch: module was compiled for Python %s, "
"but the interpreter version is incompatible: %s.", compiled_ver
, runtime_ver); return nullptr; } } pybind11::detail::get_internals
(); auto m = ::pybind11::module_::create_extension_module( "_pywrap_tfe"
, nullptr, &pybind11_module_def__pywrap_tfe); try { pybind11_init__pywrap_tfe
(m); return m.ptr(); } catch (pybind11::error_already_set &
e) { PyErr_SetString(PyExc_ImportError, e.what()); return nullptr
; } catch (const std::exception &e) { PyErr_SetString(PyExc_ImportError
, e.what()); return nullptr; } } void pybind11_init__pywrap_tfe
(::pybind11::module_ &m)
{
537 py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor");
538 py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m,
539 "TFE_ContextOptions");
540 py::class_<TFE_MonitoringCounter0> TFE_MonitoringCounter0_class(
541 m, "TFE_MonitoringCounter0");
542 py::class_<TFE_MonitoringCounter1> TFE_MonitoringCounter1_class(
543 m, "TFE_MonitoringCounter1");
544 py::class_<TFE_MonitoringCounter2> TFE_MonitoringCounter2_class(
545 m, "TFE_MonitoringCounter2");
546 py::class_<TFE_MonitoringStringGauge0> TFE_MonitoringStringGauge0_class(
547 m, "TFE_MonitoringStringGauge0");
548 py::class_<TFE_MonitoringStringGauge1> TFE_MonitoringStringGauge1_class(
549 m, "TFE_MonitoringStringGauge1");
550 py::class_<TFE_MonitoringStringGauge2> TFE_MonitoringStringGauge2_class(
551 m, "TFE_MonitoringStringGauge2");
552 py::class_<TFE_MonitoringStringGauge3> TFE_MonitoringStringGauge3_class(
553 m, "TFE_MonitoringStringGauge3");
554 py::class_<TFE_MonitoringStringGauge4> TFE_MonitoringStringGauge4_class(
555 m, "TFE_MonitoringStringGauge4");
556 py::class_<TFE_MonitoringIntGauge0> TFE_MonitoringIntGauge0_class(
557 m, "TFE_MonitoringIntGauge0");
558 py::class_<TFE_MonitoringIntGauge1> TFE_MonitoringIntGauge1_class(
559 m, "TFE_MonitoringIntGauge1");
560 py::class_<TFE_MonitoringIntGauge2> TFE_MonitoringIntGauge2_class(
561 m, "TFE_MonitoringIntGauge2");
562 py::class_<TFE_MonitoringBoolGauge0> TFE_MonitoringBoolGauge0_class(
563 m, "TFE_MonitoringBoolGauge0");
564 py::class_<TFE_MonitoringBoolGauge1> TFE_MonitoringBoolGauge1_class(
565 m, "TFE_MonitoringBoolGauge1");
566 py::class_<TFE_MonitoringBoolGauge2> TFE_MonitoringBoolGauge2_class(
567 m, "TFE_MonitoringBoolGauge2");
568 py::class_<TFE_MonitoringCounterCell> TFE_MonitoringCounterCell_class(
569 m, "TFE_MonitoringCounterCell");
570 py::class_<TFE_MonitoringIntGaugeCell> TFE_MonitoringIntGaugeCell_class(
571 m, "TFE_MonitoringIntGaugeCell");
572 py::class_<TFE_MonitoringStringGaugeCell> TFE_MonitoringStringGaugeCell_class(
573 m, "TFE_MonitoringStringGaugeCell");
574 py::class_<TFE_MonitoringBoolGaugeCell> TFE_MonitoringBoolGaugeCell_class(
575 m, "TFE_MonitoringBoolGaugeCell");
576 py::class_<TFE_MonitoringSamplerCell> TFE_MonitoringSamplerCell_class(
577 m, "TFE_MonitoringSamplerCell");
578 py::class_<TFE_MonitoringBuckets> TFE_MonitoringBuckets_class(
579 m, "TFE_MonitoringBuckets");
580 py::class_<TFE_MonitoringSampler0> TFE_MonitoringSampler0_class(
581 m, "TFE_MonitoringSampler0");
582 py::class_<TFE_MonitoringSampler1> TFE_MonitoringSampler1_class(
583 m, "TFE_MonitoringSampler1");
584 py::class_<TFE_MonitoringSampler2> TFE_MonitoringSampler2_class(
585 m, "TFE_MonitoringSampler2");
586 py::class_<tensorflow::CancellationManager> TFE_CancellationManager_class(
587 m, "TFE_CancellationManager");
588
589 py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
590 py::class_<TF_Function> TF_Function_class(m, "TF_Function");
591
592 m.def("TFE_Py_RegisterExceptionClass", [](const py::handle& e) {
593 return tensorflow::PyoOrThrow(TFE_Py_RegisterExceptionClass(e.ptr()));
594 });
595 m.def("TFE_Py_RegisterFallbackExceptionClass", [](const py::handle& e) {
596 return tensorflow::PyoOrThrow(
597 TFE_Py_RegisterFallbackExceptionClass(e.ptr()));
598 });
599
600 m.def("TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
601 tensorflow::Device* matched_device =
602 tensorflow::GetMatchedDevice(ctx, device_name);
603
604 tensorflow::AllocatorAttributes attrs;
605 tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
606
607 if (absl::optional<tensorflow::AllocatorStats> stats =
608 allocator->GetStats()) {
609 return std::map<std::string, int64_t>{{"current", stats->bytes_in_use},
610 {"peak", stats->peak_bytes_in_use}};
611 }
612
613 tensorflow::ThrowValueError(
614 absl::StrFormat("Allocator stats not available for device '%s'",
615 device_name)
616 .c_str());
617 });
618
619 m.def("TFE_ResetMemoryStats", [](py::handle& ctx, const char* device_name) {
620 tensorflow::Device* matched_device =
621 tensorflow::GetMatchedDevice(ctx, device_name);
622
623 tensorflow::AllocatorAttributes attrs;
624 tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
625
626 if (!allocator->ClearStats()) {
627 tensorflow::ThrowValueError(
628 absl::StrFormat("Cannot reset memory stats for device '%s'",
629 device_name)
630 .c_str());
631 }
632 });
633
634 // XLA Eager Logic
635 m.def("TF_SetXlaEnableLazyCompilation", &TF_SetXlaEnableLazyCompilation);
636 m.def("TF_SetTfXlaCpuGlobalJit", &TF_SetTfXlaCpuGlobalJit);
637 m.def("TF_SetXlaAutoJitMode", &TF_SetXlaAutoJitMode);
638 m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
639 m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
640 m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
641 m.def("TF_GetCompilerIr", &tensorflow::TFE_GetCompilerIr);
642
643 // MLIR Logic
644 m.def("TF_IsMlirBridgeEnabled", [] {
645 // Since python protobuf enums are integers, cast to an integer before
646 // returning the enum to python.
647 return static_cast<int32_t>(
648 tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge);
649 });
650 m.def("TF_EnableMlirBridge", [](bool enabled) {
651 tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
652 enabled
653 ? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
654 : tensorflow::ConfigProto::Experimental::
655 MLIR_BRIDGE_ROLLOUT_DISABLED;
656 });
657 m.def("TF_EnableXlaDevices", [] {
658 tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
659 });
660
661 // // TFE_Context Logic
662 m.def(
663 "TFE_NewContext",
664 [](const TFE_ContextOptions* opts) {
665 tensorflow::Safe_TF_StatusPtr status =
666 tensorflow::make_safe(TF_NewStatus());
667 TFE_Context* context = TFE_NewContext(opts, status.get());
668 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
669 return tensorflow::PyoOrThrow(tensorflow::OutputTFE_Context(context));
670 },
671 py::return_value_policy::reference);
672 m.def("TFE_DeleteContext", [](py::handle& o) {
673 TFE_DeleteContext(tensorflow::InputTFE_Context(o));
674 });
675 m.def(
676 "TFE_ContextListDevices",
677 [](py::handle& o) {
678 tensorflow::Safe_TF_StatusPtr status =
679 tensorflow::make_safe(TF_NewStatus());
680 auto output = TFE_ContextListDevices(tensorflow::InputTFE_Context(o),
681 status.get());
682 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
683 return output;
684 },
685 py::return_value_policy::reference);
686 m.def(
687 "TFE_SetLogicalCpuDevices",
688 [](py::handle& ctx, int num_cpus, const char* prefix) {
689 tensorflow::Safe_TF_StatusPtr status =
690 tensorflow::make_safe(TF_NewStatus());
691 TFE_SetLogicalCpuDevices(tensorflow::InputTFE_Context(ctx), num_cpus,
692 prefix, status.get());
693 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
694 },
695 py::return_value_policy::reference);
696 m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
697 TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
698 });
699 m.def("TFE_ContextAddFunction", [](py::handle& ctx, TF_Function* func) {
700 tensorflow::Safe_TF_StatusPtr status =
701 tensorflow::make_safe(TF_NewStatus());
702 TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), func,
703 status.get());
704 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
705 });
706 m.def("TFE_ContextAddFunctionDef",
707 [](py::handle& ctx, const char* serialized_function_def, size_t size) {
708 tensorflow::Safe_TF_StatusPtr status =
709 tensorflow::make_safe(TF_NewStatus());
710 TFE_ContextAddFunctionDef(tensorflow::InputTFE_Context(ctx),
711 serialized_function_def, size,
712 status.get());
713 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
714 });
715 m.def("TFE_ContextGetFunctionDef",
716 [](py::handle& ctx, const char* function_name, TF_Buffer& buf) {
717 tensorflow::Safe_TF_StatusPtr status =
718 tensorflow::make_safe(TF_NewStatus());
719 TFE_ContextGetFunctionDef(tensorflow::InputTFE_Context(ctx),
720 function_name, &buf, status.get());
721 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
722 });
723 m.def("TFE_ContextRemoveFunction", [](py::handle& ctx, const char* name) {
724 tensorflow::Safe_TF_StatusPtr status =
725 tensorflow::make_safe(TF_NewStatus());
726 TFE_ContextRemoveFunction(tensorflow::InputTFE_Context(ctx), name,
727 status.get());
728 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
729 });
730 m.def("TFE_ContextHasFunction", [](py::handle& ctx, const char* name) {
731 tensorflow::Safe_TF_StatusPtr status =
732 tensorflow::make_safe(TF_NewStatus());
733 auto output =
734 TFE_ContextHasFunction(tensorflow::InputTFE_Context(ctx), name);
735 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
736 return output;
737 });
738 m.def("TFE_ContextListFunctionNames", [](py::handle& ctx) {
739 return tensorflow::unwrap(tensorflow::InputTFE_Context(ctx))
740 ->ListFunctionNames();
741 });
742 m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) {
743 TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
744 });
745 m.def("TFE_ContextDisableRunMetadata", [](py::handle& ctx) {
746 TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
747 });
748 m.def("TFE_ContextEnableGraphCollection", [](py::handle& ctx) {
749 TFE_ContextEnableGraphCollection(tensorflow::InputTFE_Context(ctx));
750 });
751 m.def("TFE_ContextDisableGraphCollection", [](py::handle& ctx) {
752 TFE_ContextDisableGraphCollection(tensorflow::InputTFE_Context(ctx));
753 });
754 m.def("TFE_ContextExportRunMetadata", [](py::handle& ctx, TF_Buffer& buf) {
755 tensorflow::Safe_TF_StatusPtr status =
756 tensorflow::make_safe(TF_NewStatus());
757 TFE_ContextExportRunMetadata(tensorflow::InputTFE_Context(ctx), &buf,
758 status.get());
759 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
760 });
761 m.def("TFE_ContextClearCaches", [](py::handle& o) {
762 TFE_ContextClearCaches(tensorflow::InputTFE_Context(o));
763 });
764 m.def("TFE_GetContextId", [](py::handle& ctx) {
765 return TFE_GetContextId(tensorflow::InputTFE_Context(ctx));
766 });
767 m.def("TFE_ContextGetDevicePlacementPolicy", [](py::handle& ctx) {
768 return TFE_ContextGetDevicePlacementPolicy(
769 tensorflow::InputTFE_Context(ctx));
770 });
771 m.def("TFE_ContextSetThreadLocalDevicePlacementPolicy",
772 [](py::handle& ctx, TFE_ContextDevicePlacementPolicy policy) {
773 TFE_ContextSetThreadLocalDevicePlacementPolicy(
774 tensorflow::InputTFE_Context(ctx), policy);
775 });
776 m.def("TFE_ContextSetServerDef", [](py::handle& ctx, int keep_alive_secs,
777 py::bytes proto) {
778 tensorflow::Safe_TF_StatusPtr status =
779 tensorflow::make_safe(TF_NewStatus());
780 tensorflow::Safe_TF_BufferPtr buf =
781 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
782 TFE_ContextSetServerDef(tensorflow::InputTFE_Context(ctx), keep_alive_secs,
783 buf.get()->data, buf.get()->length, status.get());
784 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
785 });
786 m.def("TFE_ContextUpdateServerDef", [](py::handle& ctx, int keep_alive_secs,
787 py::bytes proto) {
788 tensorflow::Safe_TF_StatusPtr status =
789 tensorflow::make_safe(TF_NewStatus());
790 tensorflow::Safe_TF_BufferPtr buf =
791 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
792 Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();;
793 TFE_ContextUpdateServerDef(tensorflow::InputTFE_Context(ctx),
794 keep_alive_secs, buf.get()->data,
795 buf.get()->length, status.get());
796 Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); };
797 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
798 });
799 m.def("TFE_ContextCheckAlive", [](py::handle& ctx, const char* worker_name) {
800 tensorflow::Safe_TF_StatusPtr status =
801 tensorflow::make_safe(TF_NewStatus());
802 bool output = TFE_ContextCheckAlive(tensorflow::InputTFE_Context(ctx),
803 worker_name, status.get());
804 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
805 return output;
806 });
807 m.def("TFE_ContextSyncExecutors", [](py::handle& ctx) {
808 tensorflow::Safe_TF_StatusPtr status =
809 tensorflow::make_safe(TF_NewStatus());
810 // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
811 Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();;
812 TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
813 Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); };
814 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
815 });
816 m.def("TFE_ContextClearExecutors", [](py::handle& ctx) {
817 tensorflow::Safe_TF_StatusPtr status =
818 tensorflow::make_safe(TF_NewStatus());
819 // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
820 Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();;
821 TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
822 Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); };
823 // NOTE: different from TFE_ContextSyncExecutors that raises potential
824 // errors, deliberately ignore executor statuses in cleanup.
825 });
826 m.def(
827 "TFE_InsertConfigKeyValue",
828 [](py::handle& ctx, const char* config_key, const char* config_value) {
829 tensorflow::Safe_TF_StatusPtr status =
830 tensorflow::make_safe(TF_NewStatus());
831 Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();;
832 TFE_InsertConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
833 config_value, status.get());
834 Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); };
835 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
836 },
837 py::return_value_policy::reference);
838 m.def(
839 "TFE_GetConfigKeyValue",
840 [](py::handle& ctx, const char* config_key, TF_Buffer& config_value) {
841 tensorflow::Safe_TF_StatusPtr status =
842 tensorflow::make_safe(TF_NewStatus());
843 Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();;
844 TFE_GetConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
845 &config_value, status.get());
846 Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); };
847 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
848 },
849 py::return_value_policy::reference);
850 m.def(
851 "TFE_DeleteConfigKeyValue",
852 [](py::handle& ctx, const char* config_key) {
853 tensorflow::Safe_TF_StatusPtr status =
854 tensorflow::make_safe(TF_NewStatus());
855 Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();;
856 TFE_DeleteConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
857 status.get());
858 Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); };
859 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
860 },
861 py::return_value_policy::reference);
862 m.def(
863 "TFE_ReportErrorToCluster",
864 [](py::handle& ctx, int error_code, const char* error_message) {
865 tensorflow::Safe_TF_StatusPtr status =
866 tensorflow::make_safe(TF_NewStatus());
867 TFE_ReportErrorToCluster(tensorflow::InputTFE_Context(ctx), error_code,
868 error_message, status.get());
869 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
870 },
871 py::return_value_policy::reference);
872 m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) {
873 tensorflow::Safe_TF_StatusPtr status =
874 tensorflow::make_safe(TF_NewStatus());
875 TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
876 status.get());
877 });
878 m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) {
879 tensorflow::Safe_TF_StatusPtr status =
880 tensorflow::make_safe(TF_NewStatus());
881 TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
882 status.get());
883 });
884
885 // TFE_Executor logic
886 m.def(
887 "TFE_NewExecutor",
888 [](const bool is_async) {
889 TFE_Executor* exc = TFE_NewExecutor(is_async);
890 return exc;
891 },
892 py::return_value_policy::reference);
893 m.def("TFE_DeleteExecutor", &TFE_DeleteExecutor);
894 m.def("TFE_ExecutorIsAsync", &TFE_ExecutorIsAsync);
895 m.def("TFE_ExecutorWaitForAllPendingNodes", [](TFE_Executor& exc) {
896 tensorflow::Safe_TF_StatusPtr status =
897 tensorflow::make_safe(TF_NewStatus());
898 // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
899 Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();;
900 TFE_ExecutorWaitForAllPendingNodes(&exc, status.get());
901 Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); };
902 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
903 });
904 m.def("TFE_ExecutorClearError", &TFE_ExecutorClearError);
905 m.def("TFE_ContextSetExecutorForThread", [](py::handle& ctx,
906 TFE_Executor& exc) {
907 TFE_ContextSetExecutorForThread(tensorflow::InputTFE_Context(ctx), &exc);
908 });
909 m.def(
910 "TFE_ContextGetExecutorForThread",
911 [](py::handle& o) {
912 return TFE_ContextGetExecutorForThread(tensorflow::InputTFE_Context(o));
913 },
914 py::return_value_policy::reference);
915
916 m.def("TFE_OpNameGetAttrType",
917 [](py::handle& ctx, const char* op_or_function_name,
918 const char* attr_name) {
919 int temp = 0;
920 unsigned char* is_list = reinterpret_cast<unsigned char*>(&temp);
921 tensorflow::Safe_TF_StatusPtr status =
922 tensorflow::make_safe(TF_NewStatus());
923 auto output = TFE_OpNameGetAttrType(tensorflow::InputTFE_Context(ctx),
924 op_or_function_name, attr_name,
925 is_list, status.get());
926 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
927#if PY_MAJOR_VERSION3 < 3
928 PyObject* output_pyo = PyInt_FromLong(output);
929#else
930 PyObject* output_pyo = PyLong_FromLong(output);
931#endif
932 if (*is_list == 1) {
933 PyObject* list = PyList_New(1);
934 PyList_SetItem(list, 0, output_pyo);
935 return tensorflow::PyoOrThrow(list);
936 }
937 return tensorflow::PyoOrThrow(output_pyo);
938 });
939 m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) {
940 return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr()));
941 });
942 m.def("TFE_Py_PackEagerTensors",
943 [](const py::handle& context, const py::handle& handles) {
944 return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles);
945 });
946 m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler);
947 m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) {
948 return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr()));
949 });
950 m.def("TFE_Py_RegisterGradientFunction", [](const py::handle& o) {
951 return tensorflow::PyoOrThrow(TFE_Py_RegisterGradientFunction(o.ptr()));
952 });
953 m.def("TFE_Py_Execute",
954 [](const py::handle& context, const char* device_name,
955 const char* op_name, const py::handle& inputs,
956 const py::handle& attrs, const py::handle& num_outputs) {
957 return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
958 context, device_name, op_name, inputs, attrs.ptr(), nullptr,
959 num_outputs);
960 });
961 m.def(
962 "TFE_Py_ExecuteCancelable",
963 [](const py::handle& context, const char* device_name,
964 const char* op_name, const py::handle& inputs, const py::handle& attrs,
965 tensorflow::CancellationManager& cancellation_manager,
966 const py::handle& num_outputs) {
967 return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
968 context, device_name, op_name, inputs, attrs.ptr(),
969 &cancellation_manager, num_outputs);
970 });
971 m.def("TFE_Py_FastPathExecute", [](const py::args args) {
972 // TFE_Py_FastPathExecute requires error checking prior to returning.
973 return tensorflow::PyoOrThrow(TFE_Py_FastPathExecute_C(args.ptr()));
974 });
975 m.def("TFE_Py_RecordGradient",
976 [](const py::handle& op_name, const py::handle& inputs,
977 const py::handle& attrs, const py::handle& results,
978 const py::handle& forward_pass_name_scope) {
979 return tensorflow::PyoOrThrow(TFE_Py_RecordGradient(
980 op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(),
981 forward_pass_name_scope.ptr()));
982 });
983 m.def("TFE_Py_UID", []() { return tensorflow::PyoOrThrow(TFE_Py_UID()); });
984
985 // TFE_Py_Tape Logic
986 m.def("TFE_Py_TapeSetNew", [](const py::handle& persistent,
987 const py::handle& watch_accessed_variables) {
988 return tensorflow::PyoOrThrow(
989 TFE_Py_TapeSetNew(persistent.ptr(), watch_accessed_variables.ptr()));
990 });
991 m.def("TFE_Py_TapeSetAdd",
992 [](const py::handle& tape) { TFE_Py_TapeSetAdd(tape.ptr()); });
993 m.def("TFE_Py_TapeSetRemove",
994 [](const py::handle& tape) { TFE_Py_TapeSetRemove(tape.ptr()); });
995 m.def("TFE_Py_TapeSetStopOnThread", &TFE_Py_TapeSetStopOnThread);
996 m.def("TFE_Py_TapeSetRestartOnThread", &TFE_Py_TapeSetRestartOnThread);
997 m.def("TFE_Py_TapeSetIsStopped",
998 []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsStopped()); });
999 m.def("TFE_Py_TapeSetIsEmpty",
1000 []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsEmpty()); });
1001 m.def("TFE_Py_TapeSetShouldRecordBackprop", [](const py::handle& tensors) {
1002 return tensorflow::PyoOrThrow(
1003 TFE_Py_TapeSetShouldRecordBackprop(tensors.ptr()));
1004 });
1005 m.def("TFE_Py_TapeSetPossibleGradientTypes", [](const py::handle& tensors) {
1006 return tensorflow::PyoOrThrow(
1007 TFE_Py_TapeSetPossibleGradientTypes(tensors.ptr()));
1008 });
1009 m.def("TFE_Py_TapeSetDeleteTrace", &TFE_Py_TapeSetDeleteTrace);
1010 m.def("TFE_Py_TapeSetRecordOperation",
1011 [](const py::handle& op_type, const py::handle& output_tensors,
1012 const py::handle& input_tensors, const py::handle& backward_function,
1013 const py::handle& forward_function) {
1014 return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperation(
1015 op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1016 backward_function.ptr(), forward_function.ptr()));
1017 });
1018 m.def(
1019 "TFE_Py_TapeSetRecordOperationBackprop",
1020 [](const py::handle& op_type, const py::handle& output_tensors,
1021 const py::handle& input_tensors, const py::handle& backward_function) {
1022 return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationBackprop(
1023 op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1024 backward_function.ptr()));
1025 });
1026 m.def(
1027 "TFE_Py_TapeSetRecordOperationForwardprop",
1028 [](const py::handle& op_type, const py::handle& output_tensors,
1029 const py::handle& input_tensors, const py::handle& backward_function,
1030 const py::handle& forwardprop_output_indices) {
1031 return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationForwardprop(
1032 op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1033 backward_function.ptr(), forwardprop_output_indices.ptr()));
1034 });
1035 m.def("TFE_Py_TapeGradient",
1036 [](const py::handle& tape, const py::handle& target,
1037 const py::handle& sources, const py::handle& output_gradients,
1038 const py::handle& sources_raw,
1039 const py::handle& unconnected_gradients) {
1040 tensorflow::Safe_TF_StatusPtr status =
1041 tensorflow::make_safe(TF_NewStatus());
1042 PyObject* output = TFE_Py_TapeGradient(
1043 tape.ptr(), target.ptr(), sources.ptr(), output_gradients.ptr(),
1044 sources_raw.ptr(), unconnected_gradients.ptr(), status.get());
1045 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1046 return tensorflow::PyoOrThrow(output);
1047 });
1048
1049 m.def("TFE_Py_TapeVariableAccessed", [](const py::handle& variable) {
1050 TFE_Py_TapeVariableAccessed(variable.ptr());
1051 });
1052 m.def("TFE_Py_TapeWatch",
1053 [](const py::handle& tape, const py::handle& tensor) {
1054 TFE_Py_TapeWatch(tape.ptr(), tensor.ptr());
1055 });
1056 m.def("TFE_Py_TapeWatchVariable",
1057 [](const py::handle& tape, const py::handle& variable) {
1058 TFE_Py_TapeWatchVariable(tape.ptr(), variable.ptr());
1059 });
1060 m.def("TFE_Py_TapeWatchedVariables", [](const py::handle& tape) {
1061 return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
1062 });
1063
1064 // TFE_Py_VariableWatcher logic.
1065 m.def("TFE_Py_VariableWatcherNew",
1066 []() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
1067 m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
1068 TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
1069 });
1070 m.def("TFE_Py_VariableWatcherVariableAccessed",
1071 [](const py::handle& variable) {
1072 TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
1073 });
1074 m.def("TFE_Py_VariableWatcherWatchedVariables",
1075 [](const py::handle& variable_watcher) {
1076 return tensorflow::PyoOrThrow(
1077 TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
1078 });
1079
1080 // TFE_Py_ForwardAccumulator logic.
1081 m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) {
1082 return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch));
1083 });
1084
1085 m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) {
1086 return tensorflow::PyoOrThrow(
1087 TFE_Py_ForwardAccumulatorSetAdd(accumulator.ptr()));
1088 });
1089 m.def("TFE_Py_ForwardAccumulatorSetRemove",
1090 [](const py::handle& accumulator) {
1091 TFE_Py_ForwardAccumulatorSetRemove(accumulator.ptr());
1092 });
1093
1094 m.def("TFE_Py_ForwardAccumulatorWatch",
1095 [](const py::handle& accumulator, const py::handle& tensor,
1096 const py::handle& tangent) {
1097 TFE_Py_ForwardAccumulatorWatch(accumulator.ptr(), tensor.ptr(),
1098 tangent.ptr());
1099 });
1100 m.def("TFE_Py_ForwardAccumulatorJVP",
1101 [](const py::handle& accumulator, const py::handle& tensor) {
1102 return tensorflow::PyoOrThrow(
1103 TFE_Py_ForwardAccumulatorJVP(accumulator.ptr(), tensor.ptr()));
1104 });
1105 m.def("TFE_Py_ForwardAccumulatorPushState", []() {
1106 return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPushState());
1107 });
1108 m.def("TFE_Py_ForwardAccumulatorPopState", []() {
1109 return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPopState());
1110 });
1111 m.def("TFE_Py_PackJVPs", [](const py::handle& tensors) {
1112 return tensorflow::PyoOrThrow(TFE_Py_PackJVPs(tensors.ptr()));
1113 });
1114
1115 // TFE_ContextOptions Logic
1116 m.def("TFE_NewContextOptions", &TFE_NewContextOptions,
1117 py::return_value_policy::reference);
1118 m.def("TFE_ContextOptionsSetConfig", [](TFE_ContextOptions* options,
1119 py::bytes proto) {
1120 tensorflow::Safe_TF_StatusPtr status =
1121 tensorflow::make_safe(TF_NewStatus());
1122 tensorflow::Safe_TF_BufferPtr buf =
1123 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1124 TFE_ContextOptionsSetConfig(options, buf.get()->data, buf.get()->length,
1125 status.get());
1126 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1127 });
1128 m.def("TFE_ContextOptionsSetDevicePlacementPolicy",
1129 &TFE_ContextOptionsSetDevicePlacementPolicy);
1130 m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt);
1131 m.def("TFE_ContextOptionsSetTfrtDistributedRuntime",
1132 &TFE_ContextOptionsSetTfrtDistributedRuntime);
1133 // Experimental feature, intentionally not exposed as a C API yet.
1134 m.def("TFE_ContextOptionsSetRunEagerOpAsFunction",
1135 [](TFE_ContextOptions* options, bool run_eager_op_as_function) {
1136 options->run_eager_op_as_function = run_eager_op_as_function;
1137 });
1138 m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync);
1139 m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions,
1140 py::return_value_policy::reference);
1141
1142 // TFE_Py_TensorShape Logic
1143 m.def("TFE_Py_TensorShapeSlice",
1144 [](const py::handle& tensors, int slice_dim) {
1145 return tensorflow::PyoOrThrow(
1146 TFE_Py_TensorShapeSlice(tensors.ptr(), slice_dim));
1147 });
1148 m.def("TFE_Py_TensorShapeOnDevice", [](const py::handle& tensors,
1149 int slice_dim) {
1150 return tensorflow::PyoOrThrow(TFE_Py_TensorShapeOnDevice(tensors.ptr()));
1151 });
1152 m.def("TFE_Py_EnableInteractivePythonLogging",
1153 &TFE_Py_EnableInteractivePythonLogging);
1154
1155 // Additional Context Logic
1156 m.def("TFE_Py_SetEagerContext", [](const py::handle& o) {
1157 return tensorflow::PyoOrThrow(TFE_Py_SetEagerContext(o.ptr()));
1158 });
1159 m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) {
1160 return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr()));
1161 });
1162 m.def("TFE_Py_EncodeArg", [](const py::handle& o,
1163 bool include_tensor_ranks_only,
1164 bool encode_variables_by_resource_id) {
1165 return tensorflow::PyoOrThrow(TFE_Py_EncodeArg(
1166 o.ptr(), include_tensor_ranks_only, encode_variables_by_resource_id));
1167 });
1168 m.def("TFE_EnableCollectiveOps", [](const py::handle& ctx, py::bytes proto) {
1169 tensorflow::Safe_TF_StatusPtr status =
1170 tensorflow::make_safe(TF_NewStatus());
1171 tensorflow::Safe_TF_BufferPtr buf =
1172 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1173 TFE_EnableCollectiveOps(tensorflow::InputTFE_Context(ctx), buf.get()->data,
1174 buf.get()->length, status.get());
1175 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1176 });
1177 m.def("TFE_AbortCollectiveOps", [](const py::handle& ctx, int code,
1178 const char* message) {
1179 tensorflow::Safe_TF_StatusPtr status =
1180 tensorflow::make_safe(TF_NewStatus());
1181 TF_SetStatus(status.get(), static_cast<TF_Code>(code), message);
1182 TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
1183 });
1184 m.def("TFE_CollectiveOpsCheckPeerHealth",
1185 [](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
1186 tensorflow::Safe_TF_StatusPtr status =
1187 tensorflow::make_safe(TF_NewStatus());
1188 TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
1189 task, timeout_in_ms, status.get());
1190 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1191 });
1192 m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
1193 m.def("TF_ListPluggablePhysicalDevices",
1194 &tensorflow::TF_ListPluggablePhysicalDevices);
1195 m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails);
1196 m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList,
1197 py::return_value_policy::reference);
1198 m.def("TF_DeviceListCount", &TF_DeviceListCount);
1199 m.def("TF_DeviceListName", [](const TF_DeviceList* list, int index) {
1200 tensorflow::Safe_TF_StatusPtr status =
1201 tensorflow::make_safe(TF_NewStatus());
1202 auto output = TF_DeviceListName(list, index, status.get());
1203 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1204 return output;
1205 });
1206 m.def("TF_DeviceListType", [](const TF_DeviceList* list, int index) {
1207 tensorflow::Safe_TF_StatusPtr status =
1208 tensorflow::make_safe(TF_NewStatus());
1209 auto output = TF_DeviceListType(list, index, status.get());
1210 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1211 return output;
1212 });
1213
1214 m.def("TF_PickUnusedPortOrDie", &TF_PickUnusedPortOrDie);
1215
1216 // TFE_MonitoringCounter Logic
1217 m.def("TFE_MonitoringCounterCellIncrementBy",
1218 &TFE_MonitoringCounterCellIncrementBy);
1219 m.def("TFE_MonitoringCounterCellValue", &TFE_MonitoringCounterCellValue);
1220 m.def(
1221 "TFE_MonitoringNewCounter0",
1222 [](const char* name, const char* description) {
1223 tensorflow::Safe_TF_StatusPtr status =
1224 tensorflow::make_safe(TF_NewStatus());
1225 auto output =
1226 TFE_MonitoringNewCounter0(name, status.get(), description);
1227 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1228 return output;
1229 },
1230 py::return_value_policy::reference);
1231 m.def("TFE_MonitoringDeleteCounter0", &TFE_MonitoringDeleteCounter0,
1232 py::return_value_policy::reference);
1233 m.def("TFE_MonitoringGetCellCounter0", &TFE_MonitoringGetCellCounter0,
1234 py::return_value_policy::reference);
1235 m.def(
1236 "TFE_MonitoringNewCounter1",
1237 [](const char* name, const char* description, const char* label1) {
1238 tensorflow::Safe_TF_StatusPtr status =
1239 tensorflow::make_safe(TF_NewStatus());
1240 auto output =
1241 TFE_MonitoringNewCounter1(name, status.get(), description, label1);
1242 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1243 return output;
1244 },
1245 py::return_value_policy::reference);
1246 m.def("TFE_MonitoringDeleteCounter1", &TFE_MonitoringDeleteCounter1,
1247 py::return_value_policy::reference);
1248 m.def("TFE_MonitoringGetCellCounter1", &TFE_MonitoringGetCellCounter1,
1249 py::return_value_policy::reference);
1250 m.def(
1251 "TFE_MonitoringNewCounter2",
1252 [](const char* name, const char* description, const char* label1,
1253 const char* label2) {
1254 tensorflow::Safe_TF_StatusPtr status =
1255 tensorflow::make_safe(TF_NewStatus());
1256 auto output = TFE_MonitoringNewCounter2(name, status.get(), description,
1257 label1, label2);
1258 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1259 return output;
1260 },
1261 py::return_value_policy::reference);
1262 m.def("TFE_MonitoringDeleteCounter2", &TFE_MonitoringDeleteCounter2,
1263 py::return_value_policy::reference);
1264 m.def("TFE_MonitoringGetCellCounter2", &TFE_MonitoringGetCellCounter2,
1265 py::return_value_policy::reference);
1266
1267 // TFE_MonitoringIntGauge Logic
1268 m.def("TFE_MonitoringIntGaugeCellSet", &TFE_MonitoringIntGaugeCellSet);
1269 m.def("TFE_MonitoringIntGaugeCellValue", &TFE_MonitoringIntGaugeCellValue);
1270 m.def(
1271 "TFE_MonitoringNewIntGauge0",
1272 [](const char* name, const char* description) {
1273 tensorflow::Safe_TF_StatusPtr status =
1274 tensorflow::make_safe(TF_NewStatus());
1275 auto output =
1276 TFE_MonitoringNewIntGauge0(name, status.get(), description);
1277 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1278 return output;
1279 },
1280 py::return_value_policy::reference);
1281 m.def("TFE_MonitoringDeleteIntGauge0", &TFE_MonitoringDeleteIntGauge0,
1282 py::return_value_policy::reference);
1283 m.def("TFE_MonitoringGetCellIntGauge0", &TFE_MonitoringGetCellIntGauge0,
1284 py::return_value_policy::reference);
1285 m.def(
1286 "TFE_MonitoringNewIntGauge1",
1287 [](const char* name, const char* description, const char* label1) {
1288 tensorflow::Safe_TF_StatusPtr status =
1289 tensorflow::make_safe(TF_NewStatus());
1290 auto output =
1291 TFE_MonitoringNewIntGauge1(name, status.get(), description, label1);
1292 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1293 return output;
1294 },
1295 py::return_value_policy::reference);
1296 m.def("TFE_MonitoringDeleteIntGauge1", &TFE_MonitoringDeleteIntGauge1,
1297 py::return_value_policy::reference);
1298 m.def("TFE_MonitoringGetCellIntGauge1", &TFE_MonitoringGetCellIntGauge1,
1299 py::return_value_policy::reference);
1300 m.def(
1301 "TFE_MonitoringNewIntGauge2",
1302 [](const char* name, const char* description, const char* label1,
1303 const char* label2) {
1304 tensorflow::Safe_TF_StatusPtr status =
1305 tensorflow::make_safe(TF_NewStatus());
1306 auto output = TFE_MonitoringNewIntGauge2(name, status.get(),
1307 description, label1, label2);
1308 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1309 return output;
1310 },
1311 py::return_value_policy::reference);
1312 m.def("TFE_MonitoringDeleteIntGauge2", &TFE_MonitoringDeleteIntGauge2,
1313 py::return_value_policy::reference);
1314 m.def("TFE_MonitoringGetCellIntGauge2", &TFE_MonitoringGetCellIntGauge2,
1315 py::return_value_policy::reference);
1316 m.def("TFE_MonitoringStringGaugeCellSet", &TFE_MonitoringStringGaugeCellSet);
1317 m.def("TFE_MonitoringStringGaugeCellValue",
1318 &TFE_MonitoringStringGaugeCellValue);
1319 m.def(
1320 "TFE_MonitoringNewStringGauge0",
1321 [](const char* name, const char* description) {
1322 tensorflow::Safe_TF_StatusPtr status =
1323 tensorflow::make_safe(TF_NewStatus());
1324 auto output =
1325 TFE_MonitoringNewStringGauge0(name, status.get(), description);
1326 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1327 return output;
1328 },
1329 py::return_value_policy::reference);
1330
1331 // TFE_MonitoringStringGauge Logic
1332 m.def("TFE_MonitoringDeleteStringGauge0", &TFE_MonitoringDeleteStringGauge0);
1333 m.def("TFE_MonitoringGetCellStringGauge0", &TFE_MonitoringGetCellStringGauge0,
1334 py::return_value_policy::reference);
1335 m.def(
1336 "TFE_MonitoringNewStringGauge1",
1337 [](const char* name, const char* description, const char* label1) {
1338 tensorflow::Safe_TF_StatusPtr status =
1339 tensorflow::make_safe(TF_NewStatus());
1340 auto output = TFE_MonitoringNewStringGauge1(name, status.get(),
1341 description, label1);
1342 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1343 return output;
1344 },
1345 py::return_value_policy::reference);
1346 m.def("TFE_MonitoringDeleteStringGauge1", &TFE_MonitoringDeleteStringGauge1);
1347 m.def("TFE_MonitoringGetCellStringGauge1", &TFE_MonitoringGetCellStringGauge1,
1348 py::return_value_policy::reference);
1349 m.def(
1350 "TFE_MonitoringNewStringGauge2",
1351 [](const char* name, const char* description, const char* label1,
1352 const char* label2) {
1353 tensorflow::Safe_TF_StatusPtr status =
1354 tensorflow::make_safe(TF_NewStatus());
1355 auto output = TFE_MonitoringNewStringGauge2(
1356 name, status.get(), description, label1, label2);
1357 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1358 return output;
1359 },
1360 py::return_value_policy::reference);
1361 m.def("TFE_MonitoringDeleteStringGauge2", &TFE_MonitoringDeleteStringGauge2);
1362 m.def("TFE_MonitoringGetCellStringGauge2", &TFE_MonitoringGetCellStringGauge2,
1363 py::return_value_policy::reference);
1364
1365 m.def(
1366 "TFE_MonitoringNewStringGauge3",
1367 [](const char* name, const char* description, const char* label1,
1368 const char* label2, const char* label3) {
1369 tensorflow::Safe_TF_StatusPtr status =
1370 tensorflow::make_safe(TF_NewStatus());
1371 auto output = TFE_MonitoringNewStringGauge3(
1372 name, status.get(), description, label1, label2, label3);
1373 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1374 return output;
1375 },
1376 py::return_value_policy::reference);
1377 m.def("TFE_MonitoringDeleteStringGauge3", &TFE_MonitoringDeleteStringGauge3);
1378 m.def("TFE_MonitoringGetCellStringGauge3", &TFE_MonitoringGetCellStringGauge3,
1379 py::return_value_policy::reference);
1380
1381 m.def(
1382 "TFE_MonitoringNewStringGauge4",
1383 [](const char* name, const char* description, const char* label1,
1384 const char* label2, const char* label3, const char* label4) {
1385 tensorflow::Safe_TF_StatusPtr status =
1386 tensorflow::make_safe(TF_NewStatus());
1387 auto output = TFE_MonitoringNewStringGauge4(
1388 name, status.get(), description, label1, label2, label3, label4);
1389 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1390 return output;
1391 },
1392 py::return_value_policy::reference);
1393 m.def("TFE_MonitoringDeleteStringGauge4", &TFE_MonitoringDeleteStringGauge4);
1394 m.def("TFE_MonitoringGetCellStringGauge4", &TFE_MonitoringGetCellStringGauge4,
1395 py::return_value_policy::reference);
1396
1397 // TFE_MonitoringBoolGauge Logic
1398 m.def("TFE_MonitoringBoolGaugeCellSet", &TFE_MonitoringBoolGaugeCellSet);
1399 m.def("TFE_MonitoringBoolGaugeCellValue", &TFE_MonitoringBoolGaugeCellValue);
1400 m.def(
1401 "TFE_MonitoringNewBoolGauge0",
1402 [](const char* name, const char* description) {
1403 tensorflow::Safe_TF_StatusPtr status =
1404 tensorflow::make_safe(TF_NewStatus());
1405 auto output =
1406 TFE_MonitoringNewBoolGauge0(name, status.get(), description);
1407 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1408 return output;
1409 },
1410 py::return_value_policy::reference);
1411 m.def("TFE_MonitoringDeleteBoolGauge0", &TFE_MonitoringDeleteBoolGauge0,
1412 py::return_value_policy::reference);
1413 m.def("TFE_MonitoringGetCellBoolGauge0", &TFE_MonitoringGetCellBoolGauge0,
1414 py::return_value_policy::reference);
1415 m.def(
1416 "TFE_MonitoringNewBoolGauge1",
1417 [](const char* name, const char* description, const char* label1) {
1418 tensorflow::Safe_TF_StatusPtr status =
1419 tensorflow::make_safe(TF_NewStatus());
1420 auto output = TFE_MonitoringNewBoolGauge1(name, status.get(),
1421 description, label1);
1422 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1423 return output;
1424 },
1425 py::return_value_policy::reference);
1426 m.def("TFE_MonitoringDeleteBoolGauge1", &TFE_MonitoringDeleteBoolGauge1,
1427 py::return_value_policy::reference);
1428 m.def("TFE_MonitoringGetCellBoolGauge1", &TFE_MonitoringGetCellBoolGauge1,
1429 py::return_value_policy::reference);
1430 m.def(
1431 "TFE_MonitoringNewBoolGauge2",
1432 [](const char* name, const char* description, const char* label1,
1433 const char* label2) {
1434 tensorflow::Safe_TF_StatusPtr status =
1435 tensorflow::make_safe(TF_NewStatus());
1436 auto output = TFE_MonitoringNewBoolGauge2(name, status.get(),
1437 description, label1, label2);
1438 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1439 return output;
1440 },
1441 py::return_value_policy::reference);
1442 m.def("TFE_MonitoringDeleteBoolGauge2", &TFE_MonitoringDeleteBoolGauge2,
1443 py::return_value_policy::reference);
1444 m.def("TFE_MonitoringGetCellBoolGauge2", &TFE_MonitoringGetCellBoolGauge2,
1445 py::return_value_policy::reference);
1446
1447 // TFE_MonitoringSampler Logic
1448 m.def("TFE_MonitoringSamplerCellAdd", &TFE_MonitoringSamplerCellAdd);
1449 m.def("TFE_MonitoringSamplerCellValue", &TFE_MonitoringSamplerCellValue);
1450 m.def("TFE_MonitoringNewExponentialBuckets",
1451 &TFE_MonitoringNewExponentialBuckets,
1452 py::return_value_policy::reference);
1453 m.def("TFE_MonitoringDeleteBuckets", &TFE_MonitoringDeleteBuckets,
1454 py::return_value_policy::reference);
1455 m.def(
1456 "TFE_MonitoringNewSampler0",
1457 [](const char* name, TFE_MonitoringBuckets* buckets,
1458 const char* description) {
1459 tensorflow::Safe_TF_StatusPtr status =
1460 tensorflow::make_safe(TF_NewStatus());
1461 auto output =
1462 TFE_MonitoringNewSampler0(name, buckets, status.get(), description);
1463 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1464 return output;
1465 },
1466 py::return_value_policy::reference);
1467 m.def("TFE_MonitoringDeleteSampler0", &TFE_MonitoringDeleteSampler0,
1468 py::return_value_policy::reference);
1469 m.def("TFE_MonitoringGetCellSampler0", &TFE_MonitoringGetCellSampler0,
1470 py::return_value_policy::reference);
1471 m.def(
1472 "TFE_MonitoringNewSampler1",
1473 [](const char* name, TFE_MonitoringBuckets* buckets,
1474 const char* description, const char* label1) {
1475 tensorflow::Safe_TF_StatusPtr status =
1476 tensorflow::make_safe(TF_NewStatus());
1477 auto output = TFE_MonitoringNewSampler1(name, buckets, status.get(),
1478 description, label1);
1479 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1480 return output;
1481 },
1482 py::return_value_policy::reference);
1483 m.def("TFE_MonitoringDeleteSampler1", &TFE_MonitoringDeleteSampler1,
1484 py::return_value_policy::reference);
1485 m.def("TFE_MonitoringGetCellSampler1", &TFE_MonitoringGetCellSampler1,
1486 py::return_value_policy::reference);
1487 m.def(
1488 "TFE_MonitoringNewSampler2",
1489 [](const char* name, TFE_MonitoringBuckets* buckets,
1490 const char* description, const char* label1, const char* label2) {
1491 tensorflow::Safe_TF_StatusPtr status =
1492 tensorflow::make_safe(TF_NewStatus());
1493 auto output = TFE_MonitoringNewSampler2(name, buckets, status.get(),
1494 description, label1, label2);
1495 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1496 return output;
1497 },
1498 py::return_value_policy::reference);
1499 m.def("TFE_MonitoringDeleteSampler2", &TFE_MonitoringDeleteSampler2,
1500 py::return_value_policy::reference);
1501 m.def("TFE_MonitoringGetCellSampler2", &TFE_MonitoringGetCellSampler2,
1502 py::return_value_policy::reference);
1503
1504 // TFE_CancellationManager Logic
1505 m.def("TFE_NewCancellationManager",
1506 []() { return new tensorflow::CancellationManager(); });
1507 m.def("TFE_CancellationManagerIsCancelled",
1508 &tensorflow::CancellationManager::IsCancelled);
1509 m.def("TFE_CancellationManagerStartCancel",
1510 &tensorflow::CancellationManager::StartCancel);
1511
1512 m.def("TFE_ClearScalarCache", &tensorflow::TFE_ClearScalarCache);
1513
1514 // Util buffer helper functions
1515 m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
1516 py::return_value_policy::reference);
1517
1518 // DLPack functions
1519 m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
1520 PyObject* eager_tensor_pyobject_ptr = o.ptr();
1521 tensorflow::Safe_TF_StatusPtr status =
1522 tensorflow::make_safe(TF_NewStatus());
1523
1524 if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
1525 status->status = tensorflow::errors::InvalidArgument(
1526 "The argument to `to_dlpack` must be a TF tensor, not Python object");
1527 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1528 }
1529
1530 TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
1531 void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
1532 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1533
1534 py::capsule capsule(
1535 dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
1536 if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
1537 void* dlm_rptr =
1538 PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
1539 if (dlm_rptr) {
1540 tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
1541 PyCapsule_SetDestructor(capsule, nullptr);
1542 }
1543 }
1544 });
1545 return capsule;
1546 });
1547
1548 m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule,
1549 const py::handle& context) {
1550 tensorflow::Safe_TF_StatusPtr status =
1551 tensorflow::make_safe(TF_NewStatus());
1552 if (absl::string_view(pycapsule.name()) !=
1553 tensorflow::kDlTensorCapsuleName) {
1554 status->status = tensorflow::errors::InvalidArgument(
1555 "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
1556 "Note that a DLPack tensor may be consumed at most once.",
1557 absl::string_view(pycapsule.name()));
1558 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1559 }
1560
1561 TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
1562 pycapsule, status.get(), tensorflow::InputTFE_Context(context));
1563
1564 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1565
1566 PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
1567 PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
1568
1569 PyObject* pyhandle = EagerTensorFromHandle(thandle);
1570 return tensorflow::PyoOrThrow(pyhandle);
1571 });
1572
1573 m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
1574 const py::capsule& device,
1575 const char* device_name,
1576 const py::capsule& device_info) {
1577 tensorflow::Safe_TF_StatusPtr status =
1578 tensorflow::make_safe(TF_NewStatus());
1579 if (absl::string_view(device.name()) != "TFE_CustomDevice") {
1580 status->status = tensorflow::errors::InvalidArgument(
1581 "Expected a capsule named 'TFE_CustomDevice' for the `device` "
1582 "argument, got ",
1583 absl::string_view(device.name()));
1584 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1585 }
1586 if (absl::string_view(device_info.name()) !=
1587 "TFE_CustomDevice_DeviceInfo") {
1588 status->status = tensorflow::errors::InvalidArgument(
1589 "Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for "
1590 "the `device_info` argument, got ",
1591 absl::string_view(device_info.name()));
1592 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1593 }
1594 // TFE_RegisterCustomDevice takes ownership
1595 PyCapsule_SetDestructor(device_info.ptr(), nullptr);
1596 TFE_RegisterCustomDevice(
1597 tensorflow::InputTFE_Context(context),
1598 *reinterpret_cast<TFE_CustomDevice*>(
1599 PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")),
1600 device_name,
1601 PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
1602 status.get());
1603 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1604 });
1605
1606 py::class_<EagerContextThreadLocalDataWrapper>(m,
1607 "EagerContextThreadLocalData")
1608 .def(py::init<py::handle, py::handle, py::handle>(),
1609 py::arg("py_eager_context"), py::arg("is_eager"),
1610 py::arg("device_spec"))
1611 .def_property("is_eager",
1612 &EagerContextThreadLocalDataWrapper::get_is_eager,
1613 &EagerContextThreadLocalDataWrapper::set_is_eager)
1614 .def_property(
1615 "invoking_op_callbacks",
1616 &EagerContextThreadLocalDataWrapper::get_invoking_op_callbacks,
1617 &EagerContextThreadLocalDataWrapper::set_invoking_op_callbacks)
1618 .def_property("device_name",
1619 &EagerContextThreadLocalDataWrapper::get_device_name,
1620 &EagerContextThreadLocalDataWrapper::set_device_name)
1621 .def_property("scope_name",
1622 &EagerContextThreadLocalDataWrapper::get_scope_name,
1623 &EagerContextThreadLocalDataWrapper::set_scope_name)
1624 .def_property("device_spec",
1625 &EagerContextThreadLocalDataWrapper::get_device_spec,
1626 &EagerContextThreadLocalDataWrapper::set_device_spec)
1627 .def_property(
1628 "function_call_options",
1629 &EagerContextThreadLocalDataWrapper::get_function_call_options,
1630 &EagerContextThreadLocalDataWrapper::set_function_call_options)
1631 .def_property("executor",
1632 &EagerContextThreadLocalDataWrapper::get_executor,
1633 &EagerContextThreadLocalDataWrapper::set_executor)
1634 .def_property("op_callbacks",
1635 &EagerContextThreadLocalDataWrapper::get_op_callbacks,
1636 &EagerContextThreadLocalDataWrapper::set_op_callbacks);
1637
1638 // C API Enum
1639
1640 py::enum_<TFE_ContextDevicePlacementPolicy>(
1641 m, "TFE_ContextDevicePlacementPolicy")
1642 .value("TFE_DEVICE_PLACEMENT_EXPLICIT", TFE_DEVICE_PLACEMENT_EXPLICIT)
1643 .value("TFE_DEVICE_PLACEMENT_WARN", TFE_DEVICE_PLACEMENT_WARN)
1644 .value("TFE_DEVICE_PLACEMENT_SILENT", TFE_DEVICE_PLACEMENT_SILENT)
1645 .value("TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32",
1646 TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
1647 .export_values();
1648
1649 py::enum_<TF_AttrType>(m, "TF_AttrType")
1650 .value("TF_ATTR_STRING", TF_ATTR_STRING)
1651 .value("TF_ATTR_INT", TF_ATTR_INT)
1652 .value("TF_ATTR_FLOAT", TF_ATTR_FLOAT)
1653 .value("TF_ATTR_BOOL", TF_ATTR_BOOL)
1654 .value("TF_ATTR_TYPE", TF_ATTR_TYPE)
1655 .value("TF_ATTR_SHAPE", TF_ATTR_SHAPE)
1656 .value("TF_ATTR_TENSOR", TF_ATTR_TENSOR)
1657 .value("TF_ATTR_PLACEHOLDER", TF_ATTR_PLACEHOLDER)
1658 .value("TF_ATTR_FUNC", TF_ATTR_FUNC)
1659 .export_values();
1660};

/opt/pyrefcon/lib/pyrefcon/models/models/PyUnicode_InternFromString.model

1#ifndef PyUnicode_InternFromString
2struct _object;
3typedef struct _object PyObject;
4PyObject* clang_analyzer_PyObject_New_Reference();
5PyObject* PyUnicode_InternFromString(const char *v) {
6 return clang_analyzer_PyObject_New_Reference();
15
Setting reference count to 1
7}
8#else
9#warning "API PyUnicode_InternFromString is defined as a macro."
10#endif