Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public class AINodeWrapper extends AbstractNodeWrapper {
public static final String CONFIG_PATH = "conf";
public static final String SCRIPT_PATH = "sbin";
public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/builtin";
public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights";
public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models";

private void replaceAttribute(String[] keys, String[] values, String filePath) {
Properties props = new Properties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,12 @@

import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;

@RunWith(IoTDBTestRunner.class)
@Category({AIClusterIT.class})
public class AINodeCallInferenceIT {

private static final String[] WRITE_SQL_IN_TREE =
new String[] {
"CREATE DATABASE root.AI",
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
};

private static final String CALL_INFERENCE_SQL_TEMPLATE =
"CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)";
private static final int DEFAULT_INPUT_LENGTH = 256;
Expand All @@ -64,16 +55,7 @@ public class AINodeCallInferenceIT {
public static void setUp() throws Exception {
// Init 1C1D1A cluster environment
EnvFactory.getEnv().initClusterEnvironment(1, 1);
prepareData(WRITE_SQL_IN_TREE);
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
for (int i = 0; i < 2880; i++) {
statement.execute(
String.format(
"INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
i, (float) i, (double) i, i, i));
}
}
prepareDataInTree();
}

@AfterClass
Expand All @@ -91,7 +73,7 @@ public void callInferenceTest() throws SQLException {
}
}

public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo)
public static void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo)
throws SQLException {
// Invoke call inference for specified models, there should exist result.
for (int i = 0; i < 4; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;

@RunWith(IoTDBTestRunner.class)
@Category({AIClusterIT.class})
Expand All @@ -58,18 +59,7 @@ public class AINodeForecastIT {
public static void setUp() throws Exception {
// Init 1C1D1A cluster environment
EnvFactory.getEnv().initClusterEnvironment(1, 1);
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
statement.execute("CREATE DATABASE db");
statement.execute(
"CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)");
for (int i = 0; i < 5760; i++) {
statement.execute(
String.format(
"INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
i, (float) i, (double) i, i, i));
}
}
prepareDataInTable();
}

@AfterClass
Expand All @@ -87,7 +77,7 @@ public void forecastTableFunctionTest() throws SQLException {
}
}

public void forecastTableFunctionTest(
public static void forecastTableFunctionTest(
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException {
// Invoke forecast table function for specified models, there should exist result.
for (int i = 0; i < 4; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@
import java.sql.Statement;
import java.util.concurrent.TimeUnit;

import static org.apache.iotdb.ainode.it.AINodeCallInferenceIT.callInferenceTest;
import static org.apache.iotdb.ainode.it.AINodeForecastIT.forecastTableFunctionTest;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
Expand All @@ -54,54 +58,70 @@ public class AINodeModelManageIT {
public static void setUp() throws Exception {
// Init 1C1D1A cluster environment
EnvFactory.getEnv().initClusterEnvironment(1, 1);
prepareDataInTree();
prepareDataInTable();
}

@AfterClass
public static void tearDown() throws Exception {
EnvFactory.getEnv().cleanClusterEnvironment();
}

// @Test
@Test
public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
userDefinedModelManagementTest(statement);
registerUserDefinedModel(statement);
callInferenceTest(
statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"));
dropUserDefinedModel(statement);
errorTest(
statement,
"create model origin_chronos using uri \"file:///data/chronos2_origin\"",
"1505: 't5' is already used by a Transformers config, pick another name.");
statement.execute("drop model origin_chronos");
}
}

// @Test
@Test
public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
userDefinedModelManagementTest(statement);
registerUserDefinedModel(statement);
forecastTableFunctionTest(
statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"));
dropUserDefinedModel(statement);
errorTest(
statement,
"create model origin_chronos using uri \"file:///data/chronos2_origin\"",
"1505: 't5' is already used by a Transformers config, pick another name.");
statement.execute("drop model origin_chronos");
}
}

private void userDefinedModelManagementTest(Statement statement)
private void registerUserDefinedModel(Statement statement)
throws SQLException, InterruptedException {
final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'";
final String registerSql = "create model operationTest using uri \"" + "\"";
final String showSql = "SHOW MODELS operationTest";
final String dropSql = "DROP MODEL operationTest";

final String registerSql = "create model user_chronos using uri \"file:///data/chronos2\"";
final String showSql = "SHOW MODELS user_chronos";
statement.execute(alterConfigSQL);
statement.execute(registerSql);
boolean loading = true;
int count = 0;
for (int retryCnt = 0; retryCnt < 100; retryCnt++) {
try (ResultSet resultSet = statement.executeQuery(showSql)) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
while (resultSet.next()) {
String modelId = resultSet.getString(1);
String modelType = resultSet.getString(2);
String category = resultSet.getString(3);
String state = resultSet.getString(4);
assertEquals("operationTest", modelId);
assertEquals("USER-DEFINED", category);
if (state.equals("ACTIVE")) {
assertEquals("user_chronos", modelId);
assertEquals("custom_t5", modelType);
assertEquals("user_defined", category);
if (state.equals("active")) {
loading = false;
count++;
} else if (state.equals("LOADING")) {
} else if (state.equals("loading")) {
break;
} else {
fail("Unexpected status of model: " + state);
Expand All @@ -114,12 +134,16 @@ private void userDefinedModelManagementTest(Statement statement)
TimeUnit.SECONDS.sleep(1);
}
assertFalse(loading);
assertEquals(1, count);
}

private void dropUserDefinedModel(Statement statement) throws SQLException {
final String showSql = "SHOW MODELS user_chronos";
final String dropSql = "DROP MODEL user_chronos";
statement.execute(dropSql);
try (ResultSet resultSet = statement.executeQuery(showSql)) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
count = 0;
int count = 0;
while (resultSet.next()) {
count++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@

package org.apache.iotdb.ainode.utils;

import org.apache.iotdb.it.env.EnvFactory;
import org.apache.iotdb.itbase.env.BaseEnv;

import com.google.common.collect.ImmutableSet;
import org.junit.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
Expand All @@ -39,6 +43,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

Expand Down Expand Up @@ -206,6 +211,45 @@ public static void checkModelNotOnSpecifiedDevice(
fail("Model " + modelId + " is still loaded on device " + device);
}

private static final String[] WRITE_SQL_IN_TREE =
new String[] {
"CREATE DATABASE root.AI",
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
};

/** Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in tree. */
public static void prepareDataInTree() throws SQLException {
prepareData(WRITE_SQL_IN_TREE);
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
for (int i = 0; i < 5760; i++) {
statement.execute(
String.format(
"INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
i, (float) i, (double) i, i, i));
}
}
}

/** Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in table. */
public static void prepareDataInTable() throws SQLException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
statement.execute("CREATE DATABASE db");
statement.execute(
"CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)");
for (int i = 0; i < 5760; i++) {
statement.execute(
String.format(
"INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
i, (float) i, (double) i, i, i));
}
}
}

public static class FakeModelInfo {

private final String modelId;
Expand Down
6 changes: 5 additions & 1 deletion iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,13 @@ def register_model(
return TRegisterModelResp(
get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e))
)
except Exception as e:
# Catch-all for other exceptions (mainly from transformers implementation)
return TRegisterModelResp(
get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e))
)

def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
self._refresh()
return self._model_storage.show_models(req)

def delete_model(self, req: TDeleteModelReq) -> TSStatus:
Expand Down
12 changes: 7 additions & 5 deletions iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
pipeline_cls: str = "",
repo_id: str = "",
auto_map: Optional[Dict] = None,
_transformers_registered: bool = False,
transformers_registered: bool = False,
):
self.model_id = model_id
self.model_type = model_type
Expand All @@ -40,7 +40,9 @@ def __init__(
self.pipeline_cls = pipeline_cls
self.repo_id = repo_id
self.auto_map = auto_map # If exists, indicates it's a Transformers model
self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers
self.transformers_registered = (
transformers_registered # Internal flag: whether registered to Transformers
)

def __repr__(self):
return (
Expand Down Expand Up @@ -116,7 +118,7 @@ def __repr__(self):
"AutoConfig": "configuration_timer.TimerConfig",
"AutoModelForCausalLM": "modeling_timer.TimerForPrediction",
},
_transformers_registered=True,
transformers_registered=True,
),
"sundial": ModelInfo(
model_id="sundial",
Expand All @@ -129,7 +131,7 @@ def __repr__(self):
"AutoConfig": "configuration_sundial.SundialConfig",
"AutoModelForCausalLM": "modeling_sundial.SundialForPrediction",
},
_transformers_registered=True,
transformers_registered=True,
),
"chronos2": ModelInfo(
model_id="chronos2",
Expand All @@ -139,7 +141,7 @@ def __repr__(self):
pipeline_cls="pipeline_chronos2.Chronos2Pipeline",
repo_id="amazon/chronos-2",
auto_map={
"AutoConfig": "config.Chronos2ForecastingConfig",
"AutoConfig": "config.Chronos2CoreConfig",
"AutoModelForCausalLM": "model.Chronos2Model",
},
),
Expand Down
Loading
Loading