diff --git a/demos/common/export_models/export_model.py b/demos/common/export_models/export_model.py index 341a770c77..af5f1869aa 100644 --- a/demos/common/export_models/export_model.py +++ b/demos/common/export_models/export_model.py @@ -448,7 +448,6 @@ def export_text_generation_model(model_repository_path, source_model, model_name raise ValueError("max_prompt_len is only supported for NPU target device") if task_parameters['max_prompt_len'] <= 0: raise ValueError("max_prompt_len should be a positive integer") - plugin_config['MAX_PROMPT_LEN'] = task_parameters['max_prompt_len'] if task_parameters['ov_cache_dir'] is not None: plugin_config['CACHE_DIR'] = task_parameters['ov_cache_dir'] @@ -459,6 +458,16 @@ def export_text_generation_model(model_repository_path, source_model, model_name if "HETERO" in task_parameters['target_device']: plugin_config['MODEL_DISTRIBUTION_POLICY'] = 'PIPELINE_PARALLEL' + if task_parameters['target_device'] == 'NPU': + max_prompt_len = task_parameters['max_prompt_len'] + npu_properties = {} + if max_prompt_len is not None: + npu_properties['MAX_PROMPT_LEN'] = max_prompt_len + if task_parameters['enable_prefix_caching']: + npu_properties['NPUW_LLM_ENABLE_PREFIX_CACHING'] = True + device_properties = { "NPU": npu_properties } + plugin_config['DEVICE_PROPERTIES'] = device_properties + plugin_config_str = json.dumps(plugin_config) task_parameters['plugin_config'] = plugin_config_str diff --git a/src/graph_export/graph_export.cpp b/src/graph_export/graph_export.cpp index 41d7f9fabb..dadbd57777 100644 --- a/src/graph_export/graph_export.cpp +++ b/src/graph_export/graph_export.cpp @@ -597,12 +597,43 @@ std::variant, Status> GraphExport::createPluginString rapidjson::Value value; value.SetBool(pluginConfig.useNpuPrefixCaching.value()); auto itr = d.FindMember("NPUW_LLM_ENABLE_PREFIX_CACHING"); - if (itr != d.MemberEnd()) { + bool foundInTopLevel = (itr != d.MemberEnd()); + + bool foundInDeviceProperties = false; + if (!foundInTopLevel) { + auto devicePropsItr = d.FindMember("DEVICE_PROPERTIES"); + if (devicePropsItr != d.MemberEnd() && devicePropsItr->value.IsObject()) { + auto npuItr = devicePropsItr->value.FindMember("NPU"); + if (npuItr != devicePropsItr->value.MemberEnd() && npuItr->value.IsObject()) { + auto npuPrefixCachingItr = npuItr->value.FindMember("NPUW_LLM_ENABLE_PREFIX_CACHING"); + foundInDeviceProperties = (npuPrefixCachingItr != npuItr->value.MemberEnd()); + } + } + } + + if (foundInTopLevel || foundInDeviceProperties) { return Status(StatusCode::PLUGIN_CONFIG_CONFLICTING_PARAMETERS, "Doubled NPUW_LLM_ENABLE_PREFIX_CACHING parameter in plugin config."); } - d.AddMember("NPUW_LLM_ENABLE_PREFIX_CACHING", value, d.GetAllocator()); + + // Add to nested structure DEVICE_PROPERTIES.NPU + auto devicePropsItr = d.FindMember("DEVICE_PROPERTIES"); + if (devicePropsItr == d.MemberEnd()) { + rapidjson::Value deviceProps(rapidjson::kObjectType); + d.AddMember("DEVICE_PROPERTIES", deviceProps, d.GetAllocator()); + devicePropsItr = d.FindMember("DEVICE_PROPERTIES"); + } + + auto npuItr = devicePropsItr->value.FindMember("NPU"); + if (npuItr == devicePropsItr->value.MemberEnd()) { + rapidjson::Value npuObj(rapidjson::kObjectType); + devicePropsItr->value.AddMember("NPU", npuObj, d.GetAllocator()); + npuItr = devicePropsItr->value.FindMember("NPU"); + } + + npuItr->value.AddMember("NPUW_LLM_ENABLE_PREFIX_CACHING", value, d.GetAllocator()); configNotEmpty = true; } + if (configNotEmpty) { // Serialize the document to a JSON string rapidjson::StringBuffer buffer;