Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions backends/vulkan/cmake/ShaderLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ function(gen_vulkan_shader_lib_cpp shaders_path)
endif()
endif()

# Add nthreads argument for shader compilation
if(DEFINED EXECUTORCH_VULKAN_SHADER_COMPILE_NTHREADS)
list(APPEND GEN_SPV_ARGS "--nthreads"
"${EXECUTORCH_VULKAN_SHADER_COMPILE_NTHREADS}"
)
endif()

add_custom_command(
COMMENT "Generating Vulkan Compute Shaders"
OUTPUT ${VULKAN_SHADERGEN_OUT_PATH}/spv.cpp
Expand Down
30 changes: 25 additions & 5 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,7 @@ def generateSPV( # noqa: C901
output_dir: str,
cache_dir: Optional[str] = None,
force_rebuild: bool = False,
nthreads: int = -1,
) -> Dict[str, str]:
# The key of this dictionary is the full path to a generated source file. The
# value is a tuple that contains 3 entries:
Expand Down Expand Up @@ -1118,11 +1119,21 @@ def compile_spirv(shader_paths_pair) -> Tuple[str, str]:
gen_file_meta[gen_out_path] = (file_changed, include_list)

# Parallelize SPIR-V compilation to optimize build time
with ThreadPool(os.cpu_count()) as pool:
for spv_out_path, glsl_out_path in pool.map(
compile_spirv, self.output_file_map.items()
):
# Determine number of threads: -1 means use all CPU cores, 1 means sequential
num_processes = os.cpu_count() if nthreads == -1 else nthreads

if num_processes == 1:
# Sequential compilation (single-threaded)
for shader_pair in self.output_file_map.items():
spv_out_path, glsl_out_path = compile_spirv(shader_pair)
spv_to_glsl_map[spv_out_path] = glsl_out_path
else:
# Parallel compilation
with ThreadPool(num_processes) as pool:
for spv_out_path, glsl_out_path in pool.map(
compile_spirv, self.output_file_map.items()
):
spv_to_glsl_map[spv_out_path] = glsl_out_path

return spv_to_glsl_map

Expand Down Expand Up @@ -1443,6 +1454,12 @@ def main(argv: List[str]) -> int:
parser.add_argument(
"--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs"
)
parser.add_argument(
"--nthreads",
type=int,
default=-1,
help="Number of threads for shader compilation. -1 (default) uses all available CPU cores, 1 uses sequential compilation.",
)
options = parser.parse_args()

env = DEFAULT_ENV
Expand Down Expand Up @@ -1477,7 +1494,10 @@ def main(argv: List[str]) -> int:
replace_u16vecn=options.replace_u16vecn,
)
output_spv_files = shader_generator.generateSPV(
options.output_path, options.tmp_dir_path, options.force_rebuild
options.output_path,
options.tmp_dir_path,
options.force_rebuild,
options.nthreads,
)

genCppFiles(
Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,15 @@ def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False, no_volk = Fal
for target, subpath in spv_filegroups.items():
glsl_paths.append("$(location {})/{}".format(target, subpath))

nthreads = read_config("etvk", "shader_compile_nthreads", "-1")

genrule_cmd = (
"$(exe {}) ".format(gen_vulkan_spv_target) +
"--glsl-paths {} ".format(" ".join(glsl_paths)) +
"--output-path $OUT " +
"--glslc-path=$(exe {}) ".format(glslc_path) +
"--tmp-dir-path=shader_cache " +
"--nthreads {} ".format(nthreads) +
("-f " if read_config("etvk", "force_shader_rebuild", "0") == "1" else " ") +
select({
"DEFAULT": "",
Expand Down
Loading