diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java index fe19f991e574..7b465d10051c 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java @@ -36,8 +36,9 @@ import java.sql.Connection; import java.sql.SQLException; import java.sql.Statement; +import java.util.Arrays; +import java.util.List; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference; @@ -48,6 +49,11 @@ public class AINodeConcurrentForecastIT { private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class); + private static final List MODEL_LIST = + Arrays.asList( + new AINodeTestUtils.FakeModelInfo("sundial", "sundial", "builtin", "active"), + new AINodeTestUtils.FakeModelInfo("timer_xl", "timer", "builtin", "active")); + private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = "SELECT * FROM FORECAST(model_id=>'%s', targets=>(SELECT time,s FROM root.AI) ORDER BY time, output_length=>%d)"; @@ -78,7 +84,7 @@ private static void prepareDataForTableModel() throws SQLException { @Test public void concurrentGPUForecastTest() throws SQLException, InterruptedException { - for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) { + for (AINodeTestUtils.FakeModelInfo modelInfo : MODEL_LIST) { concurrentGPUForecastTest(modelInfo); } } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index 697b36712759..5d86e7c588f5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -116,6 +116,7 @@ def __repr__(self): "AutoConfig": "configuration_timer.TimerConfig", "AutoModelForCausalLM": "modeling_timer.TimerForPrediction", }, + _transformers_registered=True, ), "sundial": ModelInfo( model_id="sundial", @@ -128,5 +129,6 @@ def __repr__(self): "AutoConfig": "configuration_sundial.SundialConfig", "AutoModelForCausalLM": "modeling_sundial.SundialForPrediction", }, + _transformers_registered=True, ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index a79371d2e791..ee09cfd75bb3 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -196,9 +196,12 @@ def _callback_model_download_result(self, future, model_id: str): if os.path.exists(config_path): with open(config_path, "r", encoding="utf-8") as f: config = json.load(f) - if model_info.model_type == "": - model_info.model_type = config.get("model_type", "") - model_info.auto_map = config.get("auto_map", None) + model_info.model_type = config.get( + "model_type", model_info.model_type + ) + model_info.auto_map = config.get( + "auto_map", model_info.auto_map + ) logger.info( f"Model {model_id} downloaded successfully and is ready to use." )