Skip to content

Commit b54d281

Browse files
committed
Fix #4210.
Fix data_dir assignment and add tests for 2 new cases: 1. When `--temp-dir` is set and `--data-dir` is not 2. When `raw_model_modified` and model uses external data
1 parent 750bdb5 commit b54d281

File tree

2 files changed

+72
-19
lines changed

2 files changed

+72
-19
lines changed

python/torch_mlir/tools/import_onnx/__main__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,7 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
7171
# temp directory instead of a hard-coded path in order to avoid data races
7272
# by default.
7373
input_dir = os.path.dirname(os.path.abspath(args.input_file))
74-
temp_dir = (
75-
Path(input_dir if args.temp_dir is None else args.temp_dir)
76-
/ "onnx-importer-temp"
77-
)
74+
temp_dir = Path(args.temp_dir or input_dir) / "onnx-importer-temp"
7875
shutil.rmtree(temp_dir, ignore_errors=True)
7976
temp_dir.mkdir(exist_ok=True)
8077

@@ -121,10 +118,13 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
121118
# onnx.shape_inference.infer_shapes_path(temp_raw_file, temp_inferred_file)
122119
# inferred_model = onnx.load(temp_inferred_file)
123120

121+
data_dir = Path(args.data_dir or input_dir)
122+
124123
# Model is too big for in-memory inference: do file-based shape inference
125124
# to a temp file.
126125
# First need to save as model when it has been changed (e.g. version conversion).
127126
if raw_model_modified:
127+
data_dir = temp_dir
128128
temp_raw_file = temp_dir / "raw.onnx"
129129
onnx.save(raw_model, temp_raw_file, save_as_external_data=True)
130130
temp_inferred_file = temp_dir / "inferred.onnx"
@@ -146,7 +146,6 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
146146

147147
# Load the temp file and the external data.
148148
inferred_model = onnx.load(temp_inferred_file, load_external_data=False)
149-
data_dir = Path(input_dir if args.temp_dir is None else args.data_dir)
150149
onnx.load_external_data_for_model(inferred_model, str(data_dir))
151150

152151
# Remove the inferred shape file unless asked to keep it

test/python/onnx_importer/command_line_test.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import shutil
1414
import sys
1515
import subprocess
16+
import tempfile
1617
import unittest
1718
import unittest.mock
1819

@@ -39,6 +40,8 @@
3940

4041
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
4142

43+
MOCK_MAXIMUM_PROTOBUF = 1 << 20
44+
4245

4346
def const_model() -> onnx.ModelProto:
4447
# Note: data_path must be relative to model_file
@@ -87,7 +90,26 @@ def linear_model() -> onnx.ModelProto:
8790
return onnx_model
8891

8992

90-
ALL_MODELS = [const_model, linear_model]
93+
def path_based_shape_inference_model() -> onnx.ModelProto:
94+
# Create a model with a serialized form that's large enough to require
95+
# path-based shape inference.
96+
dtype = numpy.float32
97+
byte_size = numpy.dtype(dtype).itemsize
98+
tensor_size = MOCK_MAXIMUM_PROTOBUF // byte_size + 1
99+
large_tensor = numpy.random.rand(tensor_size).astype(dtype)
100+
assert large_tensor.nbytes > MOCK_MAXIMUM_PROTOBUF
101+
node1 = make_node(
102+
"Constant",
103+
[],
104+
["large_const"],
105+
value=numpy_helper.from_array(large_tensor, name="large_const"),
106+
)
107+
X = make_tensor_value_info("large_const", TensorProto.FLOAT, [tensor_size])
108+
graph = make_graph([node1], "large_const_graph", [], [X])
109+
return make_model(graph)
110+
111+
112+
ALL_MODELS = [const_model, linear_model, path_based_shape_inference_model]
91113

92114

93115
class CommandLineTest(unittest.TestCase):
@@ -110,7 +132,12 @@ def run_model_intern(self, onnx_model: onnx.ModelProto, model_name: str):
110132
args = __main__.parse_arguments([str(model_file), "-o", str(mlir_file)])
111133
__main__.main(args)
112134

113-
def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str):
135+
def run_model_extern(
136+
self,
137+
onnx_model: onnx.ModelProto,
138+
model_name: str,
139+
extra_args: list[str] | None = None,
140+
):
114141
run_path = self.get_run_path(model_name)
115142
model_file = run_path / f"{model_name}-e.onnx"
116143
mlir_file = run_path / f"{model_name}-e.torch.mlir"
@@ -127,20 +154,41 @@ def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str):
127154
onnx.save(onnx_model, model_file)
128155
temp_dir = run_path / "temp"
129156
temp_dir.mkdir(exist_ok=True)
130-
args = __main__.parse_arguments(
131-
[
132-
str(model_file),
133-
"-o",
134-
str(mlir_file),
135-
"--keep-temps",
136-
"--temp-dir",
137-
str(temp_dir),
138-
"--data-dir",
139-
str(run_path),
140-
]
141-
)
157+
raw_args = [
158+
str(model_file),
159+
"-o",
160+
str(mlir_file),
161+
"--keep-temps",
162+
"--temp-dir",
163+
str(temp_dir),
164+
"--data-dir",
165+
str(run_path),
166+
]
167+
if extra_args:
168+
raw_args.extend(extra_args)
169+
args = __main__.parse_arguments(raw_args)
142170
__main__.main(args)
143171

172+
@unittest.mock.patch("onnx.checker.MAXIMUM_PROTOBUF", MOCK_MAXIMUM_PROTOBUF)
173+
def run_model_explicit_temp_implicit_data(
174+
self, onnx_model: onnx.ModelProto, model_name: str
175+
):
176+
run_path = self.get_run_path(model_name)
177+
model_file = run_path / f"{model_name}-explicit_temp_implicit_data.onnx"
178+
mlir_file = run_path / f"{model_name}-explicit_temp_implicit_data.torch.mlir"
179+
onnx.save(onnx_model, model_file)
180+
with tempfile.TemporaryDirectory(dir=run_path) as temp_dir:
181+
args = __main__.parse_arguments(
182+
[
183+
str(model_file),
184+
"-o",
185+
str(mlir_file),
186+
"--temp-dir",
187+
str(temp_dir),
188+
]
189+
)
190+
__main__.main(args)
191+
144192
def test_all(self):
145193
for model_func in ALL_MODELS:
146194
model_name = model_func.__name__
@@ -150,6 +198,12 @@ def test_all(self):
150198
self.run_model_intern(model, model_name)
151199
with self.subTest("External data"):
152200
self.run_model_extern(model, model_name)
201+
with self.subTest("External data, raw model modified"):
202+
self.run_model_extern(
203+
model, model_name, extra_args=["--clear-domain"]
204+
)
205+
with self.subTest("Explicit temp dir, implicit data dir"):
206+
self.run_model_explicit_temp_implicit_data(model, model_name)
153207

154208

155209
if __name__ == "__main__":

0 commit comments

Comments
 (0)