From 9868b417676e4344c4cf0b9944a41e835379c6ee Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Sun, 21 Dec 2025 10:07:21 +0800 Subject: [PATCH 1/4] finish --- .../it/env/cluster/node/AINodeWrapper.java | 2 +- .../ainode/it/AINodeCallInferenceIT.java | 24 ++-------- .../iotdb/ainode/it/AINodeForecastIT.java | 16 ++----- .../iotdb/ainode/it/AINodeModelManageIT.java | 48 ++++++++++++------- .../iotdb/ainode/utils/AINodeTestUtils.java | 44 +++++++++++++++++ .../iotdb/ainode/core/model/model_info.py | 12 +++-- .../iotdb/ainode/core/model/model_storage.py | 17 ++++--- .../src/main/thrift/ainode.thrift | 4 +- 8 files changed, 99 insertions(+), 68 deletions(-) diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java index 34fd7e85240cf..15c2e4761dda8 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java @@ -60,7 +60,7 @@ public class AINodeWrapper extends AbstractNodeWrapper { public static final String CONFIG_PATH = "conf"; public static final String SCRIPT_PATH = "sbin"; public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/builtin"; - public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights"; + public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models"; private void replaceAttribute(String[] keys, String[] values, String filePath) { Properties props = new Properties(); diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java index 44e280eca169b..3131e398059e6 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java @@ -40,21 +40,12 @@ import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; -import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree; @RunWith(IoTDBTestRunner.class) @Category({AIClusterIT.class}) public class AINodeCallInferenceIT { - private static final String[] WRITE_SQL_IN_TREE = - new String[] { - "CREATE DATABASE root.AI", - "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", - }; - private static final String CALL_INFERENCE_SQL_TEMPLATE = "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)"; private static final int DEFAULT_INPUT_LENGTH = 256; @@ -64,16 +55,7 @@ public class AINodeCallInferenceIT { public static void setUp() throws Exception { // Init 1C1D1A cluster environment EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareData(WRITE_SQL_IN_TREE); - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", - i, (float) i, (double) i, i, i)); - } - } + prepareDataInTree(); } @AfterClass @@ -91,7 +73,7 @@ public void callInferenceTest() throws SQLException { } } - public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) + public static void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { // Invoke call inference for specified models, there should exist result. for (int i = 0; i < 4; i++) { diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java index bb0de13ed4969..c2114ac949954 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java @@ -39,6 +39,7 @@ import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable; @RunWith(IoTDBTestRunner.class) @Category({AIClusterIT.class}) @@ -58,18 +59,7 @@ public class AINodeForecastIT { public static void setUp() throws Exception { // Init 1C1D1A cluster environment EnvFactory.getEnv().initClusterEnvironment(1, 1); - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - statement.execute("CREATE DATABASE db"); - statement.execute( - "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)"); - for (int i = 0; i < 5760; i++) { - statement.execute( - String.format( - "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", - i, (float) i, (double) i, i, i)); - } - } + prepareDataInTable(); } @AfterClass @@ -87,7 +77,7 @@ public void forecastTableFunctionTest() throws SQLException { } } - public void forecastTableFunctionTest( + public static void forecastTableFunctionTest( Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { // Invoke forecast table function for specified models, there should exist result. for (int i = 0; i < 4; i++) { diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index 3315617e7fda8..6b2cffd0636fd 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -39,8 +39,12 @@ import java.sql.Statement; import java.util.concurrent.TimeUnit; +import static org.apache.iotdb.ainode.it.AINodeCallInferenceIT.callInferenceTest; +import static org.apache.iotdb.ainode.it.AINodeForecastIT.forecastTableFunctionTest; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -54,6 +58,8 @@ public class AINodeModelManageIT { public static void setUp() throws Exception { // Init 1C1D1A cluster environment EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareDataInTree(); + prepareDataInTable(); } @AfterClass @@ -61,47 +67,51 @@ public static void tearDown() throws Exception { EnvFactory.getEnv().cleanClusterEnvironment(); } - // @Test + @Test public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - userDefinedModelManagementTest(statement); + registerUserDefinedModel(statement); + callInferenceTest( + statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active")); + dropUserDefinedModel(statement); } } - // @Test + @Test public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { - userDefinedModelManagementTest(statement); + registerUserDefinedModel(statement); + forecastTableFunctionTest( + statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active")); + dropUserDefinedModel(statement); } } - private void userDefinedModelManagementTest(Statement statement) + private void registerUserDefinedModel(Statement statement) throws SQLException, InterruptedException { final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'"; - final String registerSql = "create model operationTest using uri \"" + "\""; - final String showSql = "SHOW MODELS operationTest"; - final String dropSql = "DROP MODEL operationTest"; - + final String registerSql = "create model user_chronos using uri \"file:///data/chronos2\""; + final String showSql = "SHOW MODELS user_chronos"; statement.execute(alterConfigSQL); statement.execute(registerSql); boolean loading = true; - int count = 0; for (int retryCnt = 0; retryCnt < 100; retryCnt++) { try (ResultSet resultSet = statement.executeQuery(showSql)) { ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State"); while (resultSet.next()) { String modelId = resultSet.getString(1); + String modelType = resultSet.getString(2); String category = resultSet.getString(3); String state = resultSet.getString(4); - assertEquals("operationTest", modelId); - assertEquals("USER-DEFINED", category); - if (state.equals("ACTIVE")) { + assertEquals("user_chronos", modelId); + assertEquals("user_defined", category); + assertEquals("custom_t5", modelType); + if (state.equals("active")) { loading = false; - count++; - } else if (state.equals("LOADING")) { + } else if (state.equals("loading")) { break; } else { fail("Unexpected status of model: " + state); @@ -114,12 +124,16 @@ private void userDefinedModelManagementTest(Statement statement) TimeUnit.SECONDS.sleep(1); } assertFalse(loading); - assertEquals(1, count); + } + + private void dropUserDefinedModel(Statement statement) throws SQLException { + final String showSql = "SHOW MODELS user_chronos"; + final String dropSql = "DROP MODEL user_chronos"; statement.execute(dropSql); try (ResultSet resultSet = statement.executeQuery(showSql)) { ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State"); - count = 0; + int count = 0; while (resultSet.next()) { count++; } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index 35fb51598b71b..d620efacc263d 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -19,11 +19,15 @@ package org.apache.iotdb.ainode.utils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.itbase.env.BaseEnv; + import com.google.common.collect.ImmutableSet; import org.junit.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.sql.Connection; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; @@ -39,6 +43,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -206,6 +211,45 @@ public static void checkModelNotOnSpecifiedDevice( fail("Model " + modelId + " is still loaded on device " + device); } + private static final String[] WRITE_SQL_IN_TREE = + new String[] { + "CREATE DATABASE root.AI", + "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", + }; + + /** Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in tree. */ + public static void prepareDataInTree() throws SQLException { + prepareData(WRITE_SQL_IN_TREE); + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + for (int i = 0; i < 5760; i++) { + statement.execute( + String.format( + "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + i, (float) i, (double) i, i, i)); + } + } + } + + /** Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in table. */ + public static void prepareDataInTable() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE db"); + statement.execute( + "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)"); + for (int i = 0; i < 5760; i++) { + statement.execute( + String.format( + "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + i, (float) i, (double) i, i, i)); + } + } + } + public static class FakeModelInfo { private final String modelId; diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index bcb4a5e2056eb..d0da371bfd5cf 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -31,7 +31,7 @@ def __init__( pipeline_cls: str = "", repo_id: str = "", auto_map: Optional[Dict] = None, - _transformers_registered: bool = False, + transformers_registered: bool = False, ): self.model_id = model_id self.model_type = model_type @@ -40,7 +40,9 @@ def __init__( self.pipeline_cls = pipeline_cls self.repo_id = repo_id self.auto_map = auto_map # If exists, indicates it's a Transformers model - self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers + self.transformers_registered = ( + transformers_registered # Internal flag: whether registered to Transformers + ) def __repr__(self): return ( @@ -116,7 +118,7 @@ def __repr__(self): "AutoConfig": "configuration_timer.TimerConfig", "AutoModelForCausalLM": "modeling_timer.TimerForPrediction", }, - _transformers_registered=True, + transformers_registered=True, ), "sundial": ModelInfo( model_id="sundial", @@ -129,7 +131,7 @@ def __repr__(self): "AutoConfig": "configuration_sundial.SundialConfig", "AutoModelForCausalLM": "modeling_sundial.SundialForPrediction", }, - _transformers_registered=True, + transformers_registered=True, ), "chronos2": ModelInfo( model_id="chronos2", @@ -139,7 +141,7 @@ def __repr__(self): pipeline_cls="pipeline_chronos2.Chronos2Pipeline", repo_id="amazon/chronos-2", auto_map={ - "AutoConfig": "config.Chronos2ForecastingConfig", + "AutoConfig": "config.Chronos2CoreConfig", "AutoModelForCausalLM": "model.Chronos2Model", }, ), diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index ee09cfd75bb3b..910a0620fac5b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -236,7 +236,7 @@ def _process_user_defined_model_directory(self, model_dir: str, model_id: str): state=ModelStates.ACTIVE, pipeline_cls=pipeline_cls, auto_map=auto_map, - _transformers_registered=False, # Lazy registration + transformers_registered=False, # Lazy registration ) self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info @@ -287,7 +287,7 @@ def register_model(self, model_id: str, uri: str): state=ModelStates.ACTIVE, pipeline_cls=pipeline_cls, auto_map=auto_map, - _transformers_registered=False, # Register later + transformers_registered=False, # Register later ) self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info @@ -296,7 +296,7 @@ def register_model(self, model_id: str, uri: str): success = self._register_transformers_model(model_info) if success: with self._lock_pool.get_lock(model_id).write_lock(): - model_info._transformers_registered = True + model_info.transformers_registered = True else: with self._lock_pool.get_lock(model_id).write_lock(): model_info.state = ModelStates.INACTIVE @@ -352,7 +352,7 @@ def _register_other_model(self, model_info: ModelInfo): f"Registered other type model: {model_info.model_id} ({model_info.model_type})" ) - def ensure_transformers_registered(self, model_id: str) -> ModelInfo: + def ensure_transformers_registered(self, model_id: str) -> ModelInfo | None: """ Ensure Transformers model is registered (called for lazy registration) This method uses locks to ensure thread safety. All check logic is within lock protection. @@ -369,11 +369,10 @@ def ensure_transformers_registered(self, model_id: str) -> ModelInfo: break if not model_info: - logger.warning(f"Model {model_id} does not exist, cannot register") return None # If already registered, return directly - if model_info._transformers_registered: + if model_info.transformers_registered: return model_info # If no auto_map, not a Transformers model, mark as registered (avoid duplicate checks) @@ -381,14 +380,14 @@ def ensure_transformers_registered(self, model_id: str) -> ModelInfo: not model_info.auto_map or model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys() ): - model_info._transformers_registered = True + model_info.transformers_registered = True return model_info # Execute registration (under lock protection) try: success = self._register_transformers_model(model_info) if success: - model_info._transformers_registered = True + model_info.transformers_registered = True logger.info( f"Model {model_id} successfully registered to Transformers" ) @@ -401,7 +400,7 @@ def ensure_transformers_registered(self, model_id: str) -> ModelInfo: except Exception as e: # Ensure state consistency in exception cases model_info.state = ModelStates.INACTIVE - model_info._transformers_registered = False + model_info.transformers_registered = False logger.error( f"Exception occurred while registering model {model_id} to Transformers: {e}" ) diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift index cda356a948e1f..ea32f01b6e2dc 100644 --- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift +++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift @@ -40,8 +40,8 @@ struct TAIHeartbeatResp { } struct TRegisterModelReq { - 1: required string uri - 2: required string modelId + 1: required string modelId + 2: required string uri } struct TConfigs { From 1e8dd7a913220fc4eecc0420d87de337a954e410 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 22 Dec 2025 11:04:43 +0800 Subject: [PATCH 2/4] update error handler --- .../iotdb/ainode/it/AINodeModelManageIT.java | 8 +++ .../ainode/core/manager/model_manager.py | 6 +- .../iotdb/ainode/core/model/model_storage.py | 67 +++++++++++-------- 3 files changed, 51 insertions(+), 30 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index 6b2cffd0636fd..fdb75e8deee71 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -75,6 +75,10 @@ public void userDefinedModelManagementTestInTree() throws SQLException, Interrup callInferenceTest( statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active")); dropUserDefinedModel(statement); + errorTest( + statement, + "create model origin_chronos using uri \"file:///data/chronos2_origin\"", + "1505: 't5' is already used by a Transformers config, pick another name."); } } @@ -86,6 +90,10 @@ public void userDefinedModelManagementTestInTable() throws SQLException, Interru forecastTableFunctionTest( statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active")); dropUserDefinedModel(statement); + errorTest( + statement, + "create model origin_chronos using uri \"file:///data/chronos2_origin\"", + "1505: 't5' is already used by a Transformers config, pick another name."); } } diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py index ef0846c3d786c..ff4226e734f18 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -61,9 +61,13 @@ def register_model( return TRegisterModelResp( get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e)) ) + except Exception as e: + # Catch-all for other exceptions (mainly from transformers implementation) + return TRegisterModelResp( + get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e)) + ) def show_models(self, req: TShowModelsReq) -> TShowModelsResp: - self._refresh() return self._model_storage.show_models(req) def delete_model(self, req: TDeleteModelReq) -> TSStatus: diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index 910a0620fac5b..2cfb07fb56a70 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -227,17 +227,22 @@ def _process_user_defined_model_directory(self, model_dir: str, model_id: str): model_type = config.get("model_type", "") auto_map = config.get("auto_map", None) pipeline_cls = config.get("pipeline_cls", "") - + model_info = ModelInfo( + model_id=model_id, + model_type=model_type, + category=ModelCategory.USER_DEFINED, + state=ModelStates.ACTIVE, + pipeline_cls=pipeline_cls, + auto_map=auto_map, + transformers_registered=False, # Lazy registration + ) + with self._lock_pool.get_lock(model_id).write_lock(): + self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info + if self.ensure_transformers_registered(model_id) is None: + model_info.state = ModelStates.INACTIVE + else: + model_info.transformers_registered = True with self._lock_pool.get_lock(model_id).write_lock(): - model_info = ModelInfo( - model_id=model_id, - model_type=model_type, - category=ModelCategory.USER_DEFINED, - state=ModelStates.ACTIVE, - pipeline_cls=pipeline_cls, - auto_map=auto_map, - transformers_registered=False, # Lazy registration - ) self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info # ==================== Registration Methods ==================== @@ -254,6 +259,7 @@ def register_model(self, model_id: str, uri: str): Raises: ModelExistedException: If the model_id already exists. InvalidModelUriException: If the URI format is invalid. + Exception: For other errors during transformers model registration. """ if self.is_model_registered(model_id): @@ -291,25 +297,30 @@ def register_model(self, model_id: str, uri: str): ) self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info - if auto_map: - # Transformers model: immediately register to Transformers autoloading mechanism - success = self._register_transformers_model(model_info) - if success: - with self._lock_pool.get_lock(model_id).write_lock(): - model_info.transformers_registered = True - else: - with self._lock_pool.get_lock(model_id).write_lock(): + if auto_map: + # Transformers model: immediately register to Transformers autoloading mechanism + try: + if self._register_transformers_model(model_info): + model_info.transformers_registered = True + except Exception as e: model_info.state = ModelStates.INACTIVE - logger.error(f"Failed to register Transformers model {model_id}") - else: - # Other type models: only log - self._register_other_model(model_info) + logger.error( + f"Failed to register Transformers model {model_id}, because {e}" + ) + raise e + else: + # Other type models: only log + self._register_other_model(model_info) logger.info(f"Successfully registered model {model_id} from URI: {uri}") - def _register_transformers_model(self, model_info: ModelInfo): + def _register_transformers_model(self, model_info: ModelInfo) -> bool: """ Register Transformers model to autoloading mechanism (internal method) + Returns: + True if registration is successful + Raises: + Exception: Transformers internal exception if registration fails """ auto_map = model_info.auto_map if not auto_map: @@ -344,7 +355,7 @@ def _register_transformers_model(self, model_info: ModelInfo): logger.warning( f"Failed to register Transformers model {model_info.model_id}: {e}. Model may still work via auto_map, but ensure module path is correct." ) - return False + raise e def _register_other_model(self, model_info: ModelInfo): """Register other type models (non-Transformers models)""" @@ -354,10 +365,9 @@ def _register_other_model(self, model_info: ModelInfo): def ensure_transformers_registered(self, model_id: str) -> ModelInfo | None: """ - Ensure Transformers model is registered (called for lazy registration) - This method uses locks to ensure thread safety. All check logic is within lock protection. + Ensure Transformers model is registered. Returns: - str: If None, registration failed, otherwise returns model path + ModelInfo | None: None if registration failed, otherwise returns the corresponding ModelInfo """ # Use lock to protect entire check-execute process with self._lock_pool.get_lock(model_id).write_lock(): @@ -385,8 +395,7 @@ def ensure_transformers_registered(self, model_id: str) -> ModelInfo | None: # Execute registration (under lock protection) try: - success = self._register_transformers_model(model_info) - if success: + if self._register_transformers_model(model_info): model_info.transformers_registered = True logger.info( f"Model {model_id} successfully registered to Transformers" From 6ad3da0e3cd2b540e0cb895204446650fb7acf85 Mon Sep 17 00:00:00 2001 From: Yongzao Date: Mon, 22 Dec 2025 12:14:09 +0800 Subject: [PATCH 3/4] Update integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index fdb75e8deee71..d213749e17068 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -115,8 +115,8 @@ private void registerUserDefinedModel(Statement statement) String category = resultSet.getString(3); String state = resultSet.getString(4); assertEquals("user_chronos", modelId); - assertEquals("user_defined", category); assertEquals("custom_t5", modelType); + assertEquals("user_defined", category); if (state.equals("active")) { loading = false; } else if (state.equals("loading")) { From fef2e10624c39d38e0f4cb32f69b611fd976c39f Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Tue, 23 Dec 2025 09:37:55 +0800 Subject: [PATCH 4/4] Fix CI bug --- .../java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index d213749e17068..8ece0ba7523e0 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -79,6 +79,7 @@ public void userDefinedModelManagementTestInTree() throws SQLException, Interrup statement, "create model origin_chronos using uri \"file:///data/chronos2_origin\"", "1505: 't5' is already used by a Transformers config, pick another name."); + statement.execute("drop model origin_chronos"); } } @@ -94,6 +95,7 @@ public void userDefinedModelManagementTestInTable() throws SQLException, Interru statement, "create model origin_chronos using uri \"file:///data/chronos2_origin\"", "1505: 't5' is already used by a Transformers config, pick another name."); + statement.execute("drop model origin_chronos"); } }