1313import shutil
1414import sys
1515import subprocess
16+ import tempfile
1617import unittest
1718import unittest .mock
1819
3940
4041OUTPUT_PATH .mkdir (parents = True , exist_ok = True )
4142
43+ MOCK_MAXIMUM_PROTOBUF = 1 << 20
44+
4245
4346def const_model () -> onnx .ModelProto :
4447 # Note: data_path must be relative to model_file
@@ -87,7 +90,26 @@ def linear_model() -> onnx.ModelProto:
8790 return onnx_model
8891
8992
90- ALL_MODELS = [const_model , linear_model ]
93+ def path_based_shape_inference_model () -> onnx .ModelProto :
94+ # Create a model with a serialized form that's large enough to require
95+ # path-based shape inference.
96+ dtype = numpy .float32
97+ byte_size = numpy .dtype (dtype ).itemsize
98+ tensor_size = MOCK_MAXIMUM_PROTOBUF // byte_size + 1
99+ large_tensor = numpy .random .rand (tensor_size ).astype (dtype )
100+ assert large_tensor .nbytes > MOCK_MAXIMUM_PROTOBUF
101+ node1 = make_node (
102+ "Constant" ,
103+ [],
104+ ["large_const" ],
105+ value = numpy_helper .from_array (large_tensor , name = "large_const" ),
106+ )
107+ X = make_tensor_value_info ("large_const" , TensorProto .FLOAT , [tensor_size ])
108+ graph = make_graph ([node1 ], "large_const_graph" , [], [X ])
109+ return make_model (graph )
110+
111+
112+ ALL_MODELS = [const_model , linear_model , path_based_shape_inference_model ]
91113
92114
93115class CommandLineTest (unittest .TestCase ):
@@ -110,7 +132,12 @@ def run_model_intern(self, onnx_model: onnx.ModelProto, model_name: str):
110132 args = __main__ .parse_arguments ([str (model_file ), "-o" , str (mlir_file )])
111133 __main__ .main (args )
112134
113- def run_model_extern (self , onnx_model : onnx .ModelProto , model_name : str ):
135+ def run_model_extern (
136+ self ,
137+ onnx_model : onnx .ModelProto ,
138+ model_name : str ,
139+ extra_args : list [str ] | None = None ,
140+ ):
114141 run_path = self .get_run_path (model_name )
115142 model_file = run_path / f"{ model_name } -e.onnx"
116143 mlir_file = run_path / f"{ model_name } -e.torch.mlir"
@@ -127,20 +154,41 @@ def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str):
127154 onnx .save (onnx_model , model_file )
128155 temp_dir = run_path / "temp"
129156 temp_dir .mkdir (exist_ok = True )
130- args = __main__ .parse_arguments (
131- [
132- str (model_file ),
133- "-o" ,
134- str (mlir_file ),
135- "--keep-temps" ,
136- "--temp-dir" ,
137- str (temp_dir ),
138- "--data-dir" ,
139- str (run_path ),
140- ]
141- )
157+ raw_args = [
158+ str (model_file ),
159+ "-o" ,
160+ str (mlir_file ),
161+ "--keep-temps" ,
162+ "--temp-dir" ,
163+ str (temp_dir ),
164+ "--data-dir" ,
165+ str (run_path ),
166+ ]
167+ if extra_args :
168+ raw_args .extend (extra_args )
169+ args = __main__ .parse_arguments (raw_args )
142170 __main__ .main (args )
143171
172+ @unittest .mock .patch ("onnx.checker.MAXIMUM_PROTOBUF" , MOCK_MAXIMUM_PROTOBUF )
173+ def run_model_explicit_temp_implicit_data (
174+ self , onnx_model : onnx .ModelProto , model_name : str
175+ ):
176+ run_path = self .get_run_path (model_name )
177+ model_file = run_path / f"{ model_name } -explicit_temp_implicit_data.onnx"
178+ mlir_file = run_path / f"{ model_name } -explicit_temp_implicit_data.torch.mlir"
179+ onnx .save (onnx_model , model_file )
180+ with tempfile .TemporaryDirectory (dir = run_path ) as temp_dir :
181+ args = __main__ .parse_arguments (
182+ [
183+ str (model_file ),
184+ "-o" ,
185+ str (mlir_file ),
186+ "--temp-dir" ,
187+ str (temp_dir ),
188+ ]
189+ )
190+ __main__ .main (args )
191+
144192 def test_all (self ):
145193 for model_func in ALL_MODELS :
146194 model_name = model_func .__name__
@@ -150,6 +198,12 @@ def test_all(self):
150198 self .run_model_intern (model , model_name )
151199 with self .subTest ("External data" ):
152200 self .run_model_extern (model , model_name )
201+ with self .subTest ("External data, raw model modified" ):
202+ self .run_model_extern (
203+ model , model_name , extra_args = ["--clear-domain" ]
204+ )
205+ with self .subTest ("Explicit temp dir, implicit data dir" ):
206+ self .run_model_explicit_temp_implicit_data (model , model_name )
153207
154208
155209if __name__ == "__main__" :
0 commit comments