unofficial mirror of guix-patches@gnu.org 
 help / color / mirror / code / Atom feed
blob b30094de0996ff4d77c2a98808a3c4c638de0040 6634 bytes (raw)
name: gnu/packages/patches/python-pytorch-fix-codegen.patch 	 # note: path name is non-authoritative(*)

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
 
This patch fixes some scripts for generating source files.  For
gen_jit_decompositions.py, gen_mobile_upgraders.py and
gen_jit_shape_functions.py, which depend on the compiled PyTorch library, the
option to generate "dummy" source files is added for the initial build, which
is later corrected.  codegen_external.py is patched to avoid duplicate
functions and add the static keyword as in the existing generated file.

diff --git a/tools/gen_flatbuffers.sh b/tools/gen_flatbuffers.sh
index cc0263d..ac34e84 100644
--- a/tools/gen_flatbuffers.sh
+++ b/tools/gen_flatbuffers.sh
@@ -1,13 +1,13 @@
 #!/bin/bash
 ROOT=$(pwd)
-FF_LOCATION="$ROOT/third_party/flatbuffers"
-cd "$FF_LOCATION" || exit
-mkdir build
-cd build || exit
-cmake ..
-cmake --build . --target flatc
-mkdir -p "$ROOT/build/torch/csrc/jit/serialization"
-./flatc --cpp --gen-mutable --scoped-enums \
+#FF_LOCATION="$ROOT/third_party/flatbuffers"
+#cd "$FF_LOCATION" || exit
+#mkdir build
+#cd build || exit
+#cmake ..
+#cmake --build . --target flatc
+#mkdir -p "$ROOT/build/torch/csrc/jit/serialization"
+flatc --cpp --gen-mutable --scoped-enums \
      -o "$ROOT/torch/csrc/jit/serialization" \
      -c "$ROOT/torch/csrc/jit/serialization/mobile_bytecode.fbs"
 echo '// @generated' >> "$ROOT/torch/csrc/jit/serialization/mobile_bytecode_generated.h"
diff --git a/torch/csrc/jit/tensorexpr/codegen_external.py b/torch/csrc/jit/tensorexpr/codegen_external.py
index 5dcf1b2..0e20b0c 100644
--- a/torch/csrc/jit/tensorexpr/codegen_external.py
+++ b/torch/csrc/jit/tensorexpr/codegen_external.py
@@ -21,9 +21,14 @@ def gen_external(native_functions_path, tags_path, external_path):
     native_functions = parse_native_yaml(native_functions_path, tags_path)
     func_decls = []
     func_registrations = []
-    for func in native_functions:
+    done_names = set()
+    for func in native_functions[0]:
         schema = func.func
         name = schema.name.name.base
+        if name in done_names:
+            continue
+        else:
+            done_names.add(name)
         args = schema.arguments
         # Only supports extern calls for functions with out variants
         if not schema.is_out_fn():
@@ -63,7 +68,7 @@ def gen_external(native_functions_path, tags_path, external_path):
 
         # print(tensor_decls, name, arg_names)
         func_decl = f"""\
-void nnc_aten_{name}(
+static void nnc_aten_{name}(
     int64_t bufs_num,
     void** buf_data,
     int64_t* buf_ranks,
diff --git a/torchgen/decompositions/gen_jit_decompositions.py b/torchgen/decompositions/gen_jit_decompositions.py
index 7a0024f..6b2445f 100644
--- a/torchgen/decompositions/gen_jit_decompositions.py
+++ b/torchgen/decompositions/gen_jit_decompositions.py
@@ -1,8 +1,12 @@
 #!/usr/bin/env python3
 import os
 from pathlib import Path
+import sys
 
-from torch.jit._decompositions import decomposition_table
+if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+    from torch.jit._decompositions import decomposition_table
+else:
+    decomposition_table = {}
 
 # from torchgen.code_template import CodeTemplate
 
@@ -85,7 +89,7 @@ def write_decomposition_util_file(path: str) -> None:
 
 
 def main() -> None:
-    pytorch_dir = Path(__file__).resolve().parents[3]
+    pytorch_dir = Path(__file__).resolve().parents[2]
     upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime"
     write_decomposition_util_file(str(upgrader_path))
 
diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py
index 2907076..6866332 100644
--- a/torchgen/operator_versions/gen_mobile_upgraders.py
+++ b/torchgen/operator_versions/gen_mobile_upgraders.py
@@ -3,10 +3,12 @@ import os
 from enum import Enum
 from operator import itemgetter
 from pathlib import Path
+import sys
 from typing import Any, Dict, List
 
-import torch
-from torch.jit.generate_bytecode import generate_upgraders_bytecode
+if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+    import torch
+    from torch.jit.generate_bytecode import generate_upgraders_bytecode
 
 from torchgen.code_template import CodeTemplate
 from torchgen.operator_versions.gen_mobile_upgraders_constant import (
@@ -263,7 +265,10 @@ def construct_register_size(register_size_from_yaml: int) -> str:
 def construct_version_maps(
     upgrader_bytecode_function_to_index_map: Dict[str, Any]
 ) -> str:
-    version_map = torch._C._get_operator_version_map()
+    if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+        version_map = torch._C._get_operator_version_map()
+    else:
+        version_map = {}
     sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0))  # type: ignore[no-any-return]
     sorted_version_map = dict(sorted_version_map_)
 
@@ -379,7 +384,10 @@ def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
 
 
 def main() -> None:
-    upgrader_list = generate_upgraders_bytecode()
+    if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+        upgrader_list = generate_upgraders_bytecode()
+    else:
+        upgrader_list = []
     sorted_upgrader_list = sort_upgrader(upgrader_list)
     for up in sorted_upgrader_list:
         print("after sort upgrader : ", next(iter(up)))
diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py
index bdfd5c7..72b237a 100644
--- a/torchgen/shape_functions/gen_jit_shape_functions.py
+++ b/torchgen/shape_functions/gen_jit_shape_functions.py
@@ -18,16 +18,20 @@ you are in the root directory of the Pytorch git repo"""
 if not file_path.exists():
     raise Exception(err_msg)  # noqa: TRY002
 
-spec = importlib.util.spec_from_file_location(module_name, file_path)
-assert spec is not None
-module = importlib.util.module_from_spec(spec)
-sys.modules[module_name] = module
-assert spec.loader is not None
-assert module is not None
-spec.loader.exec_module(module)
-
-bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
-shape_compute_graph_mapping = module.shape_compute_graph_mapping
+if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+    spec = importlib.util.spec_from_file_location(module_name, file_path)
+    assert spec is not None
+    module = importlib.util.module_from_spec(spec)
+    sys.modules[module_name] = module
+    assert spec.loader is not None
+    assert module is not None
+    spec.loader.exec_module(module)
+
+    bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
+    shape_compute_graph_mapping = module.shape_compute_graph_mapping
+else:
+    bounded_compute_graph_mapping = {}
+    shape_compute_graph_mapping = {}
 
 
 SHAPE_HEADER = r"""

debug log:

solving b30094de09 ...
found b30094de09 in https://git.savannah.gnu.org/cgit/guix.git

(*) Git path names are given by the tree(s) the blob belongs to.
    Blobs themselves have no identifier aside from the hash of its contents.^

Code repositories for project(s) associated with this public inbox

	https://git.savannah.gnu.org/cgit/guix.git

This is a public inbox, see mirroring instructions
for how to clone and mirror all data and code used for this inbox;
as well as URLs for read-only IMAP folder(s) and NNTP newsgroup(s).