Bug Summary

File:.cache/bazel/_bazel_alan/39be661231df2a680c9b74265384c13c/execroot/org_tensorflow/tensorflow/python/eager/pywrap_tfe_src.cc
Warning:line 994, column 35
PyObject ownership leak with reference count of 1

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-unknown-linux-gnu -analyze -disable-free -disable-llvm-verifier -discard-value-names -main-file-name pywrap_tfe_src.cc -analyzer-store=region -analyzer-opt-analyze-nested-blocks -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -analyzer-output=html -analyzer-checker=python -analyzer-disable-checker=deadcode -analyzer-config prune-paths=true,suppress-c++-stdlib=true,suppress-null-return-paths=false,crosscheck-with-z3=true,model-path=/opt/pyrefcon/lib/pyrefcon/models/models -analyzer-config experimental-enable-naive-ctu-analysis=true,ctu-dir=/tmp/pyrefcon/tensorflow/csa-scan,ctu-index-name=/tmp/pyrefcon/tensorflow/csa-scan/externalDefMap.txt,ctu-invocation-list=/tmp/pyrefcon/tensorflow/csa-scan/invocations.yaml,display-ctu-progress=false -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -fhalf-no-semantic-interposition -mframe-pointer=all -fmath-errno -fno-rounding-math -mconstructor-aliases -munwind-tables -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/home/pyrefcon/.cache/bazel/_bazel_alan/39be661231df2a680c9b74265384c13c/execroot/org_tensorflow -resource-dir /opt/pyrefcon/lib/clang/13.0.0 -iquote . -iquote bazel-out/k8-opt/bin -iquote external/com_google_absl -iquote bazel-out/k8-opt/bin/external/com_google_absl -iquote external/nsync -iquote bazel-out/k8-opt/bin/external/nsync -iquote external/eigen_archive -iquote bazel-out/k8-opt/bin/external/eigen_archive -iquote external/gif -iquote bazel-out/k8-opt/bin/external/gif -iquote external/libjpeg_turbo -iquote bazel-out/k8-opt/bin/external/libjpeg_turbo -iquote external/com_google_protobuf -iquote bazel-out/k8-opt/bin/external/com_google_protobuf -iquote external/com_googlesource_code_re2 -iquote bazel-out/k8-opt/bin/external/com_googlesource_code_re2 -iquote external/farmhash_archive -iquote bazel-out/k8-opt/bin/external/farmhash_archive -iquote external/fft2d -iquote bazel-out/k8-opt/bin/external/fft2d -iquote external/highwayhash -iquote bazel-out/k8-opt/bin/external/highwayhash -iquote external/zlib -iquote bazel-out/k8-opt/bin/external/zlib -iquote external/double_conversion -iquote bazel-out/k8-opt/bin/external/double_conversion -iquote external/snappy -iquote bazel-out/k8-opt/bin/external/snappy -iquote external/llvm-project -iquote bazel-out/k8-opt/bin/external/llvm-project -iquote external/llvm_terminfo -iquote bazel-out/k8-opt/bin/external/llvm_terminfo -iquote external/llvm_zlib -iquote bazel-out/k8-opt/bin/external/llvm_zlib -iquote external/curl -iquote bazel-out/k8-opt/bin/external/curl -iquote external/boringssl -iquote bazel-out/k8-opt/bin/external/boringssl -iquote external/jsoncpp_git -iquote bazel-out/k8-opt/bin/external/jsoncpp_git -iquote external/local_config_cuda -iquote bazel-out/k8-opt/bin/external/local_config_cuda -iquote external/local_config_rocm -iquote bazel-out/k8-opt/bin/external/local_config_rocm -iquote external/local_config_tensorrt -iquote bazel-out/k8-opt/bin/external/local_config_tensorrt -iquote external/mkl_dnn_v1 -iquote bazel-out/k8-opt/bin/external/mkl_dnn_v1 -iquote external/com_github_grpc_grpc -iquote bazel-out/k8-opt/bin/external/com_github_grpc_grpc -iquote external/upb -iquote bazel-out/k8-opt/bin/external/upb -iquote external/lmdb -iquote bazel-out/k8-opt/bin/external/lmdb -iquote external/png -iquote bazel-out/k8-opt/bin/external/png -iquote external/gemmlowp -iquote bazel-out/k8-opt/bin/external/gemmlowp -iquote external/icu -iquote bazel-out/k8-opt/bin/external/icu -iquote external/org_sqlite -iquote bazel-out/k8-opt/bin/external/org_sqlite -iquote external/dlpack -iquote bazel-out/k8-opt/bin/external/dlpack -iquote external/local_config_python -iquote bazel-out/k8-opt/bin/external/local_config_python -iquote external/pybind11 -iquote bazel-out/k8-opt/bin/external/pybind11 -isystem external/nsync/public -isystem bazel-out/k8-opt/bin/external/nsync/public -isystem third_party/eigen3/mkl_include -isystem bazel-out/k8-opt/bin/third_party/eigen3/mkl_include -isystem external/eigen_archive -isystem bazel-out/k8-opt/bin/external/eigen_archive -isystem external/gif -isystem bazel-out/k8-opt/bin/external/gif -isystem external/com_google_protobuf/src -isystem bazel-out/k8-opt/bin/external/com_google_protobuf/src -isystem external/farmhash_archive/src -isystem bazel-out/k8-opt/bin/external/farmhash_archive/src -isystem external/zlib -isystem bazel-out/k8-opt/bin/external/zlib -isystem external/double_conversion -isystem bazel-out/k8-opt/bin/external/double_conversion -isystem external/llvm-project/llvm/include -isystem bazel-out/k8-opt/bin/external/llvm-project/llvm/include -isystem external/llvm-project/mlir/include -isystem bazel-out/k8-opt/bin/external/llvm-project/mlir/include -isystem external/curl/include -isystem bazel-out/k8-opt/bin/external/curl/include -isystem external/boringssl/src/include -isystem bazel-out/k8-opt/bin/external/boringssl/src/include -isystem external/jsoncpp_git/include -isystem bazel-out/k8-opt/bin/external/jsoncpp_git/include -isystem external/local_config_cuda/cuda -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda -isystem external/local_config_cuda/cuda/cuda/include -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda/cuda/include -isystem external/local_config_rocm/rocm -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm -isystem external/local_config_rocm/rocm/rocm/include -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include -isystem external/local_config_rocm/rocm/rocm/include/rocrand -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include/rocrand -isystem external/local_config_rocm/rocm/rocm/include/roctracer -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include/roctracer -isystem tensorflow/compiler/mlir/tensorflow/include -isystem bazel-out/k8-opt/bin/tensorflow/compiler/mlir/tensorflow/include -isystem tensorflow/compiler/mlir/hlo/include -isystem bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/include -isystem tensorflow/compiler/mlir/xla/include -isystem bazel-out/k8-opt/bin/tensorflow/compiler/mlir/xla/include -isystem external/mkl_dnn_v1/include -isystem bazel-out/k8-opt/bin/external/mkl_dnn_v1/include -isystem external/mkl_dnn_v1/src -isystem bazel-out/k8-opt/bin/external/mkl_dnn_v1/src -isystem external/mkl_dnn_v1/src/common -isystem bazel-out/k8-opt/bin/external/mkl_dnn_v1/src/common -isystem external/mkl_dnn_v1/src/common/ittnotify -isystem bazel-out/k8-opt/bin/external/mkl_dnn_v1/src/common/ittnotify -isystem external/mkl_dnn_v1/src/cpu -isystem bazel-out/k8-opt/bin/external/mkl_dnn_v1/src/cpu -isystem external/mkl_dnn_v1/src/cpu/gemm -isystem bazel-out/k8-opt/bin/external/mkl_dnn_v1/src/cpu/gemm -isystem external/mkl_dnn_v1/src/cpu/x64/xbyak -isystem bazel-out/k8-opt/bin/external/mkl_dnn_v1/src/cpu/x64/xbyak -isystem external/com_github_grpc_grpc/include -isystem bazel-out/k8-opt/bin/external/com_github_grpc_grpc/include -isystem external/com_github_grpc_grpc/src/core/ext/upb-generated -isystem bazel-out/k8-opt/bin/external/com_github_grpc_grpc/src/core/ext/upb-generated -isystem external/com_github_grpc_grpc/third_party/address_sorting/include -isystem bazel-out/k8-opt/bin/external/com_github_grpc_grpc/third_party/address_sorting/include -isystem external/png -isystem bazel-out/k8-opt/bin/external/png -isystem external/icu/icu4c/source/common -isystem bazel-out/k8-opt/bin/external/icu/icu4c/source/common -isystem external/llvm-project/mlir/lib/Conversions/GPUToSPIRV -isystem bazel-out/k8-opt/bin/external/llvm-project/mlir/lib/Conversions/GPUToSPIRV -isystem external/llvm-project/mlir/lib/Conversion/MemRefToSPIRV -isystem bazel-out/k8-opt/bin/external/llvm-project/mlir/lib/Conversion/MemRefToSPIRV -isystem external/llvm-project/mlir/lib/Conversion/StandardToSPIRV -isystem bazel-out/k8-opt/bin/external/llvm-project/mlir/lib/Conversion/StandardToSPIRV -isystem external/llvm-project/mlir/lib/Conversion/MathToSPIRV -isystem bazel-out/k8-opt/bin/external/llvm-project/mlir/lib/Conversion/MathToSPIRV -isystem external/llvm-project/mlir/lib/Conversion/TosaToLinalg -isystem bazel-out/k8-opt/bin/external/llvm-project/mlir/lib/Conversion/TosaToLinalg -isystem external/llvm-project/mlir/lib/Conversion/TosaToSCF -isystem bazel-out/k8-opt/bin/external/llvm-project/mlir/lib/Conversion/TosaToSCF -isystem external/llvm-project/mlir/lib/Conversion/TosaToStandard -isystem bazel-out/k8-opt/bin/external/llvm-project/mlir/lib/Conversion/TosaToStandard -isystem external/llvm-project/llvm/lib/Target/X86 -isystem bazel-out/k8-opt/bin/external/llvm-project/llvm/lib/Target/X86 -isystem external/local_config_python/numpy_include -isystem bazel-out/k8-opt/bin/external/local_config_python/numpy_include -isystem /opt/pyrefcon/lib/pyrefcon/models/python3.8 -isystem /opt/pyrefcon/lib/pyrefcon/models/python3.8 -isystem external/pybind11/include -isystem bazel-out/k8-opt/bin/external/pybind11/include -U _FORTIFY_SOURCE -D _FORTIFY_SOURCE=1 -D NDEBUG -D SQLITE_OMIT_DEPRECATED -D EIGEN_ALTIVEC_USE_CUSTOM_PACK=0 -D GRPC_ARES=0 -D TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL -D TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL -D HAVE_SYS_UIO_H -D TF_USE_SNAPPY -D CURL_STATICLIB -D EIGEN_MPL2_ONLY -D EIGEN_MAX_ALIGN_BYTES=64 -D LLVM_ON_UNIX=1 -D HAVE_BACKTRACE=1 -D BACKTRACE_HEADER=<execinfo.h> -D LTDL_SHLIB_EXT=".so" -D LLVM_PLUGIN_EXT=".so" -D LLVM_ENABLE_THREADS=1 -D HAVE_SYSEXITS_H=1 -D HAVE_UNISTD_H=1 -D HAVE_STRERROR_R=1 -D HAVE_LIBPTHREAD=1 -D HAVE_PTHREAD_GETNAME_NP=1 -D HAVE_PTHREAD_SETNAME_NP=1 -D HAVE_PTHREAD_GETSPECIFIC=1 -D HAVE_REGISTER_FRAME=1 -D HAVE_DEREGISTER_FRAME=1 -D _GNU_SOURCE -D HAVE_LINK_H=1 -D HAVE_LSEEK64=1 -D HAVE_MALLINFO=1 -D HAVE_POSIX_FALLOCATE=1 -D HAVE_SBRK=1 -D HAVE_STRUCT_STAT_ST_MTIM_TV_NSEC=1 -D LLVM_NATIVE_ARCH="X86" -D LLVM_NATIVE_ASMPARSER=LLVMInitializeX86AsmParser -D LLVM_NATIVE_ASMPRINTER=LLVMInitializeX86AsmPrinter -D LLVM_NATIVE_DISASSEMBLER=LLVMInitializeX86Disassembler -D LLVM_NATIVE_TARGET=LLVMInitializeX86Target -D LLVM_NATIVE_TARGETINFO=LLVMInitializeX86TargetInfo -D LLVM_NATIVE_TARGETMC=LLVMInitializeX86TargetMC -D LLVM_NATIVE_TARGETMCA=LLVMInitializeX86TargetMCA -D LLVM_HOST_TRIPLE="x86_64-unknown-linux-gnu" -D LLVM_DEFAULT_TARGET_TRIPLE="x86_64-unknown-linux-gnu" -D __STDC_LIMIT_MACROS -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinAttributesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinDialectIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinLocationAttributesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinTypeInterfacesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinTypesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/CallOpInterfacesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/CastOpInterfacesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/InferTypeOpInterfaceIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/OpAsmInterfaceIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/RegionKindInterfaceIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SideEffectInterfacesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SubElementInterfacesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SymbolInterfacesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/TensorEncodingIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ControlFlowInterfacesIncGen -I bazel-out/k8-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual -I bazel-out/k8-opt/bin/external/local_config_tensorrt/_virtual_includes/tensorrt_headers -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ParserTokenKinds -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/DerivedAttributeOpInterfaceIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/LoopLikeInterfaceIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/StandardOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/VectorInterfacesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/AffineMemoryOpInterfacesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/AffineOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/CopyOpInterfaceIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/MemRefBaseIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/MemRefOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/TensorOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ViewLikeInterfaceIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/LinalgInterfacesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/LinalgStructuredOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/LinalgOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/MathBaseIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/MathOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SCFIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SCFPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/TilingInterfaceIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ComplexBaseIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ComplexOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/PDLOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/PDLTypesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/PDLInterpOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ConversionPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/TransformsPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/QuantOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/QuantPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/MLIRShapeCanonicalizationIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ShapeOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/LLVMDialectAttributesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/LLVMDialectInterfaceIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/LLVMOpsIncGen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/canonicalize_inc_gen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/chlo_ops_inc_gen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/hlo_ops_base_inc_gen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/hlo_ops_inc_gen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/hlo_ops_pattern_gen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/lhlo_ops_inc_gen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/lhlo_ops_structs_inc_gen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/chlo_legalize_to_hlo_inc_gen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/disc_ral_ops_inc_gen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/lhlo_gpu_ops_enums_inc_gen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/lhlo_gpu_ops_inc_gen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/lhlo_gpu_ops_structs_inc_gen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/DiscRalPassIncGen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/LmhloPassIncGen -I bazel-out/k8-opt/bin/tensorflow/compiler/mlir/hlo/_virtual_includes/MhloPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/llvm/_virtual_includes/InstCombineTableGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/VectorOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/LinalgPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/StandardOpsTransformsPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/LLVMConversionIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/LLVMPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/OpenMPOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/AMXIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ArmNeonIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ArmSVEIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/X86VectorIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/AffinePassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/AsyncOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/AsyncPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/OpenACCOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/DLTIBaseIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/GPUBaseIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/GPUOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/GPUPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ParallelLoopMapperAttrGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/NVVMConversionIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/NVVMOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/GPUToNVVMGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/GPUToROCDLTGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ROCDLOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SPIRVAttrUtilsGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SPIRVAvailabilityIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SPIRVCanonicalizationIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SPIRVOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SPIRVSerializationGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ShapeToStandardGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/TosaDialectIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/TosaInterfacesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/TosaPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/EmitCAttributesIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/EmitCOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/MemRefPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SPIRVPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ShapeTransformsPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SparseTensorAttrDefsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SparseTensorOpsIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SparseTensorPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/TensorPassIncGen -I bazel-out/k8-opt/bin/external/llvm-project/llvm/_virtual_includes/X86CodeGen -I bazel-out/k8-opt/bin/external/llvm-project/llvm/_virtual_includes/X86CommonTableGen -I bazel-out/k8-opt/bin/external/llvm-project/llvm/_virtual_includes/X86Info -I bazel-out/k8-opt/bin/external/llvm-project/llvm/_virtual_includes/X86UtilsAndDesc -I bazel-out/k8-opt/bin/external/pybind11/_virtual_includes/pybind11 -D AUTOLOAD_DYNAMIC_KERNELS -D __DATE__="redacted" -D __TIMESTAMP__="redacted" -D __TIME__="redacted" -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /opt/pyrefcon/lib/clang/13.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -O2 -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -Wno-builtin-macro-redefined -w -std=c++14 -fdeprecated-macro -fdebug-compilation-dir=/home/pyrefcon/.cache/bazel/_bazel_alan/39be661231df2a680c9b74265384c13c/execroot/org_tensorflow -ferror-limit 19 -stack-protector 1 -fgnuc-version=4.2.1 -fcxx-exceptions -fexceptions -vectorize-loops -vectorize-slp -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/pyrefcon/tensorflow/csa-scan/reports -x c++ tensorflow/python/eager/pywrap_tfe_src.cc

tensorflow/python/eager/pywrap_tfe_src.cc

1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <atomic>
17#include <cstring>
18#include <unordered_map>
19
20#include "absl/debugging/leak_check.h"
21#include "absl/strings/str_cat.h"
22#include "absl/types/variant.h"
23#include "tensorflow/c/c_api.h"
24#include "tensorflow/c/c_api_internal.h"
25#include "tensorflow/c/eager/c_api.h"
26#include "tensorflow/c/eager/c_api_internal.h"
27#include "tensorflow/c/eager/tape.h"
28#include "tensorflow/c/eager/tfe_context_internal.h"
29#include "tensorflow/c/eager/tfe_op_internal.h"
30#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
31#include "tensorflow/c/tf_status.h"
32#include "tensorflow/core/framework/types.pb.h"
33#include "tensorflow/core/lib/core/errors.h"
34#include "tensorflow/core/lib/gtl/cleanup.h"
35#include "tensorflow/core/lib/gtl/compactptrset.h"
36#include "tensorflow/core/lib/gtl/flatmap.h"
37#include "tensorflow/core/lib/gtl/flatset.h"
38#include "tensorflow/core/lib/strings/strcat.h"
39#include "tensorflow/core/lib/strings/stringprintf.h"
40#include "tensorflow/core/platform/casts.h"
41#include "tensorflow/core/platform/errors.h"
42#include "tensorflow/core/platform/mutex.h"
43#include "tensorflow/core/platform/protobuf.h"
44#include "tensorflow/core/platform/status.h"
45#include "tensorflow/core/platform/types.h"
46#include "tensorflow/core/profiler/lib/traceme.h"
47#include "tensorflow/core/util/managed_stack_trace.h"
48#include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
49#include "tensorflow/python/eager/pywrap_tensor.h"
50#include "tensorflow/python/eager/pywrap_tfe.h"
51#include "tensorflow/python/lib/core/py_util.h"
52#include "tensorflow/python/lib/core/safe_ptr.h"
53#include "tensorflow/python/util/stack_trace.h"
54#include "tensorflow/python/util/util.h"
55
56using tensorflow::Status;
57using tensorflow::string;
58using tensorflow::strings::Printf;
59
60namespace {
61// NOTE: Items are retrieved from and returned to these unique_ptrs, and they
62// act as arenas. This is important if the same thread requests 2 items without
63// releasing one.
64// The following sequence of events on the same thread will still succeed:
65// - GetOp <- Returns existing.
66// - GetOp <- Allocates and returns a new pointer.
67// - ReleaseOp <- Sets the item in the unique_ptr.
68// - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one.
69// This occurs when a PyFunc kernel is run. This behavior makes it safe in that
70// case, as well as the case where python decides to reuse the underlying
71// C++ thread in 2 python threads case.
72struct OpDeleter {
73 void operator()(TFE_Op* op) const { TFE_DeleteOp(op); }
74};
75thread_local std::unordered_map<TFE_Context*,
76 std::unique_ptr<TFE_Op, OpDeleter>>
77 thread_local_eager_operation_map; // NOLINT
78thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT
79 nullptr;
80
81std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) {
82 auto it = thread_local_eager_operation_map.find(ctx);
83 if (it == thread_local_eager_operation_map.end()) {
84 return nullptr;
85 }
86 return std::move(it->second);
87}
88
89TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
90 const char* raw_device_name, TF_Status* status) {
91 auto op = ReleaseThreadLocalOp(ctx);
92 if (!op) {
93 op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation()));
94 }
95 status->status =
96 tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name);
97 if (!status->status.ok()) {
98 op.reset();
99 }
100 return op.release();
101}
102
103void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
104 if (op) {
105 tensorflow::unwrap(op)->Clear();
106 thread_local_eager_operation_map[ctx].reset(op);
107 }
108}
109
110TF_Status* ReleaseThreadLocalStatus() {
111 if (thread_local_tf_status == nullptr) {
112 return nullptr;
113 }
114 return thread_local_tf_status.release();
115}
116
117struct InputInfo {
118 InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
119
120 int i;
121 bool is_list = false;
122};
123
124// Takes in output gradients, returns input gradients.
125typedef std::function<PyObject*(PyObject*, const std::vector<int64_t>&)>
126 PyBackwardFunction;
127
128using AttrToInputsMap =
129 tensorflow::gtl::FlatMap<string,
130 tensorflow::gtl::InlinedVector<InputInfo, 4>>;
131
132tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
133 static auto* all_attr_to_input_maps =
134 new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
135 return all_attr_to_input_maps;
136}
137
138// This function doesn't use a lock, since we depend on the GIL directly.
139AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) {
140#if PY_MAJOR_VERSION3 >= 3 && PY_MINOR_VERSION8 >= 4
141 DCHECK(PyGILState_Check())while (false && (PyGILState_Check())) ::tensorflow::internal
::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 141)
142 << "This function needs to hold the GIL when called.";
143#endif
144 auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
145 auto* output =
146 tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
147 if (output != nullptr) {
148 return output;
149 }
150
151 std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
152
153 // Store a list of InputIndex -> List of corresponding inputs.
154 for (int i = 0; i < op_def.input_arg_size(); i++) {
155 if (!op_def.input_arg(i).type_attr().empty()) {
156 auto it = m->find(op_def.input_arg(i).type_attr());
157 if (it == m->end()) {
158 it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
159 }
160 it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
161 }
162 }
163
164 auto* retval = m.get();
165 (*all_attr_to_input_maps)[op_def.name()] = m.release();
166
167 return retval;
168}
169
170// This function doesn't use a lock, since we depend on the GIL directly.
171tensorflow::gtl::FlatMap<
172 string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>*
173GetAllAttrToDefaultsMaps() {
174 static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap<
175 string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>;
176 return all_attr_to_defaults_maps;
177}
178
179tensorflow::gtl::FlatMap<string, tensorflow::DataType>*
180GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) {
181#if PY_MAJOR_VERSION3 >= 3 && PY_MINOR_VERSION8 >= 4
182 DCHECK(PyGILState_Check())while (false && (PyGILState_Check())) ::tensorflow::internal
::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 182)
183 << "This function needs to hold the GIL when called.";
184#endif
185 auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps();
186 auto* output =
187 tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name());
188 if (output != nullptr) {
189 return output;
190 }
191
192 auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>;
193
194 for (const auto& attr : op_def.attr()) {
195 if (attr.type() == "type" && attr.has_default_value()) {
196 new_map->insert({attr.name(), attr.default_value().type()});
197 }
198 }
199
200 (*all_attr_to_defaults_maps)[op_def.name()] = new_map;
201
202 return new_map;
203}
204
205struct FastPathOpExecInfo {
206 TFE_Context* ctx;
207 const char* device_name;
208
209 bool run_callbacks;
210 bool run_post_exec_callbacks;
211 bool run_gradient_callback;
212
213 // The op name of the main op being executed.
214 PyObject* name;
215 // The op type name of the main op being executed.
216 PyObject* op_name;
217 PyObject* callbacks;
218
219 // All the args passed into the FastPathOpExecInfo.
220 PyObject* args;
221
222 // DTypes can come from another input that has the same attr. So build that
223 // map.
224 const AttrToInputsMap* attr_to_inputs_map;
225 const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes;
226 tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
227};
228
229#define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \
230 bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \
231 type* value) { \
232 if (check_fn(py_value)) { \
233 *value = static_cast<type>(parse_fn(py_value)); \
234 return true; \
235 } else { \
236 TF_SetStatus(status, TF_INVALID_ARGUMENT, \
237 tensorflow::strings::StrCat( \
238 "Expecting " #type " value for attr ", key, ", got ", \
239 py_value->ob_type->tp_name) \
240 .c_str()); \
241 return false; \
242 } \
243 }
244
245#if PY_MAJOR_VERSION3 >= 3
246PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
247PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong)
248#else
249PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
250#endif
251PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
252#undef PARSE_VALUE
253
254#if PY_MAJOR_VERSION3 < 3
255bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
256 int64_t* value) {
257 if (PyInt_Check(py_value)) {
258 *value = static_cast<int64_t>(PyInt_AsLong(py_value));
259 return true;
260 } else if (PyLong_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & (
(1UL << 24))) != 0)
) {
261 *value = static_cast<int64_t>(PyLong_AsLong(py_value));
262 return true;
263 }
264 TF_SetStatus(
265 status, TF_INVALID_ARGUMENT,
266 tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
267 ", got ", py_value->ob_type->tp_name)
268 .c_str());
269 return false;
270}
271#endif
272
273Py_ssize_t TensorShapeNumDims(PyObject* value) {
274 const auto size = PySequence_Size(value);
275 if (size == -1) {
276 // TensorShape.__len__ raises an error in the scenario where the shape is an
277 // unknown, which needs to be cleared.
278 // TODO(nareshmodi): ensure that this is actually a TensorShape.
279 PyErr_Clear();
280 }
281 return size;
282}
283
284bool IsInteger(PyObject* py_value) {
285#if PY_MAJOR_VERSION3 >= 3
286 return PyLong_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & (
(1UL << 24))) != 0)
;
287#else
288 return PyInt_Check(py_value) || PyLong_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & (
(1UL << 24))) != 0)
;
289#endif
290}
291
292// This function considers a Dimension._value of None to be valid, and sets the
293// value to be -1 in that case.
294bool ParseDimensionValue(const string& key, PyObject* py_value,
295 TF_Status* status, int64_t* value) {
296 if (IsInteger(py_value)) {
297 return ParseInt64Value(key, py_value, status, value);
298 }
299
300 tensorflow::Safe_PyObjectPtr dimension_value(
301 PyObject_GetAttrString(py_value, "_value"));
302 if (dimension_value == nullptr) {
303 PyErr_Clear();
304 TF_SetStatus(
305 status, TF_INVALID_ARGUMENT,
306 tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
307 ", got ", py_value->ob_type->tp_name)
308 .c_str());
309 return false;
310 }
311
312 if (dimension_value.get() == Py_None(&_Py_NoneStruct)) {
313 *value = -1;
314 return true;
315 }
316
317 return ParseInt64Value(key, dimension_value.get(), status, value);
318}
319
320bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
321 tensorflow::StringPiece* value) {
322 if (PyBytes_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & (
(1UL << 27))) != 0)
) {
323 Py_ssize_t size = 0;
324 char* buf = nullptr;
325 if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
326 *value = tensorflow::StringPiece(buf, size);
327 return true;
328 }
329#if PY_MAJOR_VERSION3 >= 3
330 if (PyUnicode_Check(py_value)((((((PyObject*)(py_value))->ob_type))->tp_flags & (
(1UL << 28))) != 0)
) {
331 Py_ssize_t size = 0;
332 const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
333 if (buf == nullptr) return false;
334 *value = tensorflow::StringPiece(buf, size);
335 return true;
336 }
337#endif
338 TF_SetStatus(
339 status, TF_INVALID_ARGUMENT,
340 tensorflow::strings::StrCat("Expecting a string value for attr ", key,
341 ", got ", py_value->ob_type->tp_name)
342 .c_str());
343 return false;
344}
345
346bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
347 unsigned char* value) {
348 *value = PyObject_IsTrue(py_value);
349 return true;
350}
351
352// The passed in py_value is expected to be an object of the python type
353// dtypes.DType or an int.
354bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
355 int* value) {
356 if (IsInteger(py_value)) {
357 return ParseIntValue(key, py_value, status, value);
358 }
359
360 tensorflow::Safe_PyObjectPtr py_type_enum(
361 PyObject_GetAttrString(py_value, "_type_enum"));
362 if (py_type_enum == nullptr) {
363 PyErr_Clear();
364 TF_SetStatus(
365 status, TF_INVALID_ARGUMENT,
366 tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
367 ", got ", py_value->ob_type->tp_name)
368 .c_str());
369 return false;
370 }
371
372 return ParseIntValue(key, py_type_enum.get(), status, value);
373}
374
375bool SetOpAttrList(TFE_Context* ctx, TFE_Op* op, const char* key,
376 PyObject* py_list, TF_AttrType type,
377 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
378 TF_Status* status) {
379 if (!PySequence_Check(py_list)) {
380 TF_SetStatus(
381 status, TF_INVALID_ARGUMENT,
382 tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
383 ", got ", py_list->ob_type->tp_name)
384 .c_str());
385 return false;
386 }
387 const int num_values = PySequence_Size(py_list);
388 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
389
390#define PARSE_LIST(c_type, parse_fn) \
391 std::unique_ptr<c_type[]> values(new c_type[num_values]); \
392 for (int i = 0; i < num_values; ++i) { \
393 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence->
sq_item(py_list, i) )
); \
394 if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
395 }
396
397 if (type == TF_ATTR_STRING) {
398 std::unique_ptr<const void*[]> values(new const void*[num_values]);
399 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
400 for (int i = 0; i < num_values; ++i) {
401 tensorflow::StringPiece value;
402 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence->
sq_item(py_list, i) )
);
403 if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
404 values[i] = value.data();
405 lengths[i] = value.size();
406 }
407 TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
408 } else if (type == TF_ATTR_INT) {
409 PARSE_LIST(int64_t, ParseInt64Value);
410 TFE_OpSetAttrIntList(op, key, values.get(), num_values);
411 } else if (type == TF_ATTR_FLOAT) {
412 PARSE_LIST(float, ParseFloatValue);
413 TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
414 } else if (type == TF_ATTR_BOOL) {
415 PARSE_LIST(unsigned char, ParseBoolValue);
416 TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
417 } else if (type == TF_ATTR_TYPE) {
418 PARSE_LIST(int, ParseTypeValue);
419 TFE_OpSetAttrTypeList(op, key,
420 reinterpret_cast<const TF_DataType*>(values.get()),
421 num_values);
422 } else if (type == TF_ATTR_SHAPE) {
423 // Make one pass through the input counting the total number of
424 // dims across all the input lists.
425 int total_dims = 0;
426 for (int i = 0; i < num_values; ++i) {
427 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence->
sq_item(py_list, i) )
);
428 if (py_value.get() != Py_None(&_Py_NoneStruct)) {
429 if (!PySequence_Check(py_value.get())) {
430 TF_SetStatus(
431 status, TF_INVALID_ARGUMENT,
432 tensorflow::strings::StrCat(
433 "Expecting None or sequence value for element", i,
434 " of attr ", key, ", got ", py_value->ob_type->tp_name)
435 .c_str());
436 return false;
437 }
438 const auto size = TensorShapeNumDims(py_value.get());
439 if (size >= 0) {
440 total_dims += size;
441 }
442 }
443 }
444 // Allocate a buffer that can fit all of the dims together.
445 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
446 // Copy the input dims into the buffer and set dims to point to
447 // the start of each list's dims.
448 std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
449 std::unique_ptr<int[]> num_dims(new int[num_values]);
450 int64_t* offset = buffer.get();
451 for (int i = 0; i < num_values; ++i) {
452 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence->
sq_item(py_list, i) )
);
453 if (py_value.get() == Py_None(&_Py_NoneStruct)) {
454 dims[i] = nullptr;
455 num_dims[i] = -1;
456 } else {
457 const auto size = TensorShapeNumDims(py_value.get());
458 if (size == -1) {
459 dims[i] = nullptr;
460 num_dims[i] = -1;
461 continue;
462 }
463 dims[i] = offset;
464 num_dims[i] = size;
465 for (int j = 0; j < size; ++j) {
466 tensorflow::Safe_PyObjectPtr inner_py_value(
467 PySequence_ITEM(py_value.get(), j)( (((PyObject*)(py_value.get()))->ob_type)->tp_as_sequence
->sq_item(py_value.get(), j) )
);
468 if (inner_py_value.get() == Py_None(&_Py_NoneStruct)) {
469 *offset = -1;
470 } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
471 offset)) {
472 return false;
473 }
474 ++offset;
475 }
476 }
477 }
478 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
479 status);
480 if (!status->status.ok()) return false;
481 } else if (type == TF_ATTR_FUNC) {
482 std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
483 for (int i = 0; i < num_values; ++i) {
484 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)( (((PyObject*)(py_list))->ob_type)->tp_as_sequence->
sq_item(py_list, i) )
);
485 // Allow:
486 // (1) String function name, OR
487 // (2) A Python object with a .name attribute
488 // (A crude test for being a
489 // tensorflow.python.framework.function._DefinedFunction)
490 // (which is what the various "defun" or "Defun" decorators do).
491 // And in the future also allow an object that can encapsulate
492 // the function name and its attribute values.
493 tensorflow::StringPiece func_name;
494 if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
495 PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
496 if (name_attr == nullptr ||
497 !ParseStringValue(key, name_attr, status, &func_name)) {
498 TF_SetStatus(
499 status, TF_INVALID_ARGUMENT,
500 tensorflow::strings::StrCat(
501 "unable to set function value attribute from a ",
502 py_value.get()->ob_type->tp_name,
503 " object. If you think this is an error, please file an "
504 "issue at "
505 "https://github.com/tensorflow/tensorflow/issues/new")
506 .c_str());
507 return false;
508 }
509 }
510 funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
511 if (!status->status.ok()) return false;
512 }
513 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
514 if (!status->status.ok()) return false;
515 } else {
516 TF_SetStatus(status, TF_UNIMPLEMENTED,
517 tensorflow::strings::StrCat("Attr ", key,
518 " has unhandled list type ", type)
519 .c_str());
520 return false;
521 }
522#undef PARSE_LIST
523 return true;
524}
525
526TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
527 TF_Status* status) {
528 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
529 for (const auto& attr : func.attr()) {
530 if (!status->status.ok()) return nullptr;
531 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
532 if (!status->status.ok()) return nullptr;
533 }
534 return func_op;
535}
536
537void SetOpAttrListDefault(
538 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
539 const char* key, TF_AttrType type,
540 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
541 TF_Status* status) {
542 if (type == TF_ATTR_STRING) {
543 int num_values = attr.default_value().list().s_size();
544 std::unique_ptr<const void*[]> values(new const void*[num_values]);
545 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
546 (*attr_list_sizes)[key] = num_values;
547 for (int i = 0; i < num_values; i++) {
548 const string& v = attr.default_value().list().s(i);
549 values[i] = v.data();
550 lengths[i] = v.size();
551 }
552 TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
553 } else if (type == TF_ATTR_INT) {
554 int num_values = attr.default_value().list().i_size();
555 std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
556 (*attr_list_sizes)[key] = num_values;
557 for (int i = 0; i < num_values; i++) {
558 values[i] = attr.default_value().list().i(i);
559 }
560 TFE_OpSetAttrIntList(op, key, values.get(), num_values);
561 } else if (type == TF_ATTR_FLOAT) {
562 int num_values = attr.default_value().list().f_size();
563 std::unique_ptr<float[]> values(new float[num_values]);
564 (*attr_list_sizes)[key] = num_values;
565 for (int i = 0; i < num_values; i++) {
566 values[i] = attr.default_value().list().f(i);
567 }
568 TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
569 } else if (type == TF_ATTR_BOOL) {
570 int num_values = attr.default_value().list().b_size();
571 std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
572 (*attr_list_sizes)[key] = num_values;
573 for (int i = 0; i < num_values; i++) {
574 values[i] = attr.default_value().list().b(i);
575 }
576 TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
577 } else if (type == TF_ATTR_TYPE) {
578 int num_values = attr.default_value().list().type_size();
579 std::unique_ptr<int[]> values(new int[num_values]);
580 (*attr_list_sizes)[key] = num_values;
581 for (int i = 0; i < num_values; i++) {
582 values[i] = attr.default_value().list().type(i);
583 }
584 TFE_OpSetAttrTypeList(op, key,
585 reinterpret_cast<const TF_DataType*>(values.get()),
586 attr.default_value().list().type_size());
587 } else if (type == TF_ATTR_SHAPE) {
588 int num_values = attr.default_value().list().shape_size();
589 (*attr_list_sizes)[key] = num_values;
590 int total_dims = 0;
591 for (int i = 0; i < num_values; ++i) {
592 if (!attr.default_value().list().shape(i).unknown_rank()) {
593 total_dims += attr.default_value().list().shape(i).dim_size();
594 }
595 }
596 // Allocate a buffer that can fit all of the dims together.
597 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
598 // Copy the input dims into the buffer and set dims to point to
599 // the start of each list's dims.
600 std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
601 std::unique_ptr<int[]> num_dims(new int[num_values]);
602 int64_t* offset = buffer.get();
603 for (int i = 0; i < num_values; ++i) {
604 const auto& shape = attr.default_value().list().shape(i);
605 if (shape.unknown_rank()) {
606 dims[i] = nullptr;
607 num_dims[i] = -1;
608 } else {
609 for (int j = 0; j < shape.dim_size(); j++) {
610 *offset = shape.dim(j).size();
611 ++offset;
612 }
613 }
614 }
615 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
616 status);
617 } else if (type == TF_ATTR_FUNC) {
618 int num_values = attr.default_value().list().func_size();
619 (*attr_list_sizes)[key] = num_values;
620 std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
621 for (int i = 0; i < num_values; i++) {
622 funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
623 }
624 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
625 } else {
626 TF_SetStatus(status, TF_UNIMPLEMENTED,
627 "Lists of tensors are not yet implemented for default valued "
628 "attributes for an operation.");
629 }
630}
631
632bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key,
633 PyObject* py_value, TF_AttrType type,
634 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
635 TF_Status* status) {
636 if (type == TF_ATTR_STRING) {
637 tensorflow::StringPiece value;
638 if (!ParseStringValue(key, py_value, status, &value)) return false;
639 TFE_OpSetAttrString(op, key, value.data(), value.size());
640 } else if (type == TF_ATTR_INT) {
641 int64_t value;
642 if (!ParseInt64Value(key, py_value, status, &value)) return false;
643 TFE_OpSetAttrInt(op, key, value);
644 // attr_list_sizes is set for all int attributes (since at this point we are
645 // not aware if that attribute might be used to calculate the size of an
646 // output list or not).
647 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
648 } else if (type == TF_ATTR_FLOAT) {
649 float value;
650 if (!ParseFloatValue(key, py_value, status, &value)) return false;
651 TFE_OpSetAttrFloat(op, key, value);
652 } else if (type == TF_ATTR_BOOL) {
653 unsigned char value;
654 if (!ParseBoolValue(key, py_value, status, &value)) return false;
655 TFE_OpSetAttrBool(op, key, value);
656 } else if (type == TF_ATTR_TYPE) {
657 int value;
658 if (!ParseTypeValue(key, py_value, status, &value)) return false;
659 TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
660 } else if (type == TF_ATTR_SHAPE) {
661 if (py_value == Py_None(&_Py_NoneStruct)) {
662 TFE_OpSetAttrShape(op, key, nullptr, -1, status);
663 } else {
664 if (!PySequence_Check(py_value)) {
665 TF_SetStatus(status, TF_INVALID_ARGUMENT,
666 tensorflow::strings::StrCat(
667 "Expecting None or sequence value for attr", key,
668 ", got ", py_value->ob_type->tp_name)
669 .c_str());
670 return false;
671 }
672 const auto num_dims = TensorShapeNumDims(py_value);
673 if (num_dims == -1) {
674 TFE_OpSetAttrShape(op, key, nullptr, -1, status);
675 return true;
676 }
677 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
678 for (int i = 0; i < num_dims; ++i) {
679 tensorflow::Safe_PyObjectPtr inner_py_value(
680 PySequence_ITEM(py_value, i)( (((PyObject*)(py_value))->ob_type)->tp_as_sequence->
sq_item(py_value, i) )
);
681 if (inner_py_value.get() == Py_None(&_Py_NoneStruct)) {
682 dims[i] = -1;
683 } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
684 &dims[i])) {
685 return false;
686 }
687 }
688 TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
689 }
690 if (!status->status.ok()) return false;
691 } else if (type == TF_ATTR_FUNC) {
692 // Allow:
693 // (1) String function name, OR
694 // (2) A Python object with a .name attribute
695 // (A crude test for being a
696 // tensorflow.python.framework.function._DefinedFunction)
697 // (which is what the various "defun" or "Defun" decorators do).
698 // And in the future also allow an object that can encapsulate
699 // the function name and its attribute values.
700 tensorflow::StringPiece func_name;
701 if (!ParseStringValue(key, py_value, status, &func_name)) {
702 PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
703 if (name_attr == nullptr ||
704 !ParseStringValue(key, name_attr, status, &func_name)) {
705 TF_SetStatus(
706 status, TF_INVALID_ARGUMENT,
707 tensorflow::strings::StrCat(
708 "unable to set function value attribute from a ",
709 py_value->ob_type->tp_name,
710 " object. If you think this is an error, please file an issue "
711 "at https://github.com/tensorflow/tensorflow/issues/new")
712 .c_str());
713 return false;
714 }
715 }
716 TF_SetStatus(status, TF_OK, "");
717 TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
718 } else {
719 TF_SetStatus(
720 status, TF_UNIMPLEMENTED,
721 tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
722 .c_str());
723 return false;
724 }
725 return true;
726}
727
728void SetOpAttrScalarDefault(
729 TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
730 const char* attr_name,
731 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
732 TF_Status* status) {
733 SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
734 if (default_value.value_case() == tensorflow::AttrValue::kI) {
735 (*attr_list_sizes)[attr_name] = default_value.i();
736 }
737}
738
739// start_index is the index at which the Tuple/List attrs will start getting
740// processed.
741void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
742 TF_Status* out_status) {
743 if (attrs == Py_None(&_Py_NoneStruct)) return;
744 Py_ssize_t len = PyTuple_GET_SIZE(attrs)(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject
*)(attrs))))->ob_size)
- start_index;
745 if ((len & 1) != 0) {
746 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
747 "Expecting attrs tuple to have even length.");
748 return;
749 }
750 // Parse attrs
751 for (Py_ssize_t i = 0; i < len; i += 2) {
752 PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i)(((static_cast<void> (0)), (PyTupleObject *)(attrs))->
ob_item[start_index + i])
;
753 PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1)(((static_cast<void> (0)), (PyTupleObject *)(attrs))->
ob_item[start_index + i + 1])
;
754#if PY_MAJOR_VERSION3 >= 3
755 const char* key = PyBytes_Check(py_key)((((((PyObject*)(py_key))->ob_type))->tp_flags & ((
1UL << 27))) != 0)
? PyBytes_AsString(py_key)
756 : PyUnicode_AsUTF8(py_key);
757#else
758 const char* key = PyBytes_AsString(py_key);
759#endif
760 unsigned char is_list = 0;
761 const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
762 if (!out_status->status.ok()) return;
763 if (is_list != 0) {
764 if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
765 return;
766 } else {
767 if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
768 return;
769 }
770 }
771}
772
773// This function will set the op attrs required. If an attr has the value of
774// None, then it will read the AttrDef to get the default value and set that
775// instead. Any failure in this function will simply fall back to the slow
776// path.
777void SetOpAttrWithDefaults(
778 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
779 const char* attr_name, PyObject* attr_value,
780 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
781 TF_Status* status) {
782 unsigned char is_list = 0;
783 const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
784 if (!status->status.ok()) return;
785 if (attr_value == Py_None(&_Py_NoneStruct)) {
786 if (is_list != 0) {
787 SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
788 status);
789 } else {
790 SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
791 attr_list_sizes, status);
792 }
793 } else {
794 if (is_list != 0) {
795 SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
796 status);
797 } else {
798 SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
799 status);
800 }
801 }
802}
803
804PyObject* GetPythonObjectFromInt(int num) {
805#if PY_MAJOR_VERSION3 >= 3
806 return PyLong_FromLong(num);
807#else
808 return PyInt_FromLong(num);
809#endif
810}
811
812// Python subclass of Exception that is created on not ok Status.
813tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
814PyObject* exception_class TF_GUARDED_BY(exception_class_mutex)__attribute__((guarded_by(exception_class_mutex))) = nullptr;
815
816// Python subclass of Exception that is created to signal fallback.
817PyObject* fallback_exception_class = nullptr;
818
819// Python function that returns input gradients given output gradients.
820PyObject* gradient_function = nullptr;
821
822// Python function that returns output gradients given input gradients.
823PyObject* forward_gradient_function = nullptr;
824
825static std::atomic<int64_t> _uid;
826
827} // namespace
828
829TF_Status* GetStatus() {
830 TF_Status* maybe_status = ReleaseThreadLocalStatus();
831 if (maybe_status) {
832 TF_SetStatus(maybe_status, TF_OK, "");
833 return maybe_status;
834 } else {
835 return TF_NewStatus();
836 }
837}
838
839void ReturnStatus(TF_Status* status) {
840 TF_SetStatus(status, TF_OK, "");
841 thread_local_tf_status.reset(status);
842}
843
844void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
845 const char* op_name, TFE_InputTensorHandles* inputs,
846 PyObject* attrs, TFE_OutputTensorHandles* outputs,
847 TF_Status* out_status) {
848 TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs,
849 /*cancellation_manager=*/nullptr, outputs,
850 out_status);
851}
852
853void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
854 const char* op_name,
855 TFE_InputTensorHandles* inputs, PyObject* attrs,
856 TFE_CancellationManager* cancellation_manager,
857 TFE_OutputTensorHandles* outputs,
858 TF_Status* out_status) {
859 tensorflow::profiler::TraceMe activity(
860 "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo);
861
862 TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
863
864 auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); });
865 if (!out_status->status.ok()) return;
866
867 tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
868 tensorflow::StackTrace::kStackTraceInitialSize));
869
870 for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
871 TFE_OpAddInput(op, inputs->at(i), out_status);
872 }
873 if (cancellation_manager && out_status->status.ok()) {
874 TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
875 }
876 if (out_status->status.ok()) {
877 SetOpAttrs(ctx, op, attrs, 0, out_status);
878 }
879 Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();;
880
881 int num_outputs = outputs->size();
882
883 if (out_status->status.ok()) {
884 TFE_Execute(op, outputs->data(), &num_outputs, out_status);
885 }
886
887 if (out_status->status.ok()) {
888 outputs->resize(num_outputs);
889 } else {
890 TF_SetStatus(out_status, TF_GetCode(out_status),
891 tensorflow::strings::StrCat(TF_Message(out_status),
892 " [Op:", op_name, "]")
893 .c_str());
894 }
895
896 Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); };
897}
898
899PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
900 tensorflow::mutex_lock l(exception_class_mutex);
901 if (exception_class != nullptr) {
902 Py_DECREF(exception_class)_Py_DECREF(((PyObject*)(exception_class)));
903 }
904 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
905 exception_class = nullptr;
906 PyErr_SetString(PyExc_TypeError,
907 "TFE_Py_RegisterExceptionClass: "
908 "Registered class should be subclass of Exception.");
909 return nullptr;
910 }
911
912 Py_INCREF(e)_Py_INCREF(((PyObject*)(e)));
913 exception_class = e;
914 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
915}
916
917PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
918 if (fallback_exception_class != nullptr) {
919 Py_DECREF(fallback_exception_class)_Py_DECREF(((PyObject*)(fallback_exception_class)));
920 }
921 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
922 fallback_exception_class = nullptr;
923 PyErr_SetString(PyExc_TypeError,
924 "TFE_Py_RegisterFallbackExceptionClass: "
925 "Registered class should be subclass of Exception.");
926 return nullptr;
927 } else {
928 Py_INCREF(e)_Py_INCREF(((PyObject*)(e)));
929 fallback_exception_class = e;
930 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
931 }
932}
933
934PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
935 if (gradient_function != nullptr) {
936 Py_DECREF(gradient_function)_Py_DECREF(((PyObject*)(gradient_function)));
937 }
938 if (!PyCallable_Check(e)) {
939 gradient_function = nullptr;
940 PyErr_SetString(PyExc_TypeError,
941 "TFE_Py_RegisterGradientFunction: "
942 "Registered object should be function.");
943 return nullptr;
944 } else {
945 Py_INCREF(e)_Py_INCREF(((PyObject*)(e)));
946 gradient_function = e;
947 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
948 }
949}
950
951PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) {
952 if (forward_gradient_function != nullptr) {
953 Py_DECREF(forward_gradient_function)_Py_DECREF(((PyObject*)(forward_gradient_function)));
954 }
955 if (!PyCallable_Check(e)) {
956 forward_gradient_function = nullptr;
957 PyErr_SetString(PyExc_TypeError,
958 "TFE_Py_RegisterJVPFunction: "
959 "Registered object should be function.");
960 return nullptr;
961 } else {
962 Py_INCREF(e)_Py_INCREF(((PyObject*)(e)));
963 forward_gradient_function = e;
964 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
965 }
966}
967
968void RaiseFallbackException(const char* message) {
969 if (fallback_exception_class != nullptr) {
970 PyErr_SetString(fallback_exception_class, message);
971 return;
972 }
973
974 PyErr_SetString(
975 PyExc_RuntimeError,
976 tensorflow::strings::StrCat(
977 "Fallback exception type not set, attempting to fallback due to ",
978 message)
979 .data());
980}
981
982// Format and return `status`' error message with the attached stack trace if
983// available. `status` must have an error.
984std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) {
985 tensorflow::DCheckPyGilState();
986 DCHECK(!status.ok())while (false && (!status.ok())) ::tensorflow::internal
::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 986)
;
987
988 if (status.stack_trace().empty()) return status.error_message();
7
Assuming the condition is false
8
Taking false branch
989
990 const std::vector<tensorflow::StackFrame>& stack_trace = status.stack_trace();
991
992 PyObject* linecache = PyImport_ImportModule("linecache");
993 PyObject* getline =
994 PyObject_GetAttr(linecache, PyUnicode_FromString("getline"));
9
Calling 'PyUnicode_FromString'
11
Returning from 'PyUnicode_FromString'
12
PyObject ownership leak with reference count of 1
995 DCHECK(getline)while (false && (getline)) ::tensorflow::internal::LogMessageFatal
("tensorflow/python/eager/pywrap_tfe_src.cc", 995)
;
996
997 std::ostringstream result;
998 result << "Exception originated from\n\n";
999
1000 for (const tensorflow::StackFrame& stack_frame : stack_trace) {
1001 PyObject* line_str_obj = PyObject_CallFunction(
1002 getline, const_cast<char*>("si"), stack_frame.file_name.c_str(),
1003 stack_frame.line_number);
1004 tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj);
1005 tensorflow::str_util::RemoveWhitespaceContext(&line_str);
1006 result << " File \"" << stack_frame.file_name << "\", line "
1007 << stack_frame.line_number << ", in " << stack_frame.function_name
1008 << '\n';
1009
1010 if (!line_str.empty()) result << " " << line_str << '\n';
1011 Py_XDECREF(line_str_obj)_Py_XDECREF(((PyObject*)(line_str_obj)));
1012 }
1013
1014 Py_DecRef(getline);
1015 Py_DecRef(linecache);
1016
1017 result << '\n' << status.error_message();
1018 return result.str();
1019}
1020
1021int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
1022 if (status->status.ok()) return 0;
1023 const char* msg = TF_Message(status);
1024 if (exception == nullptr) {
1025 tensorflow::mutex_lock l(exception_class_mutex);
1026 if (exception_class != nullptr) {
1027 tensorflow::Safe_PyObjectPtr payloads(PyDict_New());
1028 for (const auto& payload :
1029 tensorflow::errors::GetPayloads(status->status)) {
1030 PyDict_SetItem(payloads.get(),
1031 PyBytes_FromString(payload.first.c_str()),
1032 PyBytes_FromString(payload.second.c_str()));
1033 }
1034 tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
1035 "siO", FormatErrorStatusStackTrace(status->status).c_str(),
1036 TF_GetCode(status), payloads.get()));
1037 if (PyErr_Occurred()) {
1038 // NOTE: This hides the actual error (i.e. the reason `status` was not
1039 // TF_OK), but there is nothing we can do at this point since we can't
1040 // generate a reasonable error from the status.
1041 // Consider adding a message explaining this.
1042 return -1;
1043 }
1044 PyErr_SetObject(exception_class, val.get());
1045 return -1;
1046 } else {
1047 exception = PyExc_RuntimeError;
1048 }
1049 }
1050 // May be update already set exception.
1051 PyErr_SetString(exception, msg);
1052 return -1;
1053}
1054
1055int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
1056 PyObject* exception) {
1057 if (status.ok()) return 0;
2
Taking false branch
1058 const char* msg = status.error_message().c_str();
1059 if (exception == nullptr) {
3
Taking true branch
1060 tensorflow::mutex_lock l(exception_class_mutex);
1061 if (exception_class != nullptr) {
4
Assuming the condition is true
5
Taking true branch
1062 tensorflow::Safe_PyObjectPtr payloads(PyDict_New());
1063 for (const auto& element : tensorflow::errors::GetPayloads(status)) {
1064 PyDict_SetItem(payloads.get(),
1065 PyBytes_FromString(element.first.c_str()),
1066 PyBytes_FromString(element.second.c_str()));
1067 }
1068 tensorflow::Safe_PyObjectPtr val(
1069 Py_BuildValue("siO", FormatErrorStatusStackTrace(status).c_str(),
6
Calling 'FormatErrorStatusStackTrace'
1070 status.code(), payloads.get()));
1071 PyErr_SetObject(exception_class, val.get());
1072 return -1;
1073 } else {
1074 exception = PyExc_RuntimeError;
1075 }
1076 }
1077 // May be update already set exception.
1078 PyErr_SetString(exception, msg);
1079 return -1;
1080}
1081
1082const char* TFE_GetPythonString(PyObject* o) {
1083#if PY_MAJOR_VERSION3 >= 3
1084 if (PyBytes_Check(o)((((((PyObject*)(o))->ob_type))->tp_flags & ((1UL <<
27))) != 0)
) {
1085 return PyBytes_AsString(o);
1086 } else {
1087 return PyUnicode_AsUTF8(o);
1088 }
1089#else
1090 return PyBytes_AsString(o);
1091#endif
1092}
1093
1094int64_t get_uid() { return _uid++; }
1095
1096PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
1097
1098void TFE_DeleteContextCapsule(PyObject* context) {
1099 TFE_Context* ctx =
1100 reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
1101 auto op = ReleaseThreadLocalOp(ctx);
1102 op.reset();
1103 TFE_DeleteContext(ctx);
1104}
1105
1106static int64_t MakeInt(PyObject* integer) {
1107#if PY_MAJOR_VERSION3 >= 3
1108 return PyLong_AsLong(integer);
1109#else
1110 return PyInt_AsLong(integer);
1111#endif
1112}
1113
1114static int64_t FastTensorId(PyObject* tensor) {
1115 if (EagerTensor_CheckExact(tensor)) {
1116 return PyEagerTensor_ID(tensor);
1117 }
1118 PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
1119 if (id_field == nullptr) {
1120 return -1;
1121 }
1122 int64_t id = MakeInt(id_field);
1123 Py_DECREF(id_field)_Py_DECREF(((PyObject*)(id_field)));
1124 return id;
1125}
1126
1127namespace tensorflow {
1128DataType PyTensor_DataType(PyObject* tensor) {
1129 if (EagerTensor_CheckExact(tensor)) {
1130 return PyEagerTensor_Dtype(tensor);
1131 } else {
1132#if PY_MAJOR_VERSION3 < 3
1133 // Python 2.x:
1134 static PyObject* dtype_attr = PyString_InternFromString("dtype");
1135 static PyObject* type_enum_attr = PyString_InternFromString("_type_enum");
1136#else
1137 // Python 3.x:
1138 static PyObject* dtype_attr = PyUnicode_InternFromString("dtype");
1139 static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum");
1140#endif
1141 Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr));
1142 if (!dtype_field) {
1143 return DT_INVALID;
1144 }
1145
1146 Safe_PyObjectPtr enum_field(
1147 PyObject_GetAttr(dtype_field.get(), type_enum_attr));
1148 if (!enum_field) {
1149 return DT_INVALID;
1150 }
1151
1152 return static_cast<DataType>(MakeInt(enum_field.get()));
1153 }
1154}
1155} // namespace tensorflow
1156
1157class PyTapeTensor {
1158 public:
1159 PyTapeTensor(int64_t id, tensorflow::DataType dtype,
1160 const tensorflow::TensorShape& shape)
1161 : id_(id), dtype_(dtype), shape_(shape) {}
1162 PyTapeTensor(int64_t id, tensorflow::DataType dtype, PyObject* shape)
1163 : id_(id), dtype_(dtype), shape_(shape) {
1164 Py_INCREF(absl::get<1>(shape_))_Py_INCREF(((PyObject*)(absl::get<1>(shape_))));
1165 }
1166 PyTapeTensor(const PyTapeTensor& other) {
1167 id_ = other.id_;
1168 dtype_ = other.dtype_;
1169 shape_ = other.shape_;
1170 if (shape_.index() == 1) {
1171 Py_INCREF(absl::get<1>(shape_))_Py_INCREF(((PyObject*)(absl::get<1>(shape_))));
1172 }
1173 }
1174
1175 ~PyTapeTensor() {
1176 if (shape_.index() == 1) {
1177 Py_DECREF(absl::get<1>(shape_))_Py_DECREF(((PyObject*)(absl::get<1>(shape_))));
1178 }
1179 }
1180 PyObject* GetShape() const;
1181 PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); }
1182 int64_t GetID() const { return id_; }
1183 tensorflow::DataType GetDType() const { return dtype_; }
1184
1185 PyObject* OnesLike() const;
1186 PyObject* ZerosLike() const;
1187
1188 private:
1189 int64_t id_;
1190 tensorflow::DataType dtype_;
1191
1192 // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that
1193 // PyObject is the tensor itself. This is used to support tf.shape(tensor) for
1194 // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype
1195 // tensors.
1196 absl::variant<tensorflow::TensorShape, PyObject*> shape_;
1197};
1198
1199static PyTapeTensor TapeTensorFromTensor(PyObject* tensor);
1200
1201class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
1202 PyTapeTensor> {
1203 public:
1204 explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
1205 Py_INCREF(py_vspace_)_Py_INCREF(((PyObject*)(py_vspace_)));
1206 }
1207
1208 tensorflow::Status Initialize() {
1209 num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
1210 if (num_elements_ == nullptr) {
1211 return tensorflow::errors::InvalidArgument("invalid vspace");
1212 }
1213 aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
1214 if (aggregate_fn_ == nullptr) {
1215 return tensorflow::errors::InvalidArgument("invalid vspace");
1216 }
1217 zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
1218 if (zeros_fn_ == nullptr) {
1219 return tensorflow::errors::InvalidArgument("invalid vspace");
1220 }
1221 zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn");
1222 if (zeros_like_fn_ == nullptr) {
1223 return tensorflow::errors::InvalidArgument("invalid vspace");
1224 }
1225 ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
1226 if (ones_fn_ == nullptr) {
1227 return tensorflow::errors::InvalidArgument("invalid vspace");
1228 }
1229 ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn");
1230 if (ones_like_fn_ == nullptr) {
1231 return tensorflow::errors::InvalidArgument("invalid vspace");
1232 }
1233 graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
1234 if (graph_shape_fn_ == nullptr) {
1235 return tensorflow::errors::InvalidArgument("invalid vspace");
1236 }
1237 return tensorflow::Status::OK();
1238 }
1239
1240 ~PyVSpace() override {
1241 Py_XDECREF(num_elements_)_Py_XDECREF(((PyObject*)(num_elements_)));
1242 Py_XDECREF(aggregate_fn_)_Py_XDECREF(((PyObject*)(aggregate_fn_)));
1243 Py_XDECREF(zeros_fn_)_Py_XDECREF(((PyObject*)(zeros_fn_)));
1244 Py_XDECREF(zeros_like_fn_)_Py_XDECREF(((PyObject*)(zeros_like_fn_)));
1245 Py_XDECREF(ones_fn_)_Py_XDECREF(((PyObject*)(ones_fn_)));
1246 Py_XDECREF(ones_like_fn_)_Py_XDECREF(((PyObject*)(ones_like_fn_)));
1247 Py_XDECREF(graph_shape_fn_)_Py_XDECREF(((PyObject*)(graph_shape_fn_)));
1248
1249 Py_DECREF(py_vspace_)_Py_DECREF(((PyObject*)(py_vspace_)));
1250 }
1251
1252 int64_t NumElements(PyObject* tensor) const final {
1253 if (EagerTensor_CheckExact(tensor)) {
1254 return PyEagerTensor_NumElements(tensor);
1255 }
1256 PyObject* arglist =
1257 Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
1258 PyObject* result = PyEval_CallObject(num_elements_, arglist)PyEval_CallObjectWithKeywords(num_elements_, arglist, (PyObject
*)__null)
;
1259 Py_DECREF(arglist)_Py_DECREF(((PyObject*)(arglist)));
1260 if (result == nullptr) {
1261 // The caller detects whether a python exception has been raised.
1262 return -1;
1263 }
1264 int64_t r = MakeInt(result);
1265 Py_DECREF(result)_Py_DECREF(((PyObject*)(result)));
1266 return r;
1267 }
1268
1269 PyObject* AggregateGradients(
1270 tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
1271 PyObject* list = PyList_New(gradient_tensors.size());
1272 for (int i = 0; i < gradient_tensors.size(); ++i) {
1273 // Note: stealing a reference to the gradient tensors.
1274 CHECK(gradient_tensors[i] != nullptr)if ((__builtin_expect(!(gradient_tensors[i] != nullptr), 0)))
::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 1274) << "Check failed: " "gradient_tensors[i] != nullptr"
" "
;
1275 CHECK(gradient_tensors[i] != Py_None)if ((__builtin_expect(!(gradient_tensors[i] != (&_Py_NoneStruct
)), 0))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 1275) << "Check failed: " "gradient_tensors[i] != Py_None"
" "
;
1276 PyList_SET_ITEM(list, i,PyList_SetItem(list, i, reinterpret_cast<PyObject*>(gradient_tensors
[i]))
1277 reinterpret_cast<PyObject*>(gradient_tensors[i]))PyList_SetItem(list, i, reinterpret_cast<PyObject*>(gradient_tensors
[i]))
;
1278 }
1279 PyObject* arglist = Py_BuildValue("(O)", list);
1280 CHECK(arglist != nullptr)if ((__builtin_expect(!(arglist != nullptr), 0))) ::tensorflow
::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 1280) << "Check failed: " "arglist != nullptr" " "
;
1281 PyObject* result = PyEval_CallObject(aggregate_fn_, arglist)PyEval_CallObjectWithKeywords(aggregate_fn_, arglist, (PyObject
*)__null)
;
1282 Py_DECREF(arglist)_Py_DECREF(((PyObject*)(arglist)));
1283 Py_DECREF(list)_Py_DECREF(((PyObject*)(list)));
1284 return result;
1285 }
1286
1287 int64_t TensorId(PyObject* tensor) const final {
1288 return FastTensorId(tensor);
1289 }
1290
1291 void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient)_Py_INCREF(((PyObject*)(gradient))); }
1292
1293 PyObject* Ones(PyObject* shape, PyObject* dtype) const {
1294 if (PyErr_Occurred()) {
1295 return nullptr;
1296 }
1297 PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1298 PyObject* result = PyEval_CallObject(ones_fn_, arg_list)PyEval_CallObjectWithKeywords(ones_fn_, arg_list, (PyObject *
)__null)
;
1299 Py_DECREF(arg_list)_Py_DECREF(((PyObject*)(arg_list)));
1300 return result;
1301 }
1302
1303 PyObject* OnesLike(PyObject* tensor) const {
1304 if (PyErr_Occurred()) {
1305 return nullptr;
1306 }
1307 return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL__null);
1308 }
1309
1310 // Builds a tensor filled with ones with the same shape and dtype as `t`.
1311 Status BuildOnesLike(const PyTapeTensor& t,
1312 PyObject** result) const override {
1313 *result = t.OnesLike();
1314 return Status::OK();
1315 }
1316
1317 PyObject* Zeros(PyObject* shape, PyObject* dtype) const {
1318 if (PyErr_Occurred()) {
1319 return nullptr;
1320 }
1321 PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1322 PyObject* result = PyEval_CallObject(zeros_fn_, arg_list)PyEval_CallObjectWithKeywords(zeros_fn_, arg_list, (PyObject *
)__null)
;
1323 Py_DECREF(arg_list)_Py_DECREF(((PyObject*)(arg_list)));
1324 return result;
1325 }
1326
1327 PyObject* ZerosLike(PyObject* tensor) const {
1328 if (PyErr_Occurred()) {
1329 return nullptr;
1330 }
1331 return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL__null);
1332 }
1333
1334 PyObject* GraphShape(PyObject* tensor) const {
1335 PyObject* arg_list = Py_BuildValue("(O)", tensor);
1336 PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list)PyEval_CallObjectWithKeywords(graph_shape_fn_, arg_list, (PyObject
*)__null)
;
1337 Py_DECREF(arg_list)_Py_DECREF(((PyObject*)(arg_list)));
1338 return result;
1339 }
1340
1341 tensorflow::Status CallBackwardFunction(
1342 const string& op_type, PyBackwardFunction* backward_function,
1343 const std::vector<int64_t>& unneeded_gradients,
1344 tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
1345 absl::Span<PyObject*> result) const final {
1346 PyObject* grads = PyTuple_New(output_gradients.size());
1347 for (int i = 0; i < output_gradients.size(); ++i) {
1348 if (output_gradients[i] == nullptr) {
1349 Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct))));
1350 PyTuple_SET_ITEM(grads, i, Py_None)PyTuple_SetItem(grads, i, (&_Py_NoneStruct));
1351 } else {
1352 PyTuple_SET_ITEM(grads, i,PyTuple_SetItem(grads, i, reinterpret_cast<PyObject*>(output_gradients
[i]))
1353 reinterpret_cast<PyObject*>(output_gradients[i]))PyTuple_SetItem(grads, i, reinterpret_cast<PyObject*>(output_gradients
[i]))
;
1354 }
1355 }
1356 PyObject* py_result = (*backward_function)(grads, unneeded_gradients);
1357 Py_DECREF(grads)_Py_DECREF(((PyObject*)(grads)));
1358 if (py_result == nullptr) {
1359 return tensorflow::errors::Internal("gradient function threw exceptions");
1360 }
1361 PyObject* seq =
1362 PySequence_Fast(py_result, "expected a sequence of gradients");
1363 if (seq == nullptr) {
1364 return tensorflow::errors::InvalidArgument(
1365 "gradient function did not return a list");
1366 }
1367 int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL
<< 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject
*)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void
> (0)), (PyTupleObject *)(seq))))->ob_size))
;
1368 if (len != result.size()) {
1369 return tensorflow::errors::Internal(
1370 "Recorded operation '", op_type,
1371 "' returned too few gradients. Expected ", result.size(),
1372 " but received ", len);
1373 }
1374 PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL
<< 25))) != 0) ? ((PyListObject *)(seq))->ob_item :
((PyTupleObject *)(seq))->ob_item)
;
1375 VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static
const bool vmodule_activated = ::tensorflow::internal::LogMessage
::VmoduleActivated(fname, level); return vmodule_activated; }
)(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void
)0 : ::tensorflow::internal::Voidifier() & ::tensorflow::
internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc"
, 1375, tensorflow::INFO)
<< "Gradient length is " << len;
1376 for (int i = 0; i < len; ++i) {
1377 PyObject* item = seq_array[i];
1378 if (item == Py_None(&_Py_NoneStruct)) {
1379 result[i] = nullptr;
1380 } else {
1381 Py_INCREF(item)_Py_INCREF(((PyObject*)(item)));
1382 result[i] = item;
1383 }
1384 }
1385 Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq)));
1386 Py_DECREF(py_result)_Py_DECREF(((PyObject*)(py_result)));
1387 return tensorflow::Status::OK();
1388 }
1389
1390 void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor)_Py_XDECREF(((PyObject*)(tensor))); }
1391
1392 PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final {
1393 return TapeTensorFromTensor(tensor);
1394 }
1395
1396 private:
1397 PyObject* py_vspace_;
1398
1399 PyObject* num_elements_;
1400 PyObject* aggregate_fn_;
1401 PyObject* zeros_fn_;
1402 PyObject* zeros_like_fn_;
1403 PyObject* ones_fn_;
1404 PyObject* ones_like_fn_;
1405 PyObject* graph_shape_fn_;
1406};
1407PyVSpace* py_vspace = nullptr;
1408
1409bool HasAccumulator();
1410
1411PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
1412 if (py_vspace != nullptr) {
1413 if (HasAccumulator()) {
1414 // Accumulators reference py_vspace, so we can't swap it out while one is
1415 // active. This is unlikely to ever happen.
1416 MaybeRaiseExceptionFromStatus(
1417 tensorflow::errors::Internal(
1418 "Can't change the vspace implementation while a "
1419 "forward accumulator is active."),
1420 nullptr);
1421 }
1422 delete py_vspace;
1423 }
1424
1425 py_vspace = new PyVSpace(e);
1426 auto status = py_vspace->Initialize();
1427 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1428 delete py_vspace;
1429 return nullptr;
1430 }
1431
1432 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
1433}
1434
1435PyObject* PyTapeTensor::GetShape() const {
1436 if (shape_.index() == 0) {
1437 auto& shape = absl::get<0>(shape_);
1438 PyObject* py_shape = PyTuple_New(shape.dims());
1439 for (int i = 0; i < shape.dims(); ++i) {
1440 PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)))PyTuple_SetItem(py_shape, i, PyLong_FromLong(shape.dim_size(i
)))
;
1441 }
1442
1443 return py_shape;
1444 }
1445
1446 return py_vspace->GraphShape(absl::get<1>(shape_));
1447}
1448
1449PyObject* PyTapeTensor::OnesLike() const {
1450 if (shape_.index() == 1) {
1451 PyObject* tensor = absl::get<1>(shape_);
1452 return py_vspace->OnesLike(tensor);
1453 }
1454 PyObject* py_shape = GetShape();
1455 PyObject* dtype_field = GetPyDType();
1456 PyObject* result = py_vspace->Ones(py_shape, dtype_field);
1457 Py_DECREF(dtype_field)_Py_DECREF(((PyObject*)(dtype_field)));
1458 Py_DECREF(py_shape)_Py_DECREF(((PyObject*)(py_shape)));
1459 return result;
1460}
1461
1462PyObject* PyTapeTensor::ZerosLike() const {
1463 if (GetDType() == tensorflow::DT_RESOURCE) {
1464 // Gradient functions for ops which return resource tensors accept
1465 // None. This is the behavior of py_vspace->Zeros, but checking here avoids
1466 // issues with ZerosLike.
1467 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
1468 }
1469 if (shape_.index() == 1) {
1470 PyObject* tensor = absl::get<1>(shape_);
1471 return py_vspace->ZerosLike(tensor);
1472 }
1473 PyObject* py_shape = GetShape();
1474 PyObject* dtype_field = GetPyDType();
1475 PyObject* result = py_vspace->Zeros(py_shape, dtype_field);
1476 Py_DECREF(dtype_field)_Py_DECREF(((PyObject*)(dtype_field)));
1477 Py_DECREF(py_shape)_Py_DECREF(((PyObject*)(py_shape)));
1478 return result;
1479}
1480
1481// Keeps track of all variables that have been accessed during execution.
1482class VariableWatcher {
1483 public:
1484 VariableWatcher() {}
1485
1486 ~VariableWatcher() {
1487 for (const IdAndVariable& v : watched_variables_) {
1488 Py_DECREF(v.variable)_Py_DECREF(((PyObject*)(v.variable)));
1489 }
1490 }
1491
1492 int64_t WatchVariable(PyObject* v) {
1493 tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
1494 if (handle == nullptr) {
1495 return -1;
1496 }
1497 int64_t id = FastTensorId(handle.get());
1498
1499 tensorflow::mutex_lock l(watched_variables_mu_);
1500 auto insert_result = watched_variables_.emplace(id, v);
1501
1502 if (insert_result.second) {
1503 // Only increment the reference count if we aren't already watching this
1504 // variable.
1505 Py_INCREF(v)_Py_INCREF(((PyObject*)(v)));
1506 }
1507
1508 return id;
1509 }
1510
1511 PyObject* GetVariablesAsPyTuple() {
1512 tensorflow::mutex_lock l(watched_variables_mu_);
1513 PyObject* result = PyTuple_New(watched_variables_.size());
1514 Py_ssize_t pos = 0;
1515 for (const IdAndVariable& id_and_variable : watched_variables_) {
1516 PyTuple_SET_ITEM(result, pos++, id_and_variable.variable)PyTuple_SetItem(result, pos++, id_and_variable.variable);
1517 Py_INCREF(id_and_variable.variable)_Py_INCREF(((PyObject*)(id_and_variable.variable)));
1518 }
1519 return result;
1520 }
1521
1522 private:
1523 // We store an IdAndVariable in the map since the map needs to be locked
1524 // during insert, but should not call back into python during insert to avoid
1525 // deadlocking with the GIL.
1526 struct IdAndVariable {
1527 int64_t id;
1528 PyObject* variable;
1529
1530 IdAndVariable(int64_t id, PyObject* variable)
1531 : id(id), variable(variable) {}
1532 };
1533 struct CompareById {
1534 bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
1535 return lhs.id < rhs.id;
1536 }
1537 };
1538
1539 tensorflow::mutex watched_variables_mu_;
1540 std::set<IdAndVariable, CompareById> watched_variables_
1541 TF_GUARDED_BY(watched_variables_mu_)__attribute__((guarded_by(watched_variables_mu_)));
1542};
1543
1544class GradientTape
1545 : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1546 PyTapeTensor> {
1547 public:
1548 explicit GradientTape(bool persistent, bool watch_accessed_variables)
1549 : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1550 PyTapeTensor>(persistent),
1551 watch_accessed_variables_(watch_accessed_variables) {}
1552
1553 virtual ~GradientTape() {}
1554
1555 void VariableAccessed(PyObject* v) {
1556 if (watch_accessed_variables_) {
1557 WatchVariable(v);
1558 }
1559 }
1560
1561 void WatchVariable(PyObject* v) {
1562 int64_t id = variable_watcher_.WatchVariable(v);
1563
1564 if (!PyErr_Occurred()) {
1565 this->Watch(id);
1566 }
1567 }
1568
1569 PyObject* GetVariablesAsPyTuple() {
1570 return variable_watcher_.GetVariablesAsPyTuple();
1571 }
1572
1573 private:
1574 bool watch_accessed_variables_;
1575 VariableWatcher variable_watcher_;
1576};
1577
1578typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
1579 PyTapeTensor>
1580 ForwardAccumulator;
1581
1582// Incremented when a GradientTape or accumulator is newly added to a set, and
1583// used to enforce an ordering between them.
1584std::atomic_uint_fast64_t tape_nesting_id_counter(0);
1585
1586typedef struct {
1587 PyObject_HEADPyObject ob_base;
1588 /* Type-specific fields go here. */
1589 GradientTape* tape;
1590 // A nesting order between GradientTapes and ForwardAccumulators, used to
1591 // ensure that GradientTapes do not watch the products of outer
1592 // ForwardAccumulators.
1593 int64_t nesting_id;
1594} TFE_Py_Tape;
1595
1596static void TFE_Py_Tape_Delete(PyObject* tape) {
1597 delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
1598 Py_TYPE(tape)(((PyObject*)(tape))->ob_type)->tp_free(tape);
1599}
1600
1601static PyTypeObject TFE_Py_Tape_Type = {
1602 PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, "tfe.Tape", /* tp_name */
1603 sizeof(TFE_Py_Tape), /* tp_basicsize */
1604 0, /* tp_itemsize */
1605 &TFE_Py_Tape_Delete, /* tp_dealloc */
1606#if PY_VERSION_HEX((3 << 24) | (8 << 16) | (5 << 8) | (0xF <<
4) | (0 << 0))
< 0x03080000
1607 nullptr, /* tp_print */
1608#else
1609 0, /* tp_vectorcall_offset */
1610#endif
1611 nullptr, /* tp_getattr */
1612 nullptr, /* tp_setattr */
1613 nullptr, /* tp_reserved */
1614 nullptr, /* tp_repr */
1615 nullptr, /* tp_as_number */
1616 nullptr, /* tp_as_sequence */
1617 nullptr, /* tp_as_mapping */
1618 nullptr, /* tp_hash */
1619 nullptr, /* tp_call */
1620 nullptr, /* tp_str */
1621 nullptr, /* tp_getattro */
1622 nullptr, /* tp_setattro */
1623 nullptr, /* tp_as_buffer */
1624 Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0), /* tp_flags */
1625 "TFE_Py_Tape objects", /* tp_doc */
1626};
1627
1628typedef struct {
1629 PyObject_HEADPyObject ob_base;
1630 /* Type-specific fields go here. */
1631 ForwardAccumulator* accumulator;
1632 // A nesting order between GradientTapes and ForwardAccumulators, used to
1633 // ensure that GradientTapes do not watch the products of outer
1634 // ForwardAccumulators.
1635 int64_t nesting_id;
1636} TFE_Py_ForwardAccumulator;
1637
1638static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) {
1639 delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator;
1640 Py_TYPE(accumulator)(((PyObject*)(accumulator))->ob_type)->tp_free(accumulator);
1641}
1642
1643static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
1644 PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, "ForwardAccumulator", /* tp_name */
1645 sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */
1646 0, /* tp_itemsize */
1647 &TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */
1648#if PY_VERSION_HEX((3 << 24) | (8 << 16) | (5 << 8) | (0xF <<
4) | (0 << 0))
< 0x03080000
1649 nullptr, /* tp_print */
1650#else
1651 0, /* tp_vectorcall_offset */
1652#endif
1653 nullptr, /* tp_getattr */
1654 nullptr, /* tp_setattr */
1655 nullptr, /* tp_reserved */
1656 nullptr, /* tp_repr */
1657 nullptr, /* tp_as_number */
1658 nullptr, /* tp_as_sequence */
1659 nullptr, /* tp_as_mapping */
1660 nullptr, /* tp_hash */
1661 nullptr, /* tp_call */
1662 nullptr, /* tp_str */
1663 nullptr, /* tp_getattro */
1664 nullptr, /* tp_setattro */
1665 nullptr, /* tp_as_buffer */
1666 Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0), /* tp_flags */
1667 "TFE_Py_ForwardAccumulator objects", /* tp_doc */
1668};
1669
1670typedef struct {
1671 PyObject_HEADPyObject ob_base;
1672 /* Type-specific fields go here. */
1673 VariableWatcher* variable_watcher;
1674} TFE_Py_VariableWatcher;
1675
1676static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
1677 delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
1678 ->variable_watcher;
1679 Py_TYPE(variable_watcher)(((PyObject*)(variable_watcher))->ob_type)->tp_free(variable_watcher);
1680}
1681
1682static PyTypeObject TFE_Py_VariableWatcher_Type = {
1683 PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, "tfe.VariableWatcher", /* tp_name */
1684 sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
1685 0, /* tp_itemsize */
1686 &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
1687#if PY_VERSION_HEX((3 << 24) | (8 << 16) | (5 << 8) | (0xF <<
4) | (0 << 0))
< 0x03080000
1688 nullptr, /* tp_print */
1689#else
1690 0, /* tp_vectorcall_offset */
1691#endif
1692 nullptr, /* tp_getattr */
1693 nullptr, /* tp_setattr */
1694 nullptr, /* tp_reserved */
1695 nullptr, /* tp_repr */
1696 nullptr, /* tp_as_number */
1697 nullptr, /* tp_as_sequence */
1698 nullptr, /* tp_as_mapping */
1699 nullptr, /* tp_hash */
1700 nullptr, /* tp_call */
1701 nullptr, /* tp_str */
1702 nullptr, /* tp_getattro */
1703 nullptr, /* tp_setattro */
1704 nullptr, /* tp_as_buffer */
1705 Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0), /* tp_flags */
1706 "TFE_Py_VariableWatcher objects", /* tp_doc */
1707};
1708
1709// Note: in the current design no mutex is needed here because of the python
1710// GIL, which is always held when any TFE_Py_* methods are called. We should
1711// revisit this if/when decide to not hold the GIL while manipulating the tape
1712// stack.
1713tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
1714 thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>
1715 tape_set = nullptr;
1716 if (tape_set == nullptr) {
1717 tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>);
1718 }
1719 return tape_set.get();
1720}
1721
1722tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
1723GetVariableWatcherSet() {
1724 thread_local std::unique_ptr<
1725 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
1726 variable_watcher_set = nullptr;
1727 if (variable_watcher_set == nullptr) {
1728 variable_watcher_set.reset(
1729 new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
1730 }
1731 return variable_watcher_set.get();
1732}
1733
1734// A linked hash set, where iteration is in insertion order.
1735//
1736// Nested accumulators rely on op recording happening in insertion order, so an
1737// unordered data structure like CompactPointerSet is not suitable. Outer
1738// accumulators need to observe operations first so they know to watch the inner
1739// accumulator's jvp computation.
1740//
1741// Not thread safe.
1742class AccumulatorSet {
1743 public:
1744 // Returns true if `element` was newly inserted, false if it already exists.
1745 bool insert(TFE_Py_ForwardAccumulator* element) {
1746 if (map_.find(element) != map_.end()) {
1747 return false;
1748 }
1749 ListType::iterator it = ordered_.insert(ordered_.end(), element);
1750 map_.insert(std::make_pair(element, it));
1751 return true;
1752 }
1753
1754 void erase(TFE_Py_ForwardAccumulator* element) {
1755 MapType::iterator existing = map_.find(element);
1756 if (existing == map_.end()) {
1757 return;
1758 }
1759 ListType::iterator list_position = existing->second;
1760 map_.erase(existing);
1761 ordered_.erase(list_position);
1762 }
1763
1764 bool empty() const { return ordered_.empty(); }
1765
1766 size_t size() const { return ordered_.size(); }
1767
1768 private:
1769 typedef std::list<TFE_Py_ForwardAccumulator*> ListType;
1770 typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*,
1771 ListType::iterator>
1772 MapType;
1773
1774 public:
1775 typedef ListType::const_iterator const_iterator;
1776 typedef ListType::const_reverse_iterator const_reverse_iterator;
1777
1778 const_iterator begin() const { return ordered_.begin(); }
1779 const_iterator end() const { return ordered_.end(); }
1780
1781 const_reverse_iterator rbegin() const { return ordered_.rbegin(); }
1782 const_reverse_iterator rend() const { return ordered_.rend(); }
1783
1784 private:
1785 MapType map_;
1786 ListType ordered_;
1787};
1788
1789AccumulatorSet* GetAccumulatorSet() {
1790 thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr};
1791 if (accumulator_set == nullptr) {
1792 accumulator_set.reset(new AccumulatorSet);
1793 }
1794 return accumulator_set.get();
1795}
1796
1797inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); }
1798
1799inline bool HasGradientTape() { return !GetTapeSet()->empty(); }
1800
1801inline bool HasAccumulatorOrTape() {
1802 return HasGradientTape() || HasAccumulator();
1803}
1804
1805// A safe copy of a set, used for tapes and accumulators. The copy is not
1806// affected by other python threads changing the set of active tapes.
1807template <typename ContainerType>
1808class SafeSetCopy {
1809 public:
1810 explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) {
1811 for (auto* member : set_copy_) {
1812 Py_INCREF(member)_Py_INCREF(((PyObject*)(member)));
1813 }
1814 }
1815
1816 ~SafeSetCopy() {
1817 for (auto* member : set_copy_) {
1818 Py_DECREF(member)_Py_DECREF(((PyObject*)(member)));
1819 }
1820 }
1821
1822 typename ContainerType::const_iterator begin() const {
1823 return set_copy_.begin();
1824 }
1825
1826 typename ContainerType::const_iterator end() const { return set_copy_.end(); }
1827
1828 bool empty() const { return set_copy_.empty(); }
1829 size_t size() const { return set_copy_.size(); }
1830
1831 protected:
1832 ContainerType set_copy_;
1833};
1834
1835class SafeTapeSet
1836 : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> {
1837 public:
1838 SafeTapeSet()
1839 : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>(
1840 *GetTapeSet()) {}
1841};
1842
1843class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
1844 public:
1845 SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {}
1846
1847 typename AccumulatorSet::const_reverse_iterator rbegin() const {
1848 return set_copy_.rbegin();
1849 }
1850
1851 typename AccumulatorSet::const_reverse_iterator rend() const {
1852 return set_copy_.rend();
1853 }
1854};
1855
1856class SafeVariableWatcherSet
1857 : public SafeSetCopy<
1858 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
1859 public:
1860 SafeVariableWatcherSet()
1861 : SafeSetCopy<
1862 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
1863 *GetVariableWatcherSet()) {}
1864};
1865
1866bool* ThreadTapeIsStopped() {
1867 thread_local bool thread_tape_is_stopped{false};
1868 return &thread_tape_is_stopped;
1869}
1870
1871void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
1872
1873void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
1874
1875PyObject* TFE_Py_TapeSetIsStopped() {
1876 if (*ThreadTapeIsStopped()) {
1877 Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct
)))), ((PyObject *) &_Py_TrueStruct)
;
1878 }
1879 Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct
)))), ((PyObject *) &_Py_FalseStruct)
;
1880}
1881
1882PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
1883 PyObject* watch_accessed_variables) {
1884 TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
1885 if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
1886 TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type)( (TFE_Py_Tape *) PyObject_Init( (PyObject *) PyObject_Malloc
( ( (&TFE_Py_Tape_Type)->tp_basicsize ) ), (&TFE_Py_Tape_Type
)) )
;
1887 tape->tape = new GradientTape(persistent == Py_True((PyObject *) &_Py_TrueStruct),
1888 watch_accessed_variables == Py_True((PyObject *) &_Py_TrueStruct));
1889 Py_INCREF(tape)_Py_INCREF(((PyObject*)(tape)));
1890 tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1891 GetTapeSet()->insert(tape);
1892 return reinterpret_cast<PyObject*>(tape);
1893}
1894
1895void TFE_Py_TapeSetAdd(PyObject* tape) {
1896 Py_INCREF(tape)_Py_INCREF(((PyObject*)(tape)));
1897 TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape);
1898 if (!GetTapeSet()->insert(tfe_tape).second) {
1899 // Already exists in the tape set.
1900 Py_DECREF(tape)_Py_DECREF(((PyObject*)(tape)));
1901 } else {
1902 tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1903 }
1904}
1905
1906PyObject* TFE_Py_TapeSetIsEmpty() {
1907 if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
1908 Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct
)))), ((PyObject *) &_Py_TrueStruct)
;
1909 }
1910 Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct
)))), ((PyObject *) &_Py_FalseStruct)
;
1911}
1912
1913void TFE_Py_TapeSetRemove(PyObject* tape) {
1914 auto* stack = GetTapeSet();
1915 stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
1916 // We kept a reference to the tape in the set to ensure it wouldn't get
1917 // deleted under us; cleaning it up here.
1918 Py_DECREF(tape)_Py_DECREF(((PyObject*)(tape)));
1919}
1920
1921static std::vector<int64_t> MakeIntList(PyObject* list) {
1922 if (list == Py_None(&_Py_NoneStruct)) {
1923 return {};
1924 }
1925 PyObject* seq = PySequence_Fast(list, "expected a sequence");
1926 if (seq == nullptr) {
1927 return {};
1928 }
1929 int len = PySequence_Size(list);
1930 PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL
<< 25))) != 0) ? ((PyListObject *)(seq))->ob_item :
((PyTupleObject *)(seq))->ob_item)
;
1931 std::vector<int64_t> tensor_ids;
1932 tensor_ids.reserve(len);
1933 for (int i = 0; i < len; ++i) {
1934 PyObject* item = seq_array[i];
1935#if PY_MAJOR_VERSION3 >= 3
1936 if (PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL
<< 24))) != 0)
) {
1937#else
1938 if (PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL
<< 24))) != 0)
|| PyInt_Check(item)) {
1939#endif
1940 int64_t id = MakeInt(item);
1941 tensor_ids.push_back(id);
1942 } else {
1943 tensor_ids.push_back(-1);
1944 }
1945 }
1946 Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq)));
1947 return tensor_ids;
1948}
1949
1950// Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
1951// null. Returns true on success and false on a Python exception.
1952bool TensorShapesAndDtypes(PyObject* tensors, std::vector<int64_t>* tensor_ids,
1953 std::vector<tensorflow::DataType>* dtypes) {
1954 tensorflow::Safe_PyObjectPtr seq(
1955 PySequence_Fast(tensors, "expected a sequence"));
1956 if (seq == nullptr) {
1957 return false;
1958 }
1959 int len = PySequence_Fast_GET_SIZE(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags &
((1UL << 25))) != 0) ? ((static_cast<void> (0)),
(((PyVarObject*)(seq.get()))->ob_size)) : (((PyVarObject*
)(((static_cast<void> (0)), (PyTupleObject *)(seq.get()
))))->ob_size))
;
1960 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags &
((1UL << 25))) != 0) ? ((PyListObject *)(seq.get()))->
ob_item : ((PyTupleObject *)(seq.get()))->ob_item)
;
1961 tensor_ids->reserve(len);
1962 dtypes->reserve(len);
1963 for (int i = 0; i < len; ++i) {
1964 PyObject* item = seq_array[i];
1965 tensor_ids->push_back(FastTensorId(item));
1966 dtypes->push_back(tensorflow::PyTensor_DataType(item));
1967 }
1968 return true;
1969}
1970
1971bool TapeCouldPossiblyRecord(PyObject* tensors) {
1972 if (tensors == Py_None(&_Py_NoneStruct)) {
1973 return false;
1974 }
1975 if (*ThreadTapeIsStopped()) {
1976 return false;
1977 }
1978 if (!HasAccumulatorOrTape()) {
1979 return false;
1980 }
1981 return true;
1982}
1983
1984bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); }
1985
1986bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); }
1987
1988PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) {
1989 if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) {
1990 Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct
)))), ((PyObject *) &_Py_FalseStruct)
;
1991 }
1992 // TODO(apassos) consider not building a list and changing the API to check
1993 // each tensor individually.
1994 std::vector<int64_t> tensor_ids;
1995 std::vector<tensorflow::DataType> dtypes;
1996 if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
1997 return nullptr;
1998 }
1999 auto tape_set = *GetTapeSet();
2000 for (TFE_Py_Tape* tape : tape_set) {
2001 if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
2002 Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct
)))), ((PyObject *) &_Py_TrueStruct)
;
2003 }
2004 }
2005
2006 Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct
)))), ((PyObject *) &_Py_FalseStruct)
;
2007}
2008
2009PyObject* TFE_Py_ForwardAccumulatorPushState() {
2010 auto forward_accumulators = *GetAccumulatorSet();
2011 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2012 accumulator->accumulator->PushState();
2013 }
2014 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
2015}
2016
2017PyObject* TFE_Py_ForwardAccumulatorPopState() {
2018 auto forward_accumulators = *GetAccumulatorSet();
2019 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2020 accumulator->accumulator->PopState();
2021 }
2022 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
2023}
2024
2025PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
2026 if (!TapeCouldPossiblyRecord(tensors)) {
2027 return GetPythonObjectFromInt(0);
2028 }
2029 std::vector<int64_t> tensor_ids;
2030 std::vector<tensorflow::DataType> dtypes;
2031 if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
2032 return nullptr;
2033 }
2034
2035 // If there is a persistent tape watching, or if there are multiple tapes
2036 // watching, we'll return immediately indicating that higher-order tape
2037 // gradients are possible.
2038 bool some_tape_watching = false;
2039 if (CouldBackprop()) {
2040 auto tape_set = *GetTapeSet();
2041 for (TFE_Py_Tape* tape : tape_set) {
2042 if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
2043 if (tape->tape->IsPersistent() || some_tape_watching) {
2044 // Either this is the second tape watching, or this tape is
2045 // persistent: higher-order gradients are possible.
2046 return GetPythonObjectFromInt(2);
2047 }
2048 some_tape_watching = true;
2049 }
2050 }
2051 }
2052 if (CouldForwardprop()) {
2053 auto forward_accumulators = *GetAccumulatorSet();
2054 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2055 if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
2056 if (some_tape_watching) {
2057 // This is the second tape watching: higher-order gradients are
2058 // possible. Note that there's no equivalent of persistence for
2059 // forward-mode.
2060 return GetPythonObjectFromInt(2);
2061 }
2062 some_tape_watching = true;
2063 }
2064 }
2065 }
2066 if (some_tape_watching) {
2067 // There's exactly one non-persistent tape. The user can request first-order
2068 // gradients but won't be able to get higher-order tape gradients.
2069 return GetPythonObjectFromInt(1);
2070 } else {
2071 // There are no tapes. The user can't request tape gradients.
2072 return GetPythonObjectFromInt(0);
2073 }
2074}
2075
2076void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
2077 if (!CouldBackprop()) {
2078 return;
2079 }
2080 int64_t tensor_id = FastTensorId(tensor);
2081 if (PyErr_Occurred()) {
2082 return;
2083 }
2084 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
2085}
2086
2087bool ListContainsNone(PyObject* list) {
2088 if (list == Py_None(&_Py_NoneStruct)) return true;
2089 tensorflow::Safe_PyObjectPtr seq(
2090 PySequence_Fast(list, "expected a sequence"));
2091 if (seq == nullptr) {
2092 return false;
2093 }
2094
2095 int len = PySequence_Size(list);
2096 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags &
((1UL << 25))) != 0) ? ((PyListObject *)(seq.get()))->
ob_item : ((PyTupleObject *)(seq.get()))->ob_item)
;
2097 for (int i = 0; i < len; ++i) {
2098 PyObject* item = seq_array[i];
2099 if (item == Py_None(&_Py_NoneStruct)) return true;
2100 }
2101
2102 return false;
2103}
2104
2105// As an optimization, the tape generally keeps only the shape and dtype of
2106// tensors, and uses this information to generate ones/zeros tensors. However,
2107// some tensors require OnesLike/ZerosLike because their gradients do not match
2108// their inference shape/dtype.
2109bool DTypeNeedsHandleData(tensorflow::DataType dtype) {
2110 return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE;
2111}
2112
2113static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
2114 if (EagerTensor_CheckExact(tensor)) {
2115 tensorflow::ImmediateExecutionTensorHandle* handle =
2116 tensorflow::unwrap(EagerTensor_Handle(tensor));
2117 int64_t id = PyEagerTensor_ID(tensor);
2118 tensorflow::DataType dtype =
2119 static_cast<tensorflow::DataType>(handle->DataType());
2120 if (DTypeNeedsHandleData(dtype)) {
2121 return PyTapeTensor(id, dtype, tensor);
2122 }
2123
2124 tensorflow::TensorShape tensor_shape;
2125 int num_dims;
2126 tensorflow::Status status = handle->NumDims(&num_dims);
2127 if (status.ok()) {
2128 for (int i = 0; i < num_dims; ++i) {
2129 int64_t dim_size;
2130 status = handle->Dim(i, &dim_size);
2131 if (!status.ok()) break;
2132 tensor_shape.AddDim(dim_size);
2133 }
2134 }
2135
2136 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2137 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2138 tensorflow::TensorShape({}));
2139 } else {
2140 return PyTapeTensor(id, dtype, tensor_shape);
2141 }
2142 }
2143 int64_t id = FastTensorId(tensor);
2144 if (PyErr_Occurred()) {
2145 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2146 tensorflow::TensorShape({}));
2147 }
2148 PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
2149 PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
2150 Py_DECREF(dtype_object)_Py_DECREF(((PyObject*)(dtype_object)));
2151 tensorflow::DataType dtype =
2152 static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
2153 Py_DECREF(dtype_enum)_Py_DECREF(((PyObject*)(dtype_enum)));
2154 if (PyErr_Occurred()) {
2155 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2156 tensorflow::TensorShape({}));
2157 }
2158 static char _shape_tuple[] = "_shape_tuple";
2159 tensorflow::Safe_PyObjectPtr shape_tuple(
2160 PyObject_CallMethod(tensor, _shape_tuple, nullptr));
2161 if (PyErr_Occurred()) {
2162 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2163 tensorflow::TensorShape({}));
2164 }
2165
2166 if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) {
2167 return PyTapeTensor(id, dtype, tensor);
2168 }
2169
2170 auto l = MakeIntList(shape_tuple.get());
2171 // Replace -1, which represents accidental Nones which can occur in graph mode
2172 // and can cause errors in shape construction with 0s.
2173 for (auto& c : l) {
2174 if (c < 0) {
2175 c = 0;
2176 }
2177 }
2178 tensorflow::TensorShape shape(l);
2179 return PyTapeTensor(id, dtype, shape);
2180}
2181
2182// Populates output_info from output_seq, which must come from PySequence_Fast.
2183//
2184// Does not take ownership of output_seq. Returns true on success and false if a
2185// Python exception has been set.
2186bool TapeTensorsFromTensorSequence(PyObject* output_seq,
2187 std::vector<PyTapeTensor>* output_info) {
2188 Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq)(((((((PyObject*)(output_seq))->ob_type))->tp_flags &
((1UL << 25))) != 0) ? ((static_cast<void> (0)),
(((PyVarObject*)(output_seq))->ob_size)) : (((PyVarObject
*)(((static_cast<void> (0)), (PyTupleObject *)(output_seq
))))->ob_size))
;
2189 PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq)(((((((PyObject*)(output_seq))->ob_type))->tp_flags &
((1UL << 25))) != 0) ? ((PyListObject *)(output_seq))->
ob_item : ((PyTupleObject *)(output_seq))->ob_item)
;
2190 output_info->reserve(output_len);
2191 for (Py_ssize_t i = 0; i < output_len; ++i) {
2192 output_info->push_back(TapeTensorFromTensor(output_seq_array[i]));
2193 if (PyErr_Occurred() != nullptr) {
2194 return false;
2195 }
2196 }
2197 return true;
2198}
2199
2200std::vector<int64_t> MakeTensorIDList(PyObject* tensors) {
2201 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2202 if (seq == nullptr) {
2203 return {};
2204 }
2205 int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL
<< 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject
*)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void
> (0)), (PyTupleObject *)(seq))))->ob_size))
;
2206 PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL
<< 25))) != 0) ? ((PyListObject *)(seq))->ob_item :
((PyTupleObject *)(seq))->ob_item)
;
2207 std::vector<int64_t> list;
2208 list.reserve(len);
2209 for (int i = 0; i < len; ++i) {
2210 PyObject* tensor = seq_array[i];
2211 list.push_back(FastTensorId(tensor));
2212 if (PyErr_Occurred()) {
2213 Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq)));
2214 return list;
2215 }
2216 }
2217 Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq)));
2218 return list;
2219}
2220
2221void TFE_Py_TapeVariableAccessed(PyObject* variable) {
2222 if (!CouldBackprop()) {
2223 return;
2224 }
2225 for (TFE_Py_Tape* tape : SafeTapeSet()) {
2226 tape->tape->VariableAccessed(variable);
2227 }
2228}
2229
2230void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
2231 if (!CouldBackprop()) {
2232 return;
2233 }
2234 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
2235}
2236
2237PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
2238 return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
2239}
2240
2241PyObject* TFE_Py_VariableWatcherNew() {
2242 TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
2243 if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
2244 TFE_Py_VariableWatcher* variable_watcher =
2245 PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type)( (TFE_Py_VariableWatcher *) PyObject_Init( (PyObject *) PyObject_Malloc
( ( (&TFE_Py_VariableWatcher_Type)->tp_basicsize ) ), (
&TFE_Py_VariableWatcher_Type)) )
;
2246 variable_watcher->variable_watcher = new VariableWatcher();
2247 Py_INCREF(variable_watcher)_Py_INCREF(((PyObject*)(variable_watcher)));
2248 GetVariableWatcherSet()->insert(variable_watcher);
2249 return reinterpret_cast<PyObject*>(variable_watcher);
2250}
2251
2252void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
2253 auto* stack = GetVariableWatcherSet();
2254 stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
2255 // We kept a reference to the variable watcher in the set to ensure it
2256 // wouldn't get deleted under us; cleaning it up here.
2257 Py_DECREF(variable_watcher)_Py_DECREF(((PyObject*)(variable_watcher)));
2258}
2259
2260void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
2261 for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
2262 variable_watcher->variable_watcher->WatchVariable(variable);
2263 }
2264}
2265
2266PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
2267 return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
2268 ->variable_watcher->GetVariablesAsPyTuple();
2269}
2270
2271namespace {
2272std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
2273 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2274 if (seq == nullptr) {
2275 return {};
2276 }
2277 int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL
<< 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject
*)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void
> (0)), (PyTupleObject *)(seq))))->ob_size))
;
2278 PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL
<< 25))) != 0) ? ((PyListObject *)(seq))->ob_item :
((PyTupleObject *)(seq))->ob_item)
;
2279 std::vector<tensorflow::DataType> list;
2280 list.reserve(len);
2281 for (int i = 0; i < len; ++i) {
2282 PyObject* tensor = seq_array[i];
2283 list.push_back(tensorflow::PyTensor_DataType(tensor));
2284 }
2285 Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq)));
2286 return list;
2287}
2288
2289PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id,
2290 PyObject* weak_tensor_ref) {
2291 int64_t parsed_tensor_id = MakeInt(tensor_id);
2292 for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) {
2293 accumulator->accumulator->DeleteGradient(parsed_tensor_id);
2294 }
2295 Py_DECREF(weak_tensor_ref)_Py_DECREF(((PyObject*)(weak_tensor_ref)));
2296 Py_DECREF(tensor_id)_Py_DECREF(((PyObject*)(tensor_id)));
2297 Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct))));
2298 return Py_None(&_Py_NoneStruct);
2299}
2300
2301static PyMethodDef forward_accumulator_delete_gradient_method_def = {
2302 "ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient,
2303 METH_O0x0008, "ForwardAccumulatorDeleteGradient"};
2304
2305void RegisterForwardAccumulatorCleanup(PyObject* tensor, int64_t tensor_id) {
2306 tensorflow::Safe_PyObjectPtr callback(
2307 PyCFunction_New(&forward_accumulator_delete_gradient_method_def,PyCFunction_NewEx((&forward_accumulator_delete_gradient_method_def
), (PyLong_FromLong(tensor_id)), __null)
2308 PyLong_FromLong(tensor_id))PyCFunction_NewEx((&forward_accumulator_delete_gradient_method_def
), (PyLong_FromLong(tensor_id)), __null)
);
2309 // We need to keep a reference to the weakref active if we want our callback
2310 // called. The callback itself now owns the weakref object and the tensor ID
2311 // object.
2312 PyWeakref_NewRef(tensor, callback.get());
2313}
2314
2315void TapeSetRecordBackprop(
2316 const string& op_type, const std::vector<PyTapeTensor>& output_info,
2317 const std::vector<int64_t>& input_ids,
2318 const std::vector<tensorflow::DataType>& input_dtypes,
2319 const std::function<PyBackwardFunction*()>& backward_function_getter,
2320 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2321 tensorflow::uint64 max_gradient_tape_id) {
2322 if (!CouldBackprop()) {
2323 return;
2324 }
2325 for (TFE_Py_Tape* tape : SafeTapeSet()) {
2326 if (tape->nesting_id < max_gradient_tape_id) {
2327 tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes,
2328 backward_function_getter,
2329 backward_function_killer);
2330 }
2331 }
2332}
2333
2334bool TapeSetRecordForwardprop(
2335 const string& op_type, PyObject* output_seq,
2336 const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors,
2337 const std::vector<int64_t>& input_ids,
2338 const std::vector<tensorflow::DataType>& input_dtypes,
2339 const std::function<PyBackwardFunction*()>& backward_function_getter,
2340 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2341 const tensorflow::eager::ForwardFunction<PyObject>* forward_function,
2342 PyObject* forwardprop_output_indices,
2343 tensorflow::uint64* max_gradient_tape_id) {
2344 *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max();
2345 if (!CouldForwardprop()) {
2346 return true;
2347 }
2348 auto accumulator_set = SafeAccumulatorSet();
2349 tensorflow::Safe_PyObjectPtr input_seq(
2350 PySequence_Fast(input_tensors, "expected a sequence of tensors"));
2351 if (input_seq == nullptr || PyErr_Occurred()) return false;
2352 Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get())(((((((PyObject*)(input_seq.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((static_cast<void>
(0)), (((PyVarObject*)(input_seq.get()))->ob_size)) : (((
PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *
)(input_seq.get()))))->ob_size))
;
2353 PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq)(((((((PyObject*)(output_seq))->ob_type))->tp_flags &
((1UL << 25))) != 0) ? ((PyListObject *)(output_seq))->
ob_item : ((PyTupleObject *)(output_seq))->ob_item)
;
2354 for (int i = 0; i < output_info.size(); ++i) {
2355 RegisterForwardAccumulatorCleanup(output_seq_array[i],
2356 output_info[i].GetID());
2357 }
2358 if (forwardprop_output_indices != nullptr &&
2359 forwardprop_output_indices != Py_None(&_Py_NoneStruct)) {
2360 tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast(
2361 forwardprop_output_indices, "Expected a sequence of indices"));
2362 if (indices_fast == nullptr || PyErr_Occurred()) {
2363 return false;
2364 }
2365 if (PySequence_Fast_GET_SIZE(indices_fast.get())(((((((PyObject*)(indices_fast.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((static_cast<void>
(0)), (((PyVarObject*)(indices_fast.get()))->ob_size)) : (
((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject
*)(indices_fast.get()))))->ob_size))
!=
2366 accumulator_set.size()) {
2367 MaybeRaiseExceptionFromStatus(
2368 tensorflow::errors::Internal(
2369 "Accumulators were added or removed from the active set "
2370 "between packing and unpacking."),
2371 nullptr);
2372 }
2373 PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get())(((((((PyObject*)(indices_fast.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((PyListObject *)(indices_fast
.get()))->ob_item : ((PyTupleObject *)(indices_fast.get())
)->ob_item)
;
2374 Py_ssize_t accumulator_index = 0;
2375 for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin();
2376 it != accumulator_set.rend(); ++it, ++accumulator_index) {
2377 tensorflow::Safe_PyObjectPtr jvp_index_seq(
2378 PySequence_Fast(indices_fast_array[accumulator_index],
2379 "Expected a sequence of jvp indices."));
2380 if (jvp_index_seq == nullptr || PyErr_Occurred()) {
2381 return false;
2382 }
2383 Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get())(((((((PyObject*)(jvp_index_seq.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((static_cast<void>
(0)), (((PyVarObject*)(jvp_index_seq.get()))->ob_size)) :
(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject
*)(jvp_index_seq.get()))))->ob_size))
;
2384 PyObject** jvp_index_seq_array =
2385 PySequence_Fast_ITEMS(jvp_index_seq.get())(((((((PyObject*)(jvp_index_seq.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((PyListObject *)(jvp_index_seq
.get()))->ob_item : ((PyTupleObject *)(jvp_index_seq.get()
))->ob_item)
;
2386 for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) {
2387 PyObject* tuple = jvp_index_seq_array[jvp_index];
2388 int64_t primal_tensor_id =
2389 output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID();
2390 (*it)->accumulator->Watch(
2391 primal_tensor_id,
2392 output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]);
2393 }
2394 }
2395 } else {
2396 std::vector<PyTapeTensor> input_info;
2397 input_info.reserve(input_len);
2398 PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get())(((((((PyObject*)(input_seq.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((PyListObject *)(input_seq
.get()))->ob_item : ((PyTupleObject *)(input_seq.get()))->
ob_item)
;
2399 for (Py_ssize_t i = 0; i < input_len; ++i) {
2400 input_info.push_back(TapeTensorFromTensor(input_seq_array[i]));
2401 }
2402 for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) {
2403 tensorflow::Status status = accumulator->accumulator->Accumulate(
2404 op_type, input_info, output_info, input_ids, input_dtypes,
2405 forward_function, backward_function_getter, backward_function_killer);
2406 if (PyErr_Occurred()) return false; // Don't swallow Python exceptions.
2407 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2408 return false;
2409 }
2410 if (accumulator->accumulator->BusyAccumulating()) {
2411 // Ensure inner accumulators don't see outer accumulators' jvps. This
2412 // mostly happens on its own, with some potentially surprising
2413 // exceptions, so the blanket policy is for consistency.
2414 *max_gradient_tape_id = accumulator->nesting_id;
2415 break;
2416 }
2417 }
2418 }
2419 return true;
2420}
2421
2422PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) {
2423 PyObject* py_input_tangents = PyTuple_New(input_tangents.size());
2424 for (int i = 0; i < input_tangents.size(); ++i) {
2425 PyObject* element;
2426 if (input_tangents[i] == nullptr) {
2427 element = Py_None(&_Py_NoneStruct);
2428 } else {
2429 element = input_tangents[i];
2430 }
2431 Py_INCREF(element)_Py_INCREF(((PyObject*)(element)));
2432 PyTuple_SET_ITEM(py_input_tangents, i, element)PyTuple_SetItem(py_input_tangents, i, element);
2433 }
2434 return py_input_tangents;
2435}
2436
2437tensorflow::Status ParseTangentOutputs(
2438 PyObject* user_output, std::vector<PyObject*>* output_tangents) {
2439 if (user_output == Py_None(&_Py_NoneStruct)) {
2440 // No connected gradients.
2441 return tensorflow::Status::OK();
2442 }
2443 tensorflow::Safe_PyObjectPtr fast_result(
2444 PySequence_Fast(user_output, "expected a sequence of forward gradients"));
2445 if (fast_result == nullptr) {
2446 return tensorflow::errors::InvalidArgument(
2447 "forward gradient function did not return a sequence.");
2448 }
2449 int len = PySequence_Fast_GET_SIZE(fast_result.get())(((((((PyObject*)(fast_result.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((static_cast<void>
(0)), (((PyVarObject*)(fast_result.get()))->ob_size)) : (
((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject
*)(fast_result.get()))))->ob_size))
;
2450 PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get())(((((((PyObject*)(fast_result.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((PyListObject *)(fast_result
.get()))->ob_item : ((PyTupleObject *)(fast_result.get()))
->ob_item)
;
2451 output_tangents->reserve(len);
2452 for (int i = 0; i < len; ++i) {
2453 PyObject* item = fast_result_array[i];
2454 if (item == Py_None(&_Py_NoneStruct)) {
2455 output_tangents->push_back(nullptr);
2456 } else {
2457 Py_INCREF(item)_Py_INCREF(((PyObject*)(item)));
2458 output_tangents->push_back(item);
2459 }
2460 }
2461 return tensorflow::Status::OK();
2462}
2463
2464// Calls the registered forward_gradient_function, computing `output_tangents`
2465// from `input_tangents`. `output_tangents` must not be null.
2466//
2467// `op_name`, `attrs`, `inputs`, and `results` describe the operation for which
2468// the forward function is being called.
2469tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
2470 PyObject* inputs, PyObject* results,
2471 const std::vector<PyObject*>& input_tangents,
2472 std::vector<PyObject*>* output_tangents,
2473 bool use_batch) {
2474 if (forward_gradient_function == nullptr) {
2475 return tensorflow::errors::Internal(
2476 "No forward gradient function registered.");
2477 }
2478 tensorflow::Safe_PyObjectPtr py_input_tangents(
2479 TangentsAsPyTuple(input_tangents));
2480
2481 // Normalize the input sequence to a tuple so it works with function
2482 // caching; otherwise it may be an opaque _InputList object.
2483 tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
2484 PyObject* to_batch = (use_batch) ? Py_True((PyObject *) &_Py_TrueStruct) : Py_False((PyObject *) &_Py_FalseStruct);
2485 tensorflow::Safe_PyObjectPtr callback_args(
2486 Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
2487 py_input_tangents.get(), to_batch));
2488 tensorflow::Safe_PyObjectPtr py_result(
2489 PyObject_CallObject(forward_gradient_function, callback_args.get()));
2490 if (py_result == nullptr || PyErr_Occurred()) {
2491 return tensorflow::errors::Internal(
2492 "forward gradient function threw exceptions");
2493 }
2494 return ParseTangentOutputs(py_result.get(), output_tangents);
2495}
2496
2497// Like CallJVPFunction, but calls a pre-bound forward function.
2498// These are passed in from a record_gradient argument.
2499tensorflow::Status CallOpSpecificJVPFunction(
2500 PyObject* op_specific_forward_function,
2501 const std::vector<PyObject*>& input_tangents,
2502 std::vector<PyObject*>* output_tangents) {
2503 tensorflow::Safe_PyObjectPtr py_input_tangents(
2504 TangentsAsPyTuple(input_tangents));
2505
2506 tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject(
2507 op_specific_forward_function, py_input_tangents.get()));
2508 if (py_result == nullptr || PyErr_Occurred()) {
2509 return tensorflow::errors::Internal(
2510 "forward gradient function threw exceptions");
2511 }
2512 return ParseTangentOutputs(py_result.get(), output_tangents);
2513}
2514
2515bool ParseOpTypeString(PyObject* op_type, string* op_type_string) {
2516 if (PyBytes_Check(op_type)((((((PyObject*)(op_type))->ob_type))->tp_flags & (
(1UL << 27))) != 0)
) {
2517 *op_type_string = PyBytes_AsString(op_type);
2518 } else if (PyUnicode_Check(op_type)((((((PyObject*)(op_type))->ob_type))->tp_flags & (
(1UL << 28))) != 0)
) {
2519#if PY_MAJOR_VERSION3 >= 3
2520 *op_type_string = PyUnicode_AsUTF8(op_type);
2521#else
2522 PyObject* py_str = PyUnicode_AsUTF8String(op_type);
2523 if (py_str == nullptr) {
2524 return false;
2525 }
2526 *op_type_string = PyBytes_AS_STRING(py_str)((static_cast<void> (0)), (((PyBytesObject *)(py_str))->
ob_sval))
;
2527 Py_DECREF(py_str)_Py_DECREF(((PyObject*)(py_str)));
2528#endif
2529 } else {
2530 PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
2531 return false;
2532 }
2533 return true;
2534}
2535
2536bool TapeSetRecordOperation(
2537 PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors,
2538 const std::vector<int64_t>& input_ids,
2539 const std::vector<tensorflow::DataType>& input_dtypes,
2540 const std::function<PyBackwardFunction*()>& backward_function_getter,
2541 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2542 const tensorflow::eager::ForwardFunction<PyObject>* forward_function) {
2543 std::vector<PyTapeTensor> output_info;
2544 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2545 output_tensors, "expected a sequence of integer tensor ids"));
2546 if (PyErr_Occurred() ||
2547 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2548 return false;
2549 }
2550 string op_type_str;
2551 if (!ParseOpTypeString(op_type, &op_type_str)) {
2552 return false;
2553 }
2554 tensorflow::uint64 max_gradient_tape_id;
2555 if (!TapeSetRecordForwardprop(
2556 op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2557 input_dtypes, backward_function_getter, backward_function_killer,
2558 forward_function, nullptr /* No special-cased jvps. */,
2559 &max_gradient_tape_id)) {
2560 return false;
2561 }
2562 TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2563 backward_function_getter, backward_function_killer,
2564 max_gradient_tape_id);
2565 return true;
2566}
2567} // namespace
2568
2569PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
2570 PyObject* output_tensors,
2571 PyObject* input_tensors,
2572 PyObject* backward_function,
2573 PyObject* forward_function) {
2574 if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) {
2575 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
2576 }
2577 std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors);
2578 if (PyErr_Occurred()) return nullptr;
2579
2580 std::vector<tensorflow::DataType> input_dtypes =
2581 MakeTensorDtypeList(input_tensors);
2582 if (PyErr_Occurred()) return nullptr;
2583
2584 std::function<PyBackwardFunction*()> backward_function_getter(
2585 [backward_function]() {
2586 Py_INCREF(backward_function)_Py_INCREF(((PyObject*)(backward_function)));
2587 PyBackwardFunction* function = new PyBackwardFunction(
2588 [backward_function](PyObject* out_grads,
2589 const std::vector<int64_t>& unused) {
2590 return PyObject_CallObject(backward_function, out_grads);
2591 });
2592 return function;
2593 });
2594 std::function<void(PyBackwardFunction*)> backward_function_killer(
2595 [backward_function](PyBackwardFunction* py_backward_function) {
2596 Py_DECREF(backward_function)_Py_DECREF(((PyObject*)(backward_function)));
2597 delete py_backward_function;
2598 });
2599
2600 if (forward_function == Py_None(&_Py_NoneStruct)) {
2601 if (!TapeSetRecordOperation(
2602 op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2603 backward_function_getter, backward_function_killer,
2604 nullptr /* No special-cased forward function */)) {
2605 return nullptr;
2606 }
2607 } else {
2608 tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
2609 [forward_function](const std::vector<PyObject*>& input_tangents,
2610 std::vector<PyObject*>* output_tangents,
2611 bool use_batch = false) {
2612 return CallOpSpecificJVPFunction(forward_function, input_tangents,
2613 output_tangents);
2614 });
2615 if (!TapeSetRecordOperation(
2616 op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2617 backward_function_getter, backward_function_killer,
2618 &wrapped_forward_function)) {
2619 return nullptr;
2620 }
2621 }
2622 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
2623}
2624
2625PyObject* TFE_Py_TapeSetRecordOperationForwardprop(
2626 PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors,
2627 PyObject* backward_function, PyObject* forwardprop_output_indices) {
2628 if (!HasAccumulator() || *ThreadTapeIsStopped()) {
2629 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
2630 }
2631 std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors);
2632 if (PyErr_Occurred()) return nullptr;
2633
2634 std::vector<tensorflow::DataType> input_dtypes =
2635 MakeTensorDtypeList(input_tensors);
2636 if (PyErr_Occurred()) return nullptr;
2637
2638 std::function<PyBackwardFunction*()> backward_function_getter(
2639 [backward_function]() {
2640 Py_INCREF(backward_function)_Py_INCREF(((PyObject*)(backward_function)));
2641 PyBackwardFunction* function = new PyBackwardFunction(
2642 [backward_function](PyObject* out_grads,
2643 const std::vector<int64_t>& unused) {
2644 return PyObject_CallObject(backward_function, out_grads);
2645 });
2646 return function;
2647 });
2648 std::function<void(PyBackwardFunction*)> backward_function_killer(
2649 [backward_function](PyBackwardFunction* py_backward_function) {
2650 Py_DECREF(backward_function)_Py_DECREF(((PyObject*)(backward_function)));
2651 delete py_backward_function;
2652 });
2653 std::vector<PyTapeTensor> output_info;
2654 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2655 output_tensors, "expected a sequence of integer tensor ids"));
2656 if (PyErr_Occurred() ||
2657 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2658 return nullptr;
2659 }
2660 string op_type_str;
2661 if (!ParseOpTypeString(op_type, &op_type_str)) {
2662 return nullptr;
2663 }
2664 tensorflow::uint64 max_gradient_tape_id;
2665 if (!TapeSetRecordForwardprop(
2666 op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2667 input_dtypes, backward_function_getter, backward_function_killer,
2668 nullptr /* no special-cased forward function */,
2669 forwardprop_output_indices, &max_gradient_tape_id)) {
2670 return nullptr;
2671 }
2672 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
2673}
2674
2675PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type,
2676 PyObject* output_tensors,
2677 PyObject* input_tensors,
2678 PyObject* backward_function) {
2679 if (!CouldBackprop()) {
2680 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
2681 }
2682 std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors);
2683 if (PyErr_Occurred()) return nullptr;
2684
2685 std::vector<tensorflow::DataType> input_dtypes =
2686 MakeTensorDtypeList(input_tensors);
2687 if (PyErr_Occurred()) return nullptr;
2688
2689 std::function<PyBackwardFunction*()> backward_function_getter(
2690 [backward_function]() {
2691 Py_INCREF(backward_function)_Py_INCREF(((PyObject*)(backward_function)));
2692 PyBackwardFunction* function = new PyBackwardFunction(
2693 [backward_function](PyObject* out_grads,
2694 const std::vector<int64_t>& unused) {
2695 return PyObject_CallObject(backward_function, out_grads);
2696 });
2697 return function;
2698 });
2699 std::function<void(PyBackwardFunction*)> backward_function_killer(
2700 [backward_function](PyBackwardFunction* py_backward_function) {
2701 Py_DECREF(backward_function)_Py_DECREF(((PyObject*)(backward_function)));
2702 delete py_backward_function;
2703 });
2704 std::vector<PyTapeTensor> output_info;
2705 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2706 output_tensors, "expected a sequence of integer tensor ids"));
2707 if (PyErr_Occurred() ||
2708 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2709 return nullptr;
2710 }
2711 string op_type_str;
2712 if (!ParseOpTypeString(op_type, &op_type_str)) {
2713 return nullptr;
2714 }
2715 TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2716 backward_function_getter, backward_function_killer,
2717 // No filtering based on relative ordering with forward
2718 // accumulators.
2719 std::numeric_limits<tensorflow::uint64>::max());
2720 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
2721}
2722
2723void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id) {
2724 for (TFE_Py_Tape* tape : *GetTapeSet()) {
2725 tape->tape->DeleteTrace(tensor_id);
2726 }
2727}
2728
2729std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
2730 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2731 if (seq == nullptr) {
2732 return {};
2733 }
2734 int len = PySequence_Fast_GET_SIZE(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL
<< 25))) != 0) ? ((static_cast<void> (0)), (((PyVarObject
*)(seq))->ob_size)) : (((PyVarObject*)(((static_cast<void
> (0)), (PyTupleObject *)(seq))))->ob_size))
;
2735 PyObject** seq_array = PySequence_Fast_ITEMS(seq)(((((((PyObject*)(seq))->ob_type))->tp_flags & ((1UL
<< 25))) != 0) ? ((PyListObject *)(seq))->ob_item :
((PyTupleObject *)(seq))->ob_item)
;
2736 std::vector<PyObject*> list(seq_array, seq_array + len);
2737 Py_DECREF(seq)_Py_DECREF(((PyObject*)(seq)));
2738 return list;
2739}
2740
2741PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
2742 PyObject* sources, PyObject* output_gradients,
2743 PyObject* sources_raw,
2744 PyObject* unconnected_gradients,
2745 TF_Status* status) {
2746 TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
2747 if (!tape_obj->tape->IsPersistent()) {
2748 auto* tape_set = GetTapeSet();
2749 if (tape_set->find(tape_obj) != tape_set->end()) {
2750 PyErr_SetString(PyExc_RuntimeError,
2751 "gradient() cannot be invoked within the "
2752 "GradientTape context (i.e., while operations are being "
2753 "recorded). Either move the call to gradient() to be "
2754 "outside the 'with tf.GradientTape' block, or "
2755 "use a persistent tape: "
2756 "'with tf.GradientTape(persistent=true)'");
2757 return nullptr;
2758 }
2759 }
2760
2761 std::vector<int64_t> target_vec = MakeTensorIDList(target);
2762 if (PyErr_Occurred()) {
2763 return nullptr;
2764 }
2765 std::vector<int64_t> sources_vec = MakeTensorIDList(sources);
2766 if (PyErr_Occurred()) {
2767 return nullptr;
2768 }
2769 tensorflow::gtl::FlatSet<int64_t> sources_set(sources_vec.begin(),
2770 sources_vec.end());
2771
2772 tensorflow::Safe_PyObjectPtr seq =
2773 tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
2774 int len = PySequence_Fast_GET_SIZE(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags &
((1UL << 25))) != 0) ? ((static_cast<void> (0)),
(((PyVarObject*)(seq.get()))->ob_size)) : (((PyVarObject*
)(((static_cast<void> (0)), (PyTupleObject *)(seq.get()
))))->ob_size))
;
2775 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get())(((((((PyObject*)(seq.get()))->ob_type))->tp_flags &
((1UL << 25))) != 0) ? ((PyListObject *)(seq.get()))->
ob_item : ((PyTupleObject *)(seq.get()))->ob_item)
;
2776 std::unordered_map<int64_t, PyTapeTensor> source_tensors_that_are_targets;
2777 for (int i = 0; i < len; ++i) {
2778 int64_t target_id = target_vec[i];
2779 if (sources_set.find(target_id) != sources_set.end()) {
2780 auto tensor = seq_array[i];
2781 source_tensors_that_are_targets.insert(
2782 std::make_pair(target_id, TapeTensorFromTensor(tensor)));
2783 }
2784 if (PyErr_Occurred()) {
2785 return nullptr;
2786 }
2787 }
2788 if (PyErr_Occurred()) {
2789 return nullptr;
2790 }
2791
2792 std::vector<PyObject*> outgrad_vec;
2793 if (output_gradients != Py_None(&_Py_NoneStruct)) {
2794 outgrad_vec = MakeTensorList(output_gradients);
2795 if (PyErr_Occurred()) {
2796 return nullptr;
2797 }
2798 for (PyObject* tensor : outgrad_vec) {
2799 // Calling the backward function will eat a reference to the tensors in
2800 // outgrad_vec, so we need to increase their reference count.
2801 Py_INCREF(tensor)_Py_INCREF(((PyObject*)(tensor)));
2802 }
2803 }
2804 std::vector<PyObject*> result(sources_vec.size());
2805 status->status = tape_obj->tape->ComputeGradient(
2806 *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
2807 outgrad_vec, absl::MakeSpan(result));
2808 if (!status->status.ok()) {
2809 if (PyErr_Occurred()) {
2810 // Do not propagate the erroneous status as that would swallow the
2811 // exception which caused the problem.
2812 status->status = tensorflow::Status::OK();
2813 }
2814 return nullptr;
2815 }
2816
2817 bool unconnected_gradients_zero =
2818 strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
2819 std::vector<PyObject*> sources_obj;
2820 if (unconnected_gradients_zero) {
2821 // Uses the "raw" sources here so it can properly make a zeros tensor even
2822 // if there are resource variables as sources.
2823 sources_obj = MakeTensorList(sources_raw);
2824 }
2825
2826 if (!result.empty()) {
2827 PyObject* py_result = PyList_New(result.size());
2828 tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
2829 for (int i = 0; i < result.size(); ++i) {
2830 if (result[i] == nullptr) {
2831 if (unconnected_gradients_zero) {
2832 // generate a zeros tensor in the shape of sources[i]
2833 tensorflow::DataType dtype =
2834 tensorflow::PyTensor_DataType(sources_obj[i]);
2835 PyTapeTensor tensor =
2836 PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
2837 result[i] = tensor.ZerosLike();
2838 } else {
2839 Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct))));
2840 result[i] = Py_None(&_Py_NoneStruct);
2841 }
2842 } else if (seen_results.find(result[i]) != seen_results.end()) {
2843 Py_INCREF(result[i])_Py_INCREF(((PyObject*)(result[i])));
2844 }
2845 seen_results.insert(result[i]);
2846 PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]))PyList_SetItem(py_result, i, reinterpret_cast<PyObject*>
(result[i]))
;
2847 }
2848 return py_result;
2849 }
2850 return PyList_New(0);
2851}
2852
2853PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
2854 TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
2855 if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
2856 TFE_Py_ForwardAccumulator* accumulator =
2857 PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type)( (TFE_Py_ForwardAccumulator *) PyObject_Init( (PyObject *) PyObject_Malloc
( ( (&TFE_Py_ForwardAccumulator_Type)->tp_basicsize ) )
, (&TFE_Py_ForwardAccumulator_Type)) )
;
2858 if (py_vspace == nullptr) {
2859 MaybeRaiseExceptionFromStatus(
2860 tensorflow::errors::Internal(
2861 "ForwardAccumulator requires a PyVSpace to be registered."),
2862 nullptr);
2863 }
2864 accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
2865 return reinterpret_cast<PyObject*>(accumulator);
2866}
2867
2868PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) {
2869 TFE_Py_ForwardAccumulator* c_accumulator(
2870 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2871 c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1);
2872 if (GetAccumulatorSet()->insert(c_accumulator)) {
2873 Py_INCREF(accumulator)_Py_INCREF(((PyObject*)(accumulator)));
2874 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
2875 } else {
2876 MaybeRaiseExceptionFromStatus(
2877 tensorflow::errors::Internal(
2878 "A ForwardAccumulator was added to the active set twice."),
2879 nullptr);
2880 return nullptr;
2881 }
2882}
2883
2884void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) {
2885 GetAccumulatorSet()->erase(
2886 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2887 Py_DECREF(accumulator)_Py_DECREF(((PyObject*)(accumulator)));
2888}
2889
2890void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
2891 PyObject* tangent) {
2892 int64_t tensor_id = FastTensorId(tensor);
2893 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2894 ->accumulator->Watch(tensor_id, tangent);
2895 RegisterForwardAccumulatorCleanup(tensor, tensor_id);
2896}
2897
2898// Returns a new reference to the JVP Tensor.
2899PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator,
2900 PyObject* tensor) {
2901 PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2902 ->accumulator->FetchJVP(FastTensorId(tensor));
2903 if (jvp == nullptr) {
2904 jvp = Py_None(&_Py_NoneStruct);
2905 }
2906 Py_INCREF(jvp)_Py_INCREF(((PyObject*)(jvp)));
2907 return jvp;
2908}
2909
2910PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
2911 if (!TapeCouldPossiblyRecord(tensors)) {
2912 tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0));
2913 tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0));
2914 return PyTuple_Pack(2, empty_tuple.get(), empty_list.get());
2915 }
2916 auto accumulators = *GetAccumulatorSet();
2917 tensorflow::Safe_PyObjectPtr tensors_fast(
2918 PySequence_Fast(tensors, "Expected a sequence of input Tensors."));
2919 if (tensors_fast == nullptr || PyErr_Occurred()) {
2920 return nullptr;
2921 }
2922 std::vector<int64_t> augmented_input_ids;
2923 int len = PySequence_Fast_GET_SIZE(tensors_fast.get())(((((((PyObject*)(tensors_fast.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((static_cast<void>
(0)), (((PyVarObject*)(tensors_fast.get()))->ob_size)) : (
((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject
*)(tensors_fast.get()))))->ob_size))
;
2924 PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get())(((((((PyObject*)(tensors_fast.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((PyListObject *)(tensors_fast
.get()))->ob_item : ((PyTupleObject *)(tensors_fast.get())
)->ob_item)
;
2925 for (Py_ssize_t position = 0; position < len; ++position) {
2926 PyObject* input = tensors_fast_array[position];
2927 if (input == Py_None(&_Py_NoneStruct)) {
2928 continue;
2929 }
2930 tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input));
2931 if (input_dtype == tensorflow::DT_INVALID) {
2932 return nullptr;
2933 }
2934 augmented_input_ids.push_back(FastTensorId(input));
2935 }
2936 if (PyErr_Occurred()) {
2937 return nullptr;
2938 }
2939 // Find the innermost accumulator such that all outer accumulators are
2940 // recording. Any more deeply nested accumulators will not have their JVPs
2941 // saved.
2942 AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin();
2943 for (; innermost_all_recording != accumulators.end();
2944 ++innermost_all_recording) {
2945 if ((*innermost_all_recording)->accumulator->BusyAccumulating()) {
2946 break;
2947 }
2948 }
2949 AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording(
2950 innermost_all_recording);
2951
2952 bool saving_jvps = false;
2953 tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size()));
2954 std::vector<PyObject*> new_tensors;
2955 Py_ssize_t accumulator_index = 0;
2956 // Start with the innermost accumulators to give outer accumulators a chance
2957 // to find their higher-order JVPs.
2958 for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin();
2959 it != accumulators.rend(); ++it, ++accumulator_index) {
2960 std::vector<int64_t> new_input_ids;
2961 std::vector<std::pair<int64_t, int64_t>> accumulator_indices;
2962 if (it == reverse_innermost_all_recording) {
2963 saving_jvps = true;
2964 }
2965 if (saving_jvps) {
2966 for (int input_index = 0; input_index < augmented_input_ids.size();
2967 ++input_index) {
2968 int64_t existing_input = augmented_input_ids[input_index];
2969 PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input);
2970 if (jvp != nullptr) {
2971 new_tensors.push_back(jvp);
2972 new_input_ids.push_back(FastTensorId(jvp));
2973 accumulator_indices.emplace_back(
2974 input_index,
2975 augmented_input_ids.size() + new_input_ids.size() - 1);
2976 }
2977 }
2978 }
2979 tensorflow::Safe_PyObjectPtr accumulator_indices_py(
2980 PyTuple_New(accumulator_indices.size()));
2981 for (int i = 0; i < accumulator_indices.size(); ++i) {
2982 tensorflow::Safe_PyObjectPtr from_index(
2983 GetPythonObjectFromInt(accumulator_indices[i].first));
2984 tensorflow::Safe_PyObjectPtr to_index(
2985 GetPythonObjectFromInt(accumulator_indices[i].second));
2986 PyTuple_SetItem(accumulator_indices_py.get(), i,
2987 PyTuple_Pack(2, from_index.get(), to_index.get()));
2988 }
2989 PyTuple_SetItem(all_indices.get(), accumulator_index,
2990 accumulator_indices_py.release());
2991 augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(),
2992 new_input_ids.end());
2993 }
2994
2995 tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size()));
2996 for (int i = 0; i < new_tensors.size(); ++i) {
2997 PyObject* jvp = new_tensors[i];
2998 Py_INCREF(jvp)_Py_INCREF(((PyObject*)(jvp)));
2999 PyList_SET_ITEM(new_tensors_py.get(), i, jvp)PyList_SetItem(new_tensors_py.get(), i, jvp);
3000 }
3001 return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get());
3002}
3003
3004namespace {
3005
3006// Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C.
3007enum FastPathExecuteArgIndex {
3008 FAST_PATH_EXECUTE_ARG_CONTEXT = 0,
3009 FAST_PATH_EXECUTE_ARG_OP_NAME = 1,
3010 FAST_PATH_EXECUTE_ARG_NAME = 2,
3011 FAST_PATH_EXECUTE_ARG_INPUT_START = 3
3012};
3013
3014PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
3015#if PY_MAJOR_VERSION3 >= 3
3016 return PyUnicode_FromStringAndSize(s.data(), s.size());
3017#else
3018 return PyBytes_FromStringAndSize(s.data(), s.size());
3019#endif
3020}
3021
3022bool CheckResourceVariable(PyObject* item) {
3023 if (tensorflow::swig::IsResourceVariable(item)) {
3024 tensorflow::Safe_PyObjectPtr handle(
3025 PyObject_GetAttrString(item, "_handle"));
3026 return EagerTensor_CheckExact(handle.get());
3027 }
3028
3029 return false;
3030}
3031
3032bool IsNumberType(PyObject* item) {
3033#if PY_MAJOR_VERSION3 >= 3
3034 return PyFloat_Check(item)((((PyObject*)(item))->ob_type) == (&PyFloat_Type) || PyType_IsSubtype
((((PyObject*)(item))->ob_type), (&PyFloat_Type)))
|| PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL
<< 24))) != 0)
;
3035#else
3036 return PyFloat_Check(item)((((PyObject*)(item))->ob_type) == (&PyFloat_Type) || PyType_IsSubtype
((((PyObject*)(item))->ob_type), (&PyFloat_Type)))
|| PyInt_Check(item) || PyLong_Check(item)((((((PyObject*)(item))->ob_type))->tp_flags & ((1UL
<< 24))) != 0)
;
3037#endif
3038}
3039
3040bool CheckOneInput(PyObject* item) {
3041 if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
3042 PyArray_Check(item)((((PyObject*)(item))->ob_type) == (&(*(PyTypeObject *
)_tensorflow_numpy_api[2])) || PyType_IsSubtype((((PyObject*)
(item))->ob_type), (&(*(PyTypeObject *)_tensorflow_numpy_api
[2]))))
|| IsNumberType(item)) {
3043 return true;
3044 }
3045
3046 // Sequences are not properly handled. Sequences with purely python numeric
3047 // types work, but sequences with mixes of EagerTensors and python numeric
3048 // types don't work.
3049 // TODO(nareshmodi): fix
3050 return false;
3051}
3052
3053bool CheckInputsOk(PyObject* seq, int start_index,
3054 const tensorflow::OpDef& op_def) {
3055 for (int i = 0; i < op_def.input_arg_size(); i++) {
3056 PyObject* item = PyTuple_GET_ITEM(seq, i + start_index)(((static_cast<void> (0)), (PyTupleObject *)(seq))->
ob_item[i + start_index])
;
3057 if (!op_def.input_arg(i).number_attr().empty() ||
3058 !op_def.input_arg(i).type_list_attr().empty()) {
3059 // This item should be a seq input.
3060 if (!PySequence_Check(item)) {
3061 VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static
const bool vmodule_activated = ::tensorflow::internal::LogMessage
::VmoduleActivated(fname, level); return vmodule_activated; }
)(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void
)0 : ::tensorflow::internal::Voidifier() & ::tensorflow::
internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc"
, 3061, tensorflow::INFO)
<< "Falling back to slow path for Op \"" << op_def.name()
3062 << "\", Input \"" << op_def.input_arg(i).name()
3063 << "\" since we expected a sequence, but got "
3064 << item->ob_type->tp_name;
3065 return false;
3066 }
3067 tensorflow::Safe_PyObjectPtr fast_item(
3068 PySequence_Fast(item, "Could not parse sequence."));
3069 if (fast_item.get() == nullptr) {
3070 return false;
3071 }
3072 int len = PySequence_Fast_GET_SIZE(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((static_cast<void>
(0)), (((PyVarObject*)(fast_item.get()))->ob_size)) : (((
PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *
)(fast_item.get()))))->ob_size))
;
3073 PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((PyListObject *)(fast_item
.get()))->ob_item : ((PyTupleObject *)(fast_item.get()))->
ob_item)
;
3074 for (Py_ssize_t j = 0; j < len; j++) {
3075 PyObject* inner_item = fast_item_array[j];
3076 if (!CheckOneInput(inner_item)) {
3077 VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static
const bool vmodule_activated = ::tensorflow::internal::LogMessage
::VmoduleActivated(fname, level); return vmodule_activated; }
)(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void
)0 : ::tensorflow::internal::Voidifier() & ::tensorflow::
internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc"
, 3077, tensorflow::INFO)
<< "Falling back to slow path for Op \"" << op_def.name()
3078 << "\", Input \"" << op_def.input_arg(i).name()
3079 << "\", Index " << j
3080 << " since we expected an EagerTensor/ResourceVariable, "
3081 "but got "
3082 << inner_item->ob_type->tp_name;
3083 return false;
3084 }
3085 }
3086 } else if (!CheckOneInput(item)) {
3087 VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static
const bool vmodule_activated = ::tensorflow::internal::LogMessage
::VmoduleActivated(fname, level); return vmodule_activated; }
)(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void
)0 : ::tensorflow::internal::Voidifier() & ::tensorflow::
internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc"
, 3087, tensorflow::INFO)
3088 << "Falling back to slow path for Op \"" << op_def.name()
3089 << "\", Input \"" << op_def.input_arg(i).name()
3090 << "\" since we expected an EagerTensor/ResourceVariable, but got "
3091 << item->ob_type->tp_name;
3092 return false;
3093 }
3094 }
3095
3096 return true;
3097}
3098
3099tensorflow::DataType MaybeGetDType(PyObject* item) {
3100 if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) {
3101 return tensorflow::PyTensor_DataType(item);
3102 }
3103
3104 return tensorflow::DT_INVALID;
3105}
3106
3107tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
3108 FastPathOpExecInfo* op_exec_info) {
3109 auto cached_it = op_exec_info->cached_dtypes.find(attr);
3110 if (cached_it != op_exec_info->cached_dtypes.end()) {
3111 return cached_it->second;
3112 }
3113
3114 auto it = op_exec_info->attr_to_inputs_map->find(attr);
3115 if (it == op_exec_info->attr_to_inputs_map->end()) {
3116 // No other inputs - this should never happen.
3117 return tensorflow::DT_INVALID;
3118 }
3119
3120 for (const auto& input_info : it->second) {
3121 PyObject* item = PyTuple_GET_ITEM((((static_cast<void> (0)), (PyTupleObject *)(op_exec_info
->args))->ob_item[FAST_PATH_EXECUTE_ARG_INPUT_START + input_info
.i])
3122 op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i)(((static_cast<void> (0)), (PyTupleObject *)(op_exec_info
->args))->ob_item[FAST_PATH_EXECUTE_ARG_INPUT_START + input_info
.i])
;
3123 if (input_info.is_list) {
3124 tensorflow::Safe_PyObjectPtr fast_item(
3125 PySequence_Fast(item, "Unable to allocate"));
3126 int len = PySequence_Fast_GET_SIZE(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((static_cast<void>
(0)), (((PyVarObject*)(fast_item.get()))->ob_size)) : (((
PyVarObject*)(((static_cast<void> (0)), (PyTupleObject *
)(fast_item.get()))))->ob_size))
;
3127 PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get())(((((((PyObject*)(fast_item.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((PyListObject *)(fast_item
.get()))->ob_item : ((PyTupleObject *)(fast_item.get()))->
ob_item)
;
3128 for (int i = 0; i < len; i++) {
3129 auto dtype = MaybeGetDType(fast_item_array[i]);
3130 if (dtype != tensorflow::DT_INVALID) return dtype;
3131 }
3132 } else {
3133 auto dtype = MaybeGetDType(item);
3134 if (dtype != tensorflow::DT_INVALID) return dtype;
3135 }
3136 }
3137
3138 auto default_it = op_exec_info->default_dtypes->find(attr);
3139 if (default_it != op_exec_info->default_dtypes->end()) {
3140 return default_it->second;
3141 }
3142
3143 return tensorflow::DT_INVALID;
3144}
3145
3146PyObject* CopySequenceSettingIndicesToNull(
3147 PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
3148 tensorflow::Safe_PyObjectPtr fast_seq(
3149 PySequence_Fast(seq, "unable to allocate"));
3150 int len = PySequence_Fast_GET_SIZE(fast_seq.get())(((((((PyObject*)(fast_seq.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((static_cast<void>
(0)), (((PyVarObject*)(fast_seq.get()))->ob_size)) : (((PyVarObject
*)(((static_cast<void> (0)), (PyTupleObject *)(fast_seq
.get()))))->ob_size))
;
3151 PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get())(((((((PyObject*)(fast_seq.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((PyListObject *)(fast_seq
.get()))->ob_item : ((PyTupleObject *)(fast_seq.get()))->
ob_item)
;
3152 PyObject* result = PyTuple_New(len);
3153 for (int i = 0; i < len; i++) {
3154 PyObject* item;
3155 if (indices.find(i) != indices.end()) {
3156 item = Py_None(&_Py_NoneStruct);
3157 } else {
3158 item = fast_seq_array[i];
3159 }
3160 Py_INCREF(item)_Py_INCREF(((PyObject*)(item)));
3161 PyTuple_SET_ITEM(result, i, item)PyTuple_SetItem(result, i, item);
3162 }
3163 return result;
3164}
3165
3166PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
3167 PyObject* results,
3168 PyObject* forward_pass_name_scope = nullptr) {
3169 std::vector<int64_t> input_ids = MakeTensorIDList(inputs);
3170 if (PyErr_Occurred()) return nullptr;
3171 std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
3172 if (PyErr_Occurred()) return nullptr;
3173
3174 bool should_record = false;
3175 for (TFE_Py_Tape* tape : SafeTapeSet()) {
3176 if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
3177 should_record = true;
3178 break;
3179 }
3180 }
3181 if (!should_record) {
3182 for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) {
3183 if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) {
3184 should_record = true;
3185 break;
3186 }
3187 }
3188 }
3189 if (!should_record) Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
3190
3191 string c_op_name = TFE_GetPythonString(op_name);
3192
3193 PyObject* op_outputs;
3194 bool op_outputs_tuple_created = false;
3195
3196 if (const auto unused_output_indices =
3197 OpGradientUnusedOutputIndices(c_op_name)) {
3198 if (unused_output_indices->empty()) {
3199 op_outputs = Py_None(&_Py_NoneStruct);
3200 } else {
3201 op_outputs_tuple_created = true;
3202 op_outputs =
3203 CopySequenceSettingIndicesToNull(results, *unused_output_indices);
3204 }
3205 } else {
3206 op_outputs = results;
3207 }
3208
3209 PyObject* op_inputs;
3210 bool op_inputs_tuple_created = false;
3211
3212 if (const auto unused_input_indices =
3213 OpGradientUnusedInputIndices(c_op_name)) {
3214 if (unused_input_indices->empty()) {
3215 op_inputs = Py_None(&_Py_NoneStruct);
3216 } else {
3217 op_inputs_tuple_created = true;
3218 op_inputs =
3219 CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
3220 }
3221 } else {
3222 op_inputs = inputs;
3223 }
3224
3225 tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
3226 [op_name, attrs, inputs, results](
3227 const std::vector<PyObject*>& input_tangents,
3228 std::vector<PyObject*>* output_tangents, bool use_batch) {
3229 return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
3230 output_tangents, use_batch);
3231 });
3232 tensorflow::eager::ForwardFunction<PyObject>* forward_function;
3233 if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
3234 c_op_name == "If" || c_op_name == "StatelessIf") {
3235 // Control flow contains non-hashable attributes. Handling them in Python is
3236 // a headache, so instead we'll stay as close to GradientTape's handling as
3237 // possible (a null forward function means the accumulator forwards to a
3238 // tape).
3239 //
3240 // This is safe to do since we'll only see control flow when graph building,
3241 // in which case we can rely on pruning.
3242 forward_function = nullptr;
3243 } else {
3244 forward_function = &py_forward_function;
3245 }
3246
3247 PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
3248
3249 if (!forward_pass_name_scope) forward_pass_name_scope = Py_None(&_Py_NoneStruct);
3250
3251 TapeSetRecordOperation(
3252 op_name, inputs, results, input_ids, input_dtypes,
3253 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3254 forward_pass_name_scope]() {
3255 Py_INCREF(op_name)_Py_INCREF(((PyObject*)(op_name)));
3256 Py_INCREF(attrs)_Py_INCREF(((PyObject*)(attrs)));
3257 Py_INCREF(num_inputs)_Py_INCREF(((PyObject*)(num_inputs)));
3258 Py_INCREF(op_inputs)_Py_INCREF(((PyObject*)(op_inputs)));
3259 Py_INCREF(op_outputs)_Py_INCREF(((PyObject*)(op_outputs)));
3260 Py_INCREF(forward_pass_name_scope)_Py_INCREF(((PyObject*)(forward_pass_name_scope)));
3261 PyBackwardFunction* function = new PyBackwardFunction(
3262 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3263 forward_pass_name_scope](
3264 PyObject* output_grads,
3265 const std::vector<int64_t>& unneeded_gradients) {
3266 if (PyErr_Occurred()) {
3267 return static_cast<PyObject*>(nullptr);
3268 }
3269 tensorflow::Safe_PyObjectPtr skip_input_indices;
3270 if (!unneeded_gradients.empty()) {
3271 skip_input_indices.reset(
3272 PyTuple_New(unneeded_gradients.size()));
3273 for (int i = 0; i < unneeded_gradients.size(); i++) {
3274 PyTuple_SET_ITEM(PyTuple_SetItem(skip_input_indices.get(), i, GetPythonObjectFromInt
(unneeded_gradients[i]))
3275 skip_input_indices.get(), i,PyTuple_SetItem(skip_input_indices.get(), i, GetPythonObjectFromInt
(unneeded_gradients[i]))
3276 GetPythonObjectFromInt(unneeded_gradients[i]))PyTuple_SetItem(skip_input_indices.get(), i, GetPythonObjectFromInt
(unneeded_gradients[i]))
;
3277 }
3278 } else {
3279 Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct))));
3280 skip_input_indices.reset(Py_None(&_Py_NoneStruct));
3281 }
3282 tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
3283 "OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
3284 output_grads, skip_input_indices.get(),
3285 forward_pass_name_scope));
3286
3287 tensorflow::Safe_PyObjectPtr result(
3288 PyObject_CallObject(gradient_function, callback_args.get()));
3289
3290 if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
3291
3292 return tensorflow::swig::Flatten(result.get());
3293 });
3294 return function;
3295 },
3296 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3297 forward_pass_name_scope](PyBackwardFunction* backward_function) {
3298 Py_DECREF(op_name)_Py_DECREF(((PyObject*)(op_name)));
3299 Py_DECREF(attrs)_Py_DECREF(((PyObject*)(attrs)));
3300 Py_DECREF(num_inputs)_Py_DECREF(((PyObject*)(num_inputs)));
3301 Py_DECREF(op_inputs)_Py_DECREF(((PyObject*)(op_inputs)));
3302 Py_DECREF(op_outputs)_Py_DECREF(((PyObject*)(op_outputs)));
3303 Py_DECREF(forward_pass_name_scope)_Py_DECREF(((PyObject*)(forward_pass_name_scope)));
3304
3305 delete backward_function;
3306 },
3307 forward_function);
3308
3309 Py_DECREF(num_inputs)_Py_DECREF(((PyObject*)(num_inputs)));
3310 if (op_outputs_tuple_created) Py_DECREF(op_outputs)_Py_DECREF(((PyObject*)(op_outputs)));
3311 if (op_inputs_tuple_created) Py_DECREF(op_inputs)_Py_DECREF(((PyObject*)(op_inputs)));
3312
3313 if (PyErr_Occurred()) {
3314 return nullptr;
3315 }
3316
3317 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
3318}
3319
3320void MaybeNotifyVariableAccessed(PyObject* input) {
3321 DCHECK(CheckResourceVariable(input))while (false && (CheckResourceVariable(input))) ::tensorflow
::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 3321)
;
3322 DCHECK(PyObject_HasAttrString(input, "_trainable"))while (false && (PyObject_HasAttrString(input, "_trainable"
))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 3322)
;
3323
3324 tensorflow::Safe_PyObjectPtr trainable(
3325 PyObject_GetAttrString(input, "_trainable"));
3326 if (trainable.get() == Py_False((PyObject *) &_Py_FalseStruct)) return;
3327 TFE_Py_TapeVariableAccessed(input);
3328 TFE_Py_VariableWatcherVariableAccessed(input);
3329}
3330
3331bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
3332 PyObject* input, tensorflow::Safe_PyObjectPtr* output,
3333 TF_Status* status) {
3334 MaybeNotifyVariableAccessed(input);
3335
3336 TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
3337 auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
3338 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3339
3340 TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
3341 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3342
3343 // Set dtype
3344 DCHECK(PyObject_HasAttrString(input, "_dtype"))while (false && (PyObject_HasAttrString(input, "_dtype"
))) ::tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 3344)
;
3345 tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
3346 int value;
3347 if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
3348 return false;
3349 }
3350 TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
3351
3352 // Get handle
3353 tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
3354 if (!EagerTensor_CheckExact(handle.get())) return false;
3355 TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
3356 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3357
3358 int num_retvals = 1;
3359 TFE_TensorHandle* output_handle;
3360 TFE_Execute(op, &output_handle, &num_retvals, status);
3361 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3362
3363 // Always create the py object (and correctly DECREF it) from the returned
3364 // value, else the data will leak.
3365 output->reset(EagerTensorFromHandle(output_handle));
3366
3367 // TODO(nareshmodi): Should we run post exec callbacks here?
3368 if (parent_op_exec_info.run_gradient_callback) {
3369 tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
3370 PyTuple_SET_ITEM(inputs.get(), 0, handle.release())PyTuple_SetItem(inputs.get(), 0, handle.release());
3371
3372 tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
3373 Py_INCREF(output->get())_Py_INCREF(((PyObject*)(output->get()))); // stay alive after since tuple steals.
3374 PyTuple_SET_ITEM(outputs.get(), 0, output->get())PyTuple_SetItem(outputs.get(), 0, output->get());
3375
3376 tensorflow::Safe_PyObjectPtr op_string(
3377 GetPythonObjectFromString("ReadVariableOp"));
3378 if (!RecordGradient(op_string.get(), inputs.get(), Py_None(&_Py_NoneStruct),
3379 outputs.get())) {
3380 return false;
3381 }
3382 }
3383
3384 return true;
3385}
3386
3387// Supports 3 cases at the moment:
3388// i) input is an EagerTensor.
3389// ii) input is a ResourceVariable - in this case, the is_variable param is
3390// set to true.
3391// iii) input is an arbitrary python list/tuple (note, this handling doesn't
3392// support packing).
3393//
3394// NOTE: dtype_hint_getter must *always* return a PyObject that can be
3395// decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
3396// increfs Py_None).
3397//
3398// NOTE: This function sets a python error directly, and returns false.
3399// TF_Status is only passed since we don't want to have to reallocate it.
3400bool ConvertToTensor(
3401 const FastPathOpExecInfo& op_exec_info, PyObject* input,
3402 tensorflow::Safe_PyObjectPtr* output_handle,
3403 // This gets a hint for this particular input.
3404 const std::function<tensorflow::DataType()>& dtype_hint_getter,
3405 // This sets the dtype after conversion is complete.
3406 const std::function<void(const tensorflow::DataType dtype)>& dtype_setter,
3407 TF_Status* status) {
3408 if (EagerTensor_CheckExact(input)) {
3409 Py_INCREF(input)_Py_INCREF(((PyObject*)(input)));
3410 output_handle->reset(input);
3411 return true;
3412 } else if (CheckResourceVariable(input)) {
3413 return ReadVariableOp(op_exec_info, input, output_handle, status);
3414 }
3415
3416 // The hint comes from a supposedly similarly typed tensor.
3417 tensorflow::DataType dtype_hint = dtype_hint_getter();
3418
3419 TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor(
3420 op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name);
3421 if (handle == nullptr) {
3422 return MaybeRaiseExceptionFromTFStatus(status, nullptr);
3423 }
3424
3425 output_handle->reset(EagerTensorFromHandle(handle));
3426 dtype_setter(
3427 static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle)));
3428
3429 return true;
3430}
3431
3432// Adds input and type attr to the op, and to the list of flattened
3433// inputs/attrs.
3434bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
3435 const bool add_type_attr,
3436 const tensorflow::OpDef::ArgDef& input_arg,
3437 std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
3438 std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
3439 TFE_Op* op, TF_Status* status) {
3440 // py_eager_tensor's ownership is transferred to flattened_inputs if it is
3441 // required, else the object is destroyed and DECREF'd when the object goes
3442 // out of scope in this function.
3443 tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
3444
3445 if (!ConvertToTensor(
3446 *op_exec_info, input, &py_eager_tensor,
3447 [&]() {
3448 if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
3449 return input_arg.type();
3450 }
3451 return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
3452 },
3453 [&](const tensorflow::DataType dtype) {
3454 op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype;
3455 },
3456 status)) {
3457 return false;
3458 }
3459
3460 TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
3461
3462 if (add_type_attr && !input_arg.type_attr().empty()) {
3463 auto dtype = TFE_TensorHandleDataType(input_handle);
3464 TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
3465 if (flattened_attrs != nullptr) {
3466 flattened_attrs->emplace_back(
3467 GetPythonObjectFromString(input_arg.type_attr()));
3468 flattened_attrs->emplace_back(PyLong_FromLong(dtype));
3469 }
3470 }
3471
3472 if (flattened_inputs != nullptr) {
3473 flattened_inputs->emplace_back(std::move(py_eager_tensor));
3474 }
3475
3476 TFE_OpAddInput(op, input_handle, status);
3477 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3478 return false;
3479 }
3480
3481 return true;
3482}
3483
3484const char* GetDeviceName(PyObject* py_device_name) {
3485 if (py_device_name != Py_None(&_Py_NoneStruct)) {
3486 return TFE_GetPythonString(py_device_name);
3487 }
3488 return nullptr;
3489}
3490
3491bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
3492 if (!PySequence_Check(seq)) {
3493 PyErr_SetString(PyExc_TypeError,
3494 Printf("expected a sequence for attr %s, got %s instead",
3495 attr_name.data(), seq->ob_type->tp_name)
3496 .data());
3497
3498 return false;
3499 }
3500 if (PyArray_Check(seq)((((PyObject*)(seq))->ob_type) == (&(*(PyTypeObject *)
_tensorflow_numpy_api[2])) || PyType_IsSubtype((((PyObject*)(
seq))->ob_type), (&(*(PyTypeObject *)_tensorflow_numpy_api
[2]))))
&&
3501 PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) {
3502 PyErr_SetString(PyExc_ValueError,
3503 Printf("expected a sequence for attr %s, got an ndarray "
3504 "with rank %d instead",
3505 attr_name.data(),
3506 PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)))
3507 .data());
3508 return false;
3509 }
3510 return true;
3511}
3512
3513bool RunCallbacks(
3514 const FastPathOpExecInfo& op_exec_info, PyObject* args,
3515 int num_inferred_attrs,
3516 const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs,
3517 const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs,
3518 PyObject* flattened_result) {
3519 DCHECK(op_exec_info.run_callbacks)while (false && (op_exec_info.run_callbacks)) ::tensorflow
::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 3519)
;
3520
3521 tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size()));
3522 for (int i = 0; i < flattened_inputs.size(); i++) {
3523 PyObject* input = flattened_inputs[i].get();
3524 Py_INCREF(input)_Py_INCREF(((PyObject*)(input)));
3525 PyTuple_SET_ITEM(inputs.get(), i, input)PyTuple_SetItem(inputs.get(), i, input);
3526 }
3527
3528 int num_non_inferred_attrs = PyTuple_GET_SIZE(args)(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject
*)(args))))->ob_size)
- num_inferred_attrs;
3529 int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
3530 tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
3531
3532 for (int i = 0; i < num_non_inferred_attrs; i++) {
3533 auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i)(((static_cast<void> (0)), (PyTupleObject *)(args))->
ob_item[num_inferred_attrs + i])
;
3534 Py_INCREF(attr)_Py_INCREF(((PyObject*)(attr)));
3535 PyTuple_SET_ITEM(attrs.get(), i, attr)PyTuple_SetItem(attrs.get(), i, attr);
3536 }
3537
3538 for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
3539 PyObject* attr_or_name =
3540 flattened_attrs.at(i - num_non_inferred_attrs).get();
3541 Py_INCREF(attr_or_name)_Py_INCREF(((PyObject*)(attr_or_name)));
3542 PyTuple_SET_ITEM(attrs.get(), i, attr_or_name)PyTuple_SetItem(attrs.get(), i, attr_or_name);
3543 }
3544
3545 if (op_exec_info.run_gradient_callback) {
3546 if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
3547 flattened_result)) {
3548 return false;
3549 }
3550 }
3551
3552 if (op_exec_info.run_post_exec_callbacks) {
3553 tensorflow::Safe_PyObjectPtr callback_args(
3554 Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
3555 flattened_result, op_exec_info.name));
3556 for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
3557 PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i)(((PyListObject *)(op_exec_info.callbacks))->ob_item[i]);
3558 if (!PyCallable_Check(callback_fn)) {
3559 PyErr_SetString(
3560 PyExc_TypeError,
3561 Printf("expected a function for "
3562 "post execution callback in index %ld, got %s instead",
3563 i, callback_fn->ob_type->tp_name)
3564 .c_str());
3565 return false;
3566 }
3567 PyObject* callback_result =
3568 PyObject_CallObject(callback_fn, callback_args.get());
3569 if (!callback_result) {
3570 return false;
3571 }
3572 Py_DECREF(callback_result)_Py_DECREF(((PyObject*)(callback_result)));
3573 }
3574 }
3575
3576 return true;
3577}
3578
3579} // namespace
3580
3581PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
3582 tensorflow::profiler::TraceMe activity(
3583 "TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
3584 Py_ssize_t args_size = PyTuple_GET_SIZE(args)(((PyVarObject*)(((static_cast<void> (0)), (PyTupleObject
*)(args))))->ob_size)
;
3585 if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) {
3586 PyErr_SetString(
3587 PyExc_ValueError,
3588 Printf("There must be at least %d items in the input tuple.",
3589 FAST_PATH_EXECUTE_ARG_INPUT_START)
3590 .c_str());
3591 return nullptr;
3592 }
3593
3594 FastPathOpExecInfo op_exec_info;
3595
3596 PyObject* py_eager_context =
3597 PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT)(((static_cast<void> (0)), (PyTupleObject *)(args))->
ob_item[FAST_PATH_EXECUTE_ARG_CONTEXT])
;
3598
3599 // TODO(edoper): Use interned string here
3600 PyObject* eager_context_handle =
3601 PyObject_GetAttrString(py_eager_context, "_context_handle");
3602
3603 TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
3604 PyCapsule_GetPointer(eager_context_handle, nullptr));
3605 op_exec_info.ctx = ctx;
3606 op_exec_info.args = args;
3607
3608 if (ctx == nullptr) {
3609 // The context hasn't been initialized. It will be in the slow path.
3610 RaiseFallbackException(
3611 "This function does not handle the case of the path where "
3612 "all inputs are not already EagerTensors.");
3613 return nullptr;
3614 }
3615
3616 auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context);
3617 if (tld == nullptr) {
3618 return nullptr;
3619 }
3620 op_exec_info.device_name = GetDeviceName(tld->device_name.get());
3621 op_exec_info.callbacks = tld->op_callbacks.get();
3622
3623 op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME)(((static_cast<void> (0)), (PyTupleObject *)(args))->
ob_item[FAST_PATH_EXECUTE_ARG_OP_NAME])
;
3624 op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME)(((static_cast<void> (0)), (PyTupleObject *)(args))->
ob_item[FAST_PATH_EXECUTE_ARG_NAME])
;
3625
3626 // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
3627 // (similar to benchmark_tf_gradient_function_*). Also consider using an
3628 // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
3629 // point out problems with heap allocs.
3630 op_exec_info.run_gradient_callback =
3631 !*ThreadTapeIsStopped() && HasAccumulatorOrTape();
3632 op_exec_info.run_post_exec_callbacks =
3633 op_exec_info.callbacks != Py_None(&_Py_NoneStruct) &&
3634 PyList_Size(op_exec_info.callbacks) > 0;
3635 op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
3636 op_exec_info.run_post_exec_callbacks;
3637
3638 TF_Status* status = GetStatus();
3639 const char* op_name = TFE_GetPythonString(op_exec_info.op_name);
3640 if (op_name == nullptr) {
3641 PyErr_SetString(PyExc_TypeError,
3642 Printf("expected a string for op_name, got %s instead",
3643 op_exec_info.op_name->ob_type->tp_name)
3644 .c_str());
3645 return nullptr;
3646 }
3647
3648 TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status);
3649
3650 auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] {
3651 ReturnStatus(status);
3652 ReturnOp(ctx, op);
3653 });
3654
3655 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3656 return nullptr;
3657 }
3658
3659 tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
3660 tensorflow::StackTrace::kStackTraceInitialSize));
3661
3662 const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
3663 if (op_def == nullptr) return nullptr;
3664
3665 if (args_size <
3666 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) {
3667 PyErr_SetString(
3668 PyExc_ValueError,
3669 Printf("Tuple size smaller than intended. Expected to be at least %d, "
3670 "was %ld",
3671 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3672 args_size)
3673 .c_str());
3674 return nullptr;
3675 }
3676
3677 if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) {
3678 RaiseFallbackException(
3679 "This function does not handle the case of the path where "
3680 "all inputs are not already EagerTensors.");
3681 return nullptr;
3682 }
3683
3684 op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def);
3685 op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def);
3686
3687 // Mapping of attr name to size - used to calculate the number of values
3688 // to be expected by the TFE_Execute run.
3689 tensorflow::gtl::FlatMap<string, int64_t> attr_list_sizes;
3690
3691 // Set non-inferred attrs, including setting defaults if the attr is passed in
3692 // as None.
3693 for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size();
3694 i < args_size; i += 2) {
3695 PyObject* py_attr_name = PyTuple_GET_ITEM(args, i)(((static_cast<void> (0)), (PyTupleObject *)(args))->
ob_item[i])
;
3696 const char* attr_name = TFE_GetPythonString(py_attr_name);
3697 PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1)(((static_cast<void> (0)), (PyTupleObject *)(args))->
ob_item[i + 1])
;
3698
3699 // Not creating an index since most of the time there are not more than a
3700 // few attrs.
3701 // TODO(nareshmodi): Maybe include the index as part of the
3702 // OpRegistrationData.
3703 for (const auto& attr : op_def->attr()) {
3704 if (tensorflow::StringPiece(attr_name) == attr.name()) {
3705 SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value,
3706 &attr_list_sizes, status);
3707
3708 if (!status->status.ok()) {
3709 VLOG(1)(__builtin_expect(!!(!(([](int level, const char* fname) { static
const bool vmodule_activated = ::tensorflow::internal::LogMessage
::VmoduleActivated(fname, level); return vmodule_activated; }
)(1, "tensorflow/python/eager/pywrap_tfe_src.cc"))), 1)) ? (void
)0 : ::tensorflow::internal::Voidifier() & ::tensorflow::
internal::LogMessage("tensorflow/python/eager/pywrap_tfe_src.cc"
, 3709, tensorflow::INFO)
<< "Falling back to slow path for Op \"" << op_def->name()
3710 << "\" since we are unable to set the value for attr \""
3711 << attr.name() << "\" due to: " << TF_Message(status);
3712 RaiseFallbackException(TF_Message(status));
3713 return nullptr;
3714 }
3715
3716 break;
3717 }
3718 }
3719 }
3720
3721 // Flat attrs and inputs as required by the record_gradient call. The attrs
3722 // here only contain inferred attrs (non-inferred attrs are added directly
3723 // from the input args).
3724 // All items in flattened_attrs and flattened_inputs contain
3725 // Safe_PyObjectPtr - any time something steals a reference to this, it must
3726 // INCREF.
3727 // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
3728 // directly.
3729 std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
3730 nullptr;
3731 std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
3732 nullptr;
3733
3734 // TODO(nareshmodi): Encapsulate callbacks information into a struct.
3735 if (op_exec_info.run_callbacks) {
3736 flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3737 flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3738 }
3739
3740 // Add inferred attrs and inputs.
3741 // The following code might set duplicate type attrs. This will result in
3742 // the CacheKey for the generated AttrBuilder possibly differing from
3743 // those where the type attrs are correctly set. Inconsistent CacheKeys
3744 // for ops means that there might be unnecessarily duplicated kernels.
3745 // TODO(nareshmodi): Fix this.
3746 for (int i = 0; i < op_def->input_arg_size(); i++) {
3747 const auto& input_arg = op_def->input_arg(i);
3748
3749 PyObject* input =
3750 PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i)(((static_cast<void> (0)), (PyTupleObject *)(args))->
ob_item[FAST_PATH_EXECUTE_ARG_INPUT_START + i])
;
3751 if (!input_arg.number_attr().empty()) {
3752 // The item is a homogeneous list.
3753 if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
3754 tensorflow::Safe_PyObjectPtr fast_input(
3755 PySequence_Fast(input, "Could not parse sequence."));
3756 if (fast_input.get() == nullptr) {
3757 return nullptr;
3758 }
3759 Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((static_cast<void>
(0)), (((PyVarObject*)(fast_input.get()))->ob_size)) : ((
(PyVarObject*)(((static_cast<void> (0)), (PyTupleObject
*)(fast_input.get()))))->ob_size))
;
3760 PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((PyListObject *)(fast_input
.get()))->ob_item : ((PyTupleObject *)(fast_input.get()))->
ob_item)
;
3761
3762 TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
3763 if (op_exec_info.run_callbacks) {
3764 flattened_attrs->emplace_back(
3765 GetPythonObjectFromString(input_arg.number_attr()));
3766 flattened_attrs->emplace_back(PyLong_FromLong(len));
3767 }
3768 attr_list_sizes[input_arg.number_attr()] = len;
3769
3770 if (len > 0) {
3771 // First item adds the type attr.
3772 if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg,
3773 flattened_attrs.get(), flattened_inputs.get(), op,
3774 status)) {
3775 return nullptr;
3776 }
3777
3778 for (Py_ssize_t j = 1; j < len; j++) {
3779 // Since the list is homogeneous, we don't need to re-add the attr.
3780 if (!AddInputToOp(&op_exec_info, fast_input_array[j], false,
3781 input_arg, nullptr /* flattened_attrs */,
3782 flattened_inputs.get(), op, status)) {
3783 return nullptr;
3784 }
3785 }
3786 }
3787 } else if (!input_arg.type_list_attr().empty()) {
3788 // The item is a heterogeneous list.
3789 if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
3790 return nullptr;
3791 }
3792 tensorflow::Safe_PyObjectPtr fast_input(
3793 PySequence_Fast(input, "Could not parse sequence."));
3794 if (fast_input.get() == nullptr) {
3795 return nullptr;
3796 }
3797 const string& attr_name = input_arg.type_list_attr();
3798 Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((static_cast<void>
(0)), (((PyVarObject*)(fast_input.get()))->ob_size)) : ((
(PyVarObject*)(((static_cast<void> (0)), (PyTupleObject
*)(fast_input.get()))))->ob_size))
;
3799 PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get())(((((((PyObject*)(fast_input.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((PyListObject *)(fast_input
.get()))->ob_item : ((PyTupleObject *)(fast_input.get()))->
ob_item)
;
3800 tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
3801 PyObject* py_attr_value = nullptr;
3802 if (op_exec_info.run_callbacks) {
3803 py_attr_value = PyTuple_New(len);
3804 }
3805 for (Py_ssize_t j = 0; j < len; j++) {
3806 PyObject* py_input = fast_input_array[j];
3807 tensorflow::Safe_PyObjectPtr py_eager_tensor;
3808 if (!ConvertToTensor(
3809 op_exec_info, py_input, &py_eager_tensor,
3810 []() { return tensorflow::DT_INVALID; },
3811 [](const tensorflow::DataType dtype) {}, status)) {
3812 return nullptr;
3813 }
3814
3815 TFE_TensorHandle* input_handle =
3816 EagerTensor_Handle(py_eager_tensor.get());
3817
3818 attr_value[j] = TFE_TensorHandleDataType(input_handle);
3819
3820 TFE_OpAddInput(op, input_handle, status);
3821 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3822 return nullptr;
3823 }
3824
3825 if (op_exec_info.run_callbacks) {
3826 flattened_inputs->emplace_back(std::move(py_eager_tensor));
3827
3828 PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]))PyTuple_SetItem(py_attr_value, j, PyLong_FromLong(attr_value[
j]))
;
3829 }
3830 }
3831 if (op_exec_info.run_callbacks) {
3832 flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name));
3833 flattened_attrs->emplace_back(py_attr_value);
3834 }
3835 TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
3836 attr_value.size());
3837 attr_list_sizes[attr_name] = len;
3838 } else {
3839 // The item is a single item.
3840 if (!AddInputToOp(&op_exec_info, input, true, input_arg,
3841 flattened_attrs.get(), flattened_inputs.get(), op,
3842 status)) {
3843 return nullptr;
3844 }
3845 }
3846 }
3847
3848 int64_t num_outputs = 0;
3849 for (int i = 0; i < op_def->output_arg_size(); i++) {
3850 const auto& output_arg = op_def->output_arg(i);
3851 int64_t delta = 1;
3852 if (!output_arg.number_attr().empty()) {
3853 delta = attr_list_sizes[output_arg.number_attr()];
3854 } else if (!output_arg.type_list_attr().empty()) {
3855 delta = attr_list_sizes[output_arg.type_list_attr()];
3856 }
3857 if (delta < 0) {
3858 RaiseFallbackException(
3859 "Attributes suggest that the size of an output list is less than 0");
3860 return nullptr;
3861 }
3862 num_outputs += delta;
3863 }
3864
3865 // If number of retvals is larger than int32, we error out.
3866 if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) {
3867 PyErr_SetString(
3868 PyExc_ValueError,
3869 Printf("Number of outputs is too big: %ld", num_outputs).c_str());
3870 return nullptr;
3871 }
3872 int num_retvals = num_outputs;
3873
3874 tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
3875
3876 Py_BEGIN_ALLOW_THREADS{ PyThreadState *_save; _save = PyEval_SaveThread();;
3877 TFE_Execute(op, retvals.data(), &num_retvals, status);
3878 Py_END_ALLOW_THREADSPyEval_RestoreThread(_save); };
3879
3880 if (!status->status.ok()) {
3881 // Augment the status with the op_name for easier debugging similar to
3882 // TFE_Py_Execute.
3883 status->status = tensorflow::errors::CreateWithUpdatedMessage(
3884 status->status, tensorflow::strings::StrCat(
3885 TF_Message(status), " [Op:",
3886 TFE_GetPythonString(op_exec_info.op_name), "]"));
3887 MaybeRaiseExceptionFromTFStatus(status, nullptr);
3888 return nullptr;
3889 }
3890
3891 tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
3892 for (int i = 0; i < num_retvals; ++i) {
3893 PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]))PyList_SetItem(flat_result.get(), i, EagerTensorFromHandle(retvals
[i]))
;
3894 }
3895
3896 if (op_exec_info.run_callbacks) {
3897 if (!RunCallbacks(
3898 op_exec_info, args,
3899 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3900 *flattened_inputs, *flattened_attrs, flat_result.get())) {
3901 return nullptr;
3902 }
3903 }
3904
3905 // Unflatten results.
3906 if (op_def->output_arg_size() == 0) {
3907 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
3908 }
3909
3910 if (op_def->output_arg_size() == 1) {
3911 if (!op_def->output_arg(0).number_attr().empty() ||
3912 !op_def->output_arg(0).type_list_attr().empty()) {
3913 return flat_result.release();
3914 } else {
3915 auto* result = PyList_GET_ITEM(flat_result.get(), 0)(((PyListObject *)(flat_result.get()))->ob_item[0]);
3916 Py_INCREF(result)_Py_INCREF(((PyObject*)(result)));
3917 return result;
3918 }
3919 }
3920
3921 // Correctly output the results that are made into a namedtuple.
3922 PyObject* result = PyList_New(op_def->output_arg_size());
3923 int flat_result_index = 0;
3924 for (int i = 0; i < op_def->output_arg_size(); i++) {
3925 if (!op_def->output_arg(i).number_attr().empty()) {
3926 int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
3927 PyObject* inner_list = PyList_New(list_length);
3928 for (int j = 0; j < list_length; j++) {
3929 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++)(((PyListObject *)(flat_result.get()))->ob_item[flat_result_index
++])
;
3930 Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj)));
3931 PyList_SET_ITEM(inner_list, j, obj)PyList_SetItem(inner_list, j, obj);
3932 }
3933 PyList_SET_ITEM(result, i, inner_list)PyList_SetItem(result, i, inner_list);
3934 } else if (!op_def->output_arg(i).type_list_attr().empty()) {
3935 int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
3936 PyObject* inner_list = PyList_New(list_length);
3937 for (int j = 0; j < list_length; j++) {
3938 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++)(((PyListObject *)(flat_result.get()))->ob_item[flat_result_index
++])
;
3939 Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj)));
3940 PyList_SET_ITEM(inner_list, j, obj)PyList_SetItem(inner_list, j, obj);
3941 }
3942 PyList_SET_ITEM(result, i, inner_list)PyList_SetItem(result, i, inner_list);
3943 } else {
3944 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++)(((PyListObject *)(flat_result.get()))->ob_item[flat_result_index
++])
;
3945 Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj)));
3946 PyList_SET_ITEM(result, i, obj)PyList_SetItem(result, i, obj);
3947 }
3948 }
3949 return result;
3950}
3951
3952PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
3953 PyObject* attrs, PyObject* results,
3954 PyObject* forward_pass_name_scope) {
3955 if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
3956 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
3957 }
3958
3959 return RecordGradient(op_name, inputs, attrs, results,
3960 forward_pass_name_scope);
3961}
3962
3963namespace {
3964const char kTensor[] = "T";
3965const char kList[] = "L";
3966const char kListEnd[] = "l";
3967const char kTuple[] = "U";
3968const char kTupleEnd[] = "u";
3969const char kDIter[] = "I";
3970const char kDict[] = "D";
3971const char kRaw[] = "R";
3972const char kResourceVariable[] = "r";
3973const char kShape[] = "s";
3974const char kShapeDelim[] = "-";
3975const char kDType[] = "d";
3976const char kNone[] = "n";
3977const char kCompositeTensor[] = "C";
3978const char kAttrs[] = "A";
3979const char kAttrsEnd[] = "a";
3980const char kName[] = "'";
3981const char kNameEnd[] = "'";
3982const char kLocalIdDelim[] = "_";
3983
3984// Container for storing generated string encoding as well as the raw python
3985// objects that were not included in the string.
3986struct EncodeResult {
3987 string str;
3988 std::vector<PyObject*> objects;
3989
3990 PyObject* ToPyTuple() {
3991 PyObject* result = PyTuple_New(2);
3992
3993 PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str))PyTuple_SetItem(result, 0, GetPythonObjectFromString(str));
3994
3995 if (objects.empty()) {
3996 Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct))));
3997 PyTuple_SET_ITEM(result, 1, Py_None)PyTuple_SetItem(result, 1, (&_Py_NoneStruct));
3998 } else {
3999 PyObject* objects_tuple = PyTuple_New(objects.size());
4000
4001 for (int i = 0; i < objects.size(); i++) {
4002 PyTuple_SET_ITEM(objects_tuple, i, objects[i])PyTuple_SetItem(objects_tuple, i, objects[i]);
4003 }
4004
4005 PyTuple_SET_ITEM(result, 1, objects_tuple)PyTuple_SetItem(result, 1, objects_tuple);
4006 }
4007
4008 return result;
4009 }
4010};
4011
4012// Gives each unique resource_id a unique incremental local_id. Provides a
4013// string encoding that informs an order and uniqueness sensitive input
4014// signature.
4015// This class is not thread safe and is not meant to be shared across threads.
4016class LocalResourceIdMap {
4017 public:
4018 // When the resource ID is known (such as for OwnedIterator).
4019 // Returns the existing local ID (if present) or a new unique one.
4020 int AddResourceId(int resource_id) {
4021 const auto& it = resource_id_to_local_id_.find(resource_id);
4022 if (it == resource_id_to_local_id_.end()) {
4023 resource_id_to_local_id_[resource_id] = next_local_id_;
4024 return next_local_id_++;
4025 } else {
4026 return it->second;
4027 }
4028 }
4029
4030 // When the resource ID is not known (such as for IteratorSpec).
4031 // Returns a new unique local ID.
4032 int AddUnknownResource() { return next_local_id_++; }
4033
4034 private:
4035 absl::flat_hash_map<int, int> resource_id_to_local_id_;
4036 int next_local_id_ = 0;
4037};
4038
4039// Contains encoding configuration, intermediary data and result.
4040struct EncodingContext {
4041 bool include_tensor_ranks_only;
4042 bool encode_variable_by_resource_id;
4043
4044 LocalResourceIdMap resource_id_map;
4045 EncodeResult result;
4046};
4047
4048tensorflow::Status EncodeTensorOrTensorSpec(PyObject* arg, bool is_tensor_spec,
4049 EncodingContext& context) {
4050 absl::StrAppend(&context.result.str, kTensor);
4051
4052 if (is_tensor_spec) {
4053 tensorflow::Safe_PyObjectPtr name(PyObject_GetAttrString(arg, "name"));
4054 if (name != nullptr && name.get() != Py_None(&_Py_NoneStruct)) {
4055 absl::StrAppend(&context.result.str, kName,
4056 TFE_GetPythonString(name.get()), kNameEnd);
4057 }
4058 }
4059
4060 tensorflow::Safe_PyObjectPtr dtype_object(
4061 PyObject_GetAttrString(arg, "dtype"));
4062 if (dtype_object == nullptr) {
4063 return tensorflow::errors::InvalidArgument(
4064 "tf.TensorSpec object doesn't have dtype() attr.");
4065 }
4066
4067 tensorflow::Safe_PyObjectPtr dtype_enum(
4068 PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
4069 if (dtype_enum == nullptr) {
4070 return tensorflow::errors::InvalidArgument(
4071 "tf.TensorSpec's dtype object doesn't have _type_enum() "
4072 "attr.");
4073 }
4074
4075 tensorflow::DataType dtype =
4076 static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
4077 absl::StrAppend(&context.result.str, kDType, dtype);
4078
4079 tensorflow::Safe_PyObjectPtr shape_tuple(
4080 PyObject_GetAttrString(arg, "shape"));
4081 if (shape_tuple == nullptr) {
4082 return tensorflow::errors::InvalidArgument(
4083 "tf.TensorSpec object doesn't have shape() attr.");
4084 }
4085
4086 tensorflow::Safe_PyObjectPtr rank(
4087 PyObject_GetAttr(shape_tuple.get(), PyUnicode_FromString("rank")));
4088 if (rank == nullptr || rank.get() == Py_None(&_Py_NoneStruct)) {
4089 // Unknown shape, encode that directly.
4090 absl::StrAppend(&context.result.str, kNone);
4091 return tensorflow::Status::OK();
4092 }
4093
4094 absl::StrAppend(&context.result.str, kShape);
4095
4096 tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
4097 shape_tuple.get(), "shape_tuple didn't return a sequence"));
4098
4099 int len = MakeInt(rank.get());
4100 PyObject** shape_seq_array = PySequence_Fast_ITEMS(shape_seq.get())(((((((PyObject*)(shape_seq.get()))->ob_type))->tp_flags
& ((1UL << 25))) != 0) ? ((PyListObject *)(shape_seq
.get()))->ob_item : ((PyTupleObject *)(shape_seq.get()))->
ob_item)
;
4101
4102 if (context.include_tensor_ranks_only) {
4103 absl::StrAppend(&context.result.str, len);
4104 } else {
4105 for (int i = 0; i < len; ++i) {
4106 // Can be None, int or a Dimension object.
4107 PyObject* dimension = shape_seq_array[i];
4108
4109 // If it is a Dimension object, then we must extract value from it first.
4110 bool is_dimension_class = PyObject_HasAttrString(dimension, "value");
4111 tensorflow::Safe_PyObjectPtr dimension_holder;
4112 if (is_dimension_class) {
4113 dimension_holder =
4114 tensorflow::make_safe(PyObject_GetAttrString(dimension, "value"));
4115 dimension = dimension_holder.get();
4116 }
4117
4118 if (dimension == Py_None(&_Py_NoneStruct)) {
4119 absl::StrAppend(&context.result.str, kNone);
4120 } else {
4121 absl::StrAppend(&context.result.str, MakeInt(dimension), kShapeDelim);
4122 }
4123 }
4124 }
4125
4126 return tensorflow::Status::OK();
4127}
4128
4129// TODO(b/199534088): Remove this function by using EncodeResource instead.
4130tensorflow::Status EncodeOwnedIterator(PyObject* arg,
4131 EncodingContext& context) {
4132 PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec"));
4133 if (type_spec == nullptr) {
4134 return tensorflow::errors::InvalidArgument(
4135 "Error while reading OwnedIterator._type_spec.");
4136 }
4137 context.result.objects.push_back(type_spec);
4138
4139 // Add resource tracking
4140 tensorflow::Safe_PyObjectPtr itr_res(
4141 PyObject_GetAttrString(arg, "_iterator_resource"));
4142 if (itr_res == nullptr) {
4143 return tensorflow::errors::InvalidArgument(
4144 "Error while reading Dataset iterator resource.");
4145 }
4146 // OwnedIterator should ideally always provide a unique resource id.
4147 // TODO(b/199534088) Cases where resource_id is not provided need to be fixed.
4148 if (tensorflow::swig::IsTensor(itr_res.get())) {
4149 absl::StrAppend(&context.result.str, kDIter);
4150 tensorflow::Safe_PyObjectPtr py_resource_id(
4151 PyObject_GetAttrString(itr_res.get(), "_id"));
4152 if (py_resource_id == nullptr) {
4153 return tensorflow::errors::InvalidArgument(
4154 "Error while reading Dataset iterator resouce id.");
4155 }
4156 int resource_id = PyLong_AsSize_t(py_resource_id.get());
4157 if (resource_id < 0) {
4158 return tensorflow::errors::InvalidArgument("PyLong_AsSize_t failure");
4159 }
4160 int local_id = context.resource_id_map.AddResourceId(resource_id);
4161 absl::StrAppend(&context.result.str, local_id, kLocalIdDelim);
4162 } else {
4163 // If '_iterator_resource' is not a Tensor, there is no resource id.
4164 // Instead we treat it the same way as a CompositeTensor
4165 absl::StrAppend(&context.result.str, kCompositeTensor);
4166 }
4167 return tensorflow::Status::OK();
4168}
4169
4170tensorflow::Status EncodeResource(PyObject* arg, EncodingContext& context) {
4171 absl::StrAppend(&context.result.str, kResourceVariable);
4172 tensorflow::Safe_PyObjectPtr py_resource_id(
4173 PyObject_CallMethod(arg, "__tf_resource_id__", nullptr));
4174 DCHECK(py_resource_id != nullptr)while (false && (py_resource_id != nullptr)) ::tensorflow
::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 4174)
;
4175
4176 int resource_id = PyLong_AsSize_t(py_resource_id.get());
4177 DCHECK_GE(resource_id, 0)while (false && ((void)(resource_id), (void)(0), 0)) ::
tensorflow::internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 4177)
;
4178 int local_id = context.resource_id_map.AddResourceId(resource_id);
4179 absl::StrAppend(&context.result.str, local_id, kLocalIdDelim);
4180
4181 tensorflow::Safe_PyObjectPtr type_spec(
4182 PyObject_CallMethod(arg, "__tf_function_cache_spec__", nullptr));
4183 absl::StrAppend(&context.result.str, PyUnicode_AsUTF8(type_spec.get()));
4184 DCHECK(type_spec != nullptr)while (false && (type_spec != nullptr)) ::tensorflow::
internal::LogMessageFatal("tensorflow/python/eager/pywrap_tfe_src.cc"
, 4184)
;
4185
4186 return tensorflow::Status::OK();
4187}
4188
4189tensorflow::Status EncodeArgHelperInternal(PyObject* arg,
4190 EncodingContext& context);
4191
4192// This function doesn't set the type of sequence before
4193tensorflow::Status EncodeSequence(PyObject* arg, const char* type,
4194 const char* end_type,
4195 EncodingContext& context) {
4196 tensorflow::Safe_PyObjectPtr arg_seq(
4197 PySequence_Fast(arg, "unable to create seq from list/tuple"));
4198
4199 absl::StrAppend(&context.result.str, type);
4200 int len = PySequence_Fast_GET_SIZE(arg_seq.get())(((((((PyObject*)(arg_seq.get()))->ob_type))->tp_flags &
((1UL << 25))) != 0) ? ((static_cast<void> (0)),
(((PyVarObject*)(arg_seq.get()))->ob_size)) : (((PyVarObject
*)(((static_cast<void> (0)), (PyTupleObject *)(arg_seq.
get()))))->ob_size))
;
4201 PyObject** arg_seq_array = PySequence_Fast_ITEMS(arg_seq.get())(((((((PyObject*)(arg_seq.get()))->ob_type))->tp_flags &
((1UL << 25))) != 0) ? ((PyListObject *)(arg_seq.get()
))->ob_item : ((PyTupleObject *)(arg_seq.get()))->ob_item
)
;
4202 for (int i = 0; i < len; ++i) {
4203 PyObject* item = arg_seq_array[i];
4204 if (item == Py_None(&_Py_NoneStruct)) {
4205 absl::StrAppend(&context.result.str, kNone);
4206 } else {
4207 TF_RETURN_IF_ERROR(EncodeArgHelperInternal(item, context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal(
item, context)); if ((__builtin_expect(!_status.ok(), 0))) return
_status; } while (0)
;
4208 }
4209 }
4210 absl::StrAppend(&context.result.str, end_type);
4211
4212 return tensorflow::Status::OK();
4213}
4214
4215tensorflow::Status EncodeMapping(PyObject* arg, EncodingContext& context) {
4216 tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg));
4217 if (PyList_Sort(keys.get()) == -1) {
4218 return tensorflow::errors::Internal("Unable to sort keys");
4219 }
4220
4221 absl::StrAppend(&context.result.str, kDict);
4222 int len = PyList_Size(keys.get());
4223
4224 for (int i = 0; i < len; i++) {
4225 PyObject* key = PyList_GetItem(keys.get(), i);
4226 TF_RETURN_IF_ERROR(EncodeArgHelperInternal(key, context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal(
key, context)); if ((__builtin_expect(!_status.ok(), 0))) return
_status; } while (0)
;
4227 tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key));
4228 TF_RETURN_IF_ERROR(EncodeArgHelperInternal(value.get(), context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal(
value.get(), context)); if ((__builtin_expect(!_status.ok(), 0
))) return _status; } while (0)
;
4229 }
4230
4231 return tensorflow::Status::OK();
4232}
4233
4234tensorflow::Status EncodeCompositeTensor(PyObject* arg,
4235 EncodingContext& context) {
4236 absl::StrAppend(&context.result.str, kCompositeTensor);
4237 PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec"));
4238 if (type_spec == nullptr) {
4239 return tensorflow::errors::InvalidArgument(
4240 "Error while reading CompositeTensor._type_spec.");
4241 }
4242 context.result.objects.push_back(type_spec);
4243
4244 return tensorflow::Status::OK();
4245}
4246
4247tensorflow::Status EncodeTypeSpec(PyObject* arg, EncodingContext& context) {
4248 absl::StrAppend(&context.result.str, kRaw);
4249 Py_INCREF(arg)_Py_INCREF(((PyObject*)(arg)));
4250 context.result.objects.push_back(arg);
4251 return tensorflow::Status::OK();
4252}
4253
4254tensorflow::Status EncodeAttrs(PyObject* arg, EncodingContext& context) {
4255 absl::StrAppend(&context.result.str, kAttrs);
4256 tensorflow::Safe_PyObjectPtr attrs(
4257 PyObject_GetAttrString(arg, "__attrs_attrs__"));
4258 tensorflow::Safe_PyObjectPtr iter(PyObject_GetIter(attrs.get()));
4259 for (tensorflow::Safe_PyObjectPtr item(PyIter_Next(iter.get())); item;
4260 item.reset(PyIter_Next(iter.get()))) {
4261 tensorflow::Safe_PyObjectPtr name(
4262 PyObject_GetAttrString(item.get(), "name"));
4263 tensorflow::Safe_PyObjectPtr attr_arg(PyObject_GetAttr(arg, name.get()));
4264 TF_RETURN_IF_ERROR(EncodeArgHelperInternal(attr_arg.get(), context))do { ::tensorflow::Status _status = (EncodeArgHelperInternal(
attr_arg.get(), context)); if ((__builtin_expect(!_status.ok(
), 0))) return _status; } while (0)
;
4265 }
4266 absl::StrAppend(&context.result.str, kAttrsEnd);
4267
4268 return tensorflow::Status::OK();
4269}
4270
4271tensorflow::Status EncodeUnidentified(PyObject* arg, EncodingContext& context) {
4272 // We hold a weak reference because cache keys live practically forever, and
4273 // this may leak heavy objects.
4274 PyObject* object = PyWeakref_NewRef(arg, nullptr);
4275 if (object == nullptr) {
4276 PyErr_Clear();
4277 object = arg;
4278 Py_INCREF(object)_Py_INCREF(((PyObject*)(object)));
4279 }
4280
4281 absl::StrAppend(&context.result.str, kRaw);
4282 context.result.objects.push_back(object);
4283 return tensorflow::Status::OK();
4284}
4285
4286tensorflow::Status EncodeArgHelperInternal(PyObject* arg,
4287 EncodingContext& context) {
4288 if (tensorflow::swig::IsTensorSpec(arg)) {
4289 TF_RETURN_IF_ERROR(EncodeTensorOrTensorSpec(arg, true, context))do { ::tensorflow::Status _status = (EncodeTensorOrTensorSpec
(arg, true, context)); if ((__builtin_expect(!_status.ok(), 0
))) return _status; } while (0)
;
4290 } else if (tensorflow::swig::IsTensor(arg)) {
4291 TF_RETURN_IF_ERROR(EncodeTensorOrTensorSpec(arg, false, context))do { ::tensorflow::Status _status = (EncodeTensorOrTensorSpec
(arg, false, context)); if ((__builtin_expect(!_status.ok(), 0
))) return _status; } while (0)
;
4292 } else if (tensorflow::swig::IsOwnedIterator(arg)) {
4293 TF_RETURN_IF_ERROR(EncodeOwnedIterator(arg, context))do { ::tensorflow::Status _status = (EncodeOwnedIterator(arg,
context)); if ((__builtin_expect(!_status.ok(), 0))) return _status
; } while (0)
;
4294 } else if (PyList_Check(arg)((((((PyObject*)(arg))->ob_type))->tp_flags & ((1UL
<< 25))) != 0)
) {
4295 TF_RETURN_IF_ERROR(EncodeSequence(arg, kList, kListEnd, context))do { ::tensorflow::Status _status = (EncodeSequence(arg, kList
, kListEnd, context)); if ((__builtin_expect(!_status.ok(), 0
))) return _status; } while (0)
;
4296 } else if (tensorflow::swig::IsTuple(arg)) {
4297 TF_RETURN_IF_ERROR(EncodeSequence(arg, kTuple, kTupleEnd, context))do { ::tensorflow::Status _status = (EncodeSequence(arg, kTuple
, kTupleEnd, context)); if ((__builtin_expect(!_status.ok(), 0
))) return _status; } while (0)
;
4298 } else if (tensorflow::swig::IsMapping(arg)) {
4299 TF_RETURN_IF_ERROR(EncodeMapping(arg, context))do { ::tensorflow::Status _status = (EncodeMapping(arg, context
)); if ((__builtin_expect(!_status.ok(), 0))) return _status;
} while (0)
;
4300 } else if (tensorflow::swig::IsCompositeTensor(arg)) {
4301 TF_RETURN_IF_ERROR(EncodeCompositeTensor(arg, context))do { ::tensorflow::Status _status = (EncodeCompositeTensor(arg
, context)); if ((__builtin_expect(!_status.ok(), 0))) return
_status; } while (0)
;
4302 } else if (tensorflow::swig::IsTypeSpec(arg)) {
4303 TF_RETURN_IF_ERROR(EncodeTypeSpec(arg, context))do { ::tensorflow::Status _status = (EncodeTypeSpec(arg, context
)); if ((__builtin_expect(!_status.ok(), 0))) return _status;
} while (0)
;
4304 } else if (tensorflow::swig::IsAttrs(arg)) {
4305 TF_RETURN_IF_ERROR(EncodeAttrs(arg, context))do { ::tensorflow::Status _status = (EncodeAttrs(arg, context
)); if ((__builtin_expect(!_status.ok(), 0))) return _status;
} while (0)
;
4306 } else if (tensorflow::swig::IsResourceVariable(arg) &&
4307 context.encode_variable_by_resource_id) {
4308 TF_RETURN_IF_ERROR(EncodeResource(arg, context))do { ::tensorflow::Status _status = (EncodeResource(arg, context
)); if ((__builtin_expect(!_status.ok(), 0))) return _status;
} while (0)
;
4309 } else {
4310 TF_RETURN_IF_ERROR(EncodeUnidentified(arg, context))do { ::tensorflow::Status _status = (EncodeUnidentified(arg, context
)); if ((__builtin_expect(!_status.ok(), 0))) return _status;
} while (0)
;
4311 }
4312
4313 return tensorflow::Status::OK();
4314}
4315
4316tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
4317 EncodingContext& context) {
4318 auto status = EncodeArgHelperInternal(arg, context);
4319 return status;
4320}
4321
4322} // namespace
4323
4324// `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
4325// are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
4326// are used for both performance reasons, as much TensorFlow code specializes
4327// on known shapes to produce slimmer graphs, and correctness, as some
4328// high-level APIs require shapes to be fully-known.
4329//
4330// `include_tensor_ranks_only` allows caching on arguments excluding shape info,
4331// so that a slow path using relaxed shape can rely on a cache key that excludes
4332// shapes.
4333PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only,
4334 bool encode_variable_by_resource_id) {
4335 EncodingContext context;
4336 context.include_tensor_ranks_only = include_tensor_ranks_only;
4337 context.encode_variable_by_resource_id = encode_variable_by_resource_id;
4338 const auto status = TFE_Py_EncodeArgHelper(arg, context);
4339 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1
Calling 'MaybeRaiseExceptionFromStatus'
4340 return nullptr;
4341 }
4342
4343 return context.result.ToPyTuple();
4344}
4345
4346// A method prints incoming messages directly to Python's
4347// stdout using Python's C API. This is necessary in Jupyter notebooks
4348// and colabs where messages to the C stdout don't go to the notebook
4349// cell outputs, but calls to Python's stdout do.
4350void PrintToPythonStdout(const char* msg) {
4351 if (Py_IsInitialized()) {
4352 PyGILState_STATE py_threadstate;
4353 py_threadstate = PyGILState_Ensure();
4354
4355 string string_msg = msg;
4356 // PySys_WriteStdout truncates strings over 1000 bytes, so
4357 // we write the message in chunks small enough to not be truncated.
4358 int CHUNK_SIZE = 900;
4359 auto len = string_msg.length();
4360 for (int i = 0; i < len; i += CHUNK_SIZE) {
4361 PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
4362 }
4363
4364 // Force flushing to make sure print newlines aren't interleaved in
4365 // some colab environments
4366 PyRun_SimpleString("import sys; sys.stdout.flush()")PyRun_SimpleStringFlags("import sys; sys.stdout.flush()", __null
)
;
4367
4368 PyGILState_Release(py_threadstate);
4369 }
4370}
4371
4372// Register PrintToPythonStdout as a log listener, to allow
4373// printing in colabs and jupyter notebooks to work.
4374void TFE_Py_EnableInteractivePythonLogging() {
4375 static bool enabled_interactive_logging = false;
4376 if (!enabled_interactive_logging) {
4377 enabled_interactive_logging = true;
4378 TF_RegisterLogListener(PrintToPythonStdout);
4379 }
4380}
4381
4382namespace {
4383// weak reference to Python Context object currently active
4384PyObject* weak_eager_context = nullptr;
4385} // namespace
4386
4387PyObject* TFE_Py_SetEagerContext(PyObject* py_context) {
4388 Py_XDECREF(weak_eager_context)_Py_XDECREF(((PyObject*)(weak_eager_context)));
4389 weak_eager_context = PyWeakref_NewRef(py_context, nullptr);
4390 if (weak_eager_context == nullptr) {
4391 return nullptr;
4392 }
4393 Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (&
_Py_NoneStruct)
;
4394}
4395
4396PyObject* GetPyEagerContext() {
4397 if (weak_eager_context == nullptr) {
4398 PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set");
4399 return nullptr;
4400 }
4401 PyObject* py_context = PyWeakref_GET_OBJECT(weak_eager_context)((((PyObject*)(((PyWeakReference *)(weak_eager_context))->
wr_object))->ob_refcnt) > 0 ? ((PyWeakReference *)(weak_eager_context
))->wr_object : (&_Py_NoneStruct))
;
4402 if (py_context == Py_None(&_Py_NoneStruct)) {
4403 PyErr_SetString(PyExc_RuntimeError, "Eager context has been destroyed");
4404 return nullptr;
4405 }
4406 Py_INCREF(py_context)_Py_INCREF(((PyObject*)(py_context)));
4407 return py_context;
4408}
4409
4410namespace {
4411
4412// Default values for thread_local_data fields.
4413struct EagerContextThreadLocalDataDefaults {
4414 tensorflow::Safe_PyObjectPtr is_eager;
4415 tensorflow::Safe_PyObjectPtr device_spec;
4416};
4417
4418// Maps each py_eager_context object to its thread_local_data.
4419//
4420// Note: we need to use the python Context object as the key here (and not
4421// its handle object), because the handle object isn't created until the
4422// context is initialized; but thread_local_data is potentially accessed
4423// before then.
4424using EagerContextThreadLocalDataMap = absl::flat_hash_map<
4425 PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>;
4426thread_local EagerContextThreadLocalDataMap*
4427 eager_context_thread_local_data_map = nullptr;
4428
4429// Maps each py_eager_context object to default values.
4430using EagerContextThreadLocalDataDefaultsMap =
4431 absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>;
4432EagerContextThreadLocalDataDefaultsMap*
4433 eager_context_thread_local_data_defaults = nullptr;
4434
4435} // namespace
4436
4437namespace tensorflow {
4438
4439void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
4440 PyObject* is_eager,
4441 PyObject* device_spec) {
4442 DCheckPyGilState();
4443 if (eager_context_thread_local_data_defaults == nullptr) {
4444 absl::LeakCheckDisabler disabler;
4445 eager_context_thread_local_data_defaults =
4446 new EagerContextThreadLocalDataDefaultsMap();
4447 }
4448 if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) {
4449 PyErr_SetString(PyExc_AssertionError,
4450 "MakeEagerContextThreadLocalData may not be called "
4451 "twice on the same eager Context object.");
4452 }
4453
4454 auto& defaults =
4455 (*eager_context_thread_local_data_defaults)[py_eager_context];
4456 Py_INCREF(is_eager)_Py_INCREF(((PyObject*)(is_eager)));
4457 defaults.is_eager.reset(is_eager);
4458 Py_INCREF(device_spec)_Py_INCREF(((PyObject*)(device_spec)));
4459 defaults.device_spec.reset(device_spec);
4460}
4461
4462EagerContextThreadLocalData* GetEagerContextThreadLocalData(
4463 PyObject* py_eager_context) {
4464 if (eager_context_thread_local_data_defaults == nullptr) {
4465 PyErr_SetString(PyExc_AssertionError,
4466 "MakeEagerContextThreadLocalData must be called "
4467 "before GetEagerContextThreadLocalData.");
4468 return nullptr;
4469 }
4470 auto defaults =
4471 eager_context_thread_local_data_defaults->find(py_eager_context);
4472 if (defaults == eager_context_thread_local_data_defaults->end()) {
4473 PyErr_SetString(PyExc_AssertionError,
4474 "MakeEagerContextThreadLocalData must be called "
4475 "before GetEagerContextThreadLocalData.");
4476 return nullptr;
4477 }
4478
4479 if (eager_context_thread_local_data_map == nullptr) {
4480 absl::LeakCheckDisabler disabler;
4481 eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
4482 }
4483 auto& thread_local_data =
4484 (*eager_context_thread_local_data_map)[py_eager_context];
4485
4486 if (!thread_local_data) {
4487 thread_local_data.reset(new EagerContextThreadLocalData());
4488
4489 Safe_PyObjectPtr is_eager(
4490 PyObject_CallFunctionObjArgs(defaults->second.is_eager.get(), nullptr));
4491 if (!is_eager) return nullptr;
4492 thread_local_data->is_eager = PyObject_IsTrue(is_eager.get());
4493
4494#if PY_MAJOR_VERSION3 >= 3
4495 PyObject* scope_name = PyUnicode_FromString("");
4496#else
4497 PyObject* scope_name = PyString_FromString("");
4498#endif
4499 thread_local_data->scope_name.reset(scope_name);
4500
4501#if PY_MAJOR_VERSION3 >= 3
4502 PyObject* device_name = PyUnicode_FromString("");
4503#else
4504 PyObject* device_name = PyString_FromString("");
4505#endif
4506 thread_local_data->device_name.reset(device_name);
4507
4508 Py_INCREF(defaults->second.device_spec.get())_Py_INCREF(((PyObject*)(defaults->second.device_spec.get()
)))
;
4509 thread_local_data->device_spec.reset(defaults->second.device_spec.get());
4510
4511 Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct))));
4512 thread_local_data->function_call_options.reset(Py_None(&_Py_NoneStruct));
4513
4514 Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct))));
4515 thread_local_data->executor.reset(Py_None(&_Py_NoneStruct));
4516
4517 thread_local_data->op_callbacks.reset(PyList_New(0));
4518 }
4519 return thread_local_data.get();
4520}
4521
4522void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) {
4523 DCheckPyGilState();
4524 if (eager_context_thread_local_data_defaults) {
4525 eager_context_thread_local_data_defaults->erase(py_eager_context);
4526 }
4527 if (eager_context_thread_local_data_map) {
4528 eager_context_thread_local_data_map->erase(py_eager_context);
4529 }
4530}
4531
4532} // namespace tensorflow

/opt/pyrefcon/lib/pyrefcon/models/models/PyUnicode_FromString.model

1#ifndef PyUnicode_FromString
2struct _object;
3typedef struct _object PyObject;
4PyObject* clang_analyzer_PyObject_New_Reference();
5PyObject *PyUnicode_FromString(const char *u) {
6 return clang_analyzer_PyObject_New_Reference();
10
Setting reference count to 1
7}
8#else
9#warning "API PyUnicode_FromString is defined as a macro."
10#endif