From 9883649ab269179de40726cb9d4bdfbbe7d59e55 Mon Sep 17 00:00:00 2001 From: Alex Kasko Date: Sun, 22 Feb 2026 23:49:19 +0000 Subject: [PATCH] VARIANT columns and parameters support This PR adds support for `VARIANT` columns and query parameters. Underlying values of the `VARIANT` result columns are returned from `ResultSet#getObject(col)` calls. --- src/jni/duckdb_java.cpp | 19 ++ src/jni/refs.cpp | 2 + src/jni/refs.hpp | 1 + .../java/org/duckdb/DuckDBColumnType.java | 3 +- src/main/java/org/duckdb/DuckDBVector.java | 15 + src/test/java/org/duckdb/TestDuckDBJDBC.java | 12 +- src/test/java/org/duckdb/TestVariant.java | 271 ++++++++++++++++++ 7 files changed, 316 insertions(+), 7 deletions(-) create mode 100644 src/test/java/org/duckdb/TestVariant.java diff --git a/src/jni/duckdb_java.cpp b/src/jni/duckdb_java.cpp index 0a5cbcce4..8911b6ed5 100644 --- a/src/jni/duckdb_java.cpp +++ b/src/jni/duckdb_java.cpp @@ -7,6 +7,7 @@ extern "C" { #include "duckdb/common/arrow/result_arrow_wrapper.hpp" #include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/common/shared_ptr.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" #include "duckdb/function/table/arrow.hpp" #include "duckdb/main/appender.hpp" #include "duckdb/main/client_context.hpp" @@ -685,6 +686,24 @@ jobject ProcessVector(JNIEnv *env, Connection *conn_ref, Vector &vec, idx_t row_ } break; } + case LogicalTypeId::VARIANT: { + RecursiveUnifiedVectorFormat format; + Vector::RecursiveToUnifiedFormat(vec, 1, format); + UnifiedVariantVectorData vector_data(format); + varlen_data = env->NewObjectArray(row_count, J_Object, nullptr); + for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { + auto variant_val = VariantUtils::ConvertVariantToValue(vector_data, row_idx, 0); + if (variant_val.IsNull()) { + continue; + } + Vector variant_vec(variant_val); + variant_vec.Flatten(1); + jobject variant_j_vec = ProcessVector(env, conn_ref, variant_vec, 1); + env->CallVoidMethod(variant_j_vec, J_DuckVector_retainConstlenData); + env->SetObjectArrayElement(varlen_data, row_idx, variant_j_vec); + } + break; + } default: { Vector string_vec(LogicalType::VARCHAR); VectorOperations::Cast(*conn_ref->context, vec, string_vec, row_count); diff --git a/src/jni/refs.cpp b/src/jni/refs.cpp index bc20f691f..f5a0a795e 100644 --- a/src/jni/refs.cpp +++ b/src/jni/refs.cpp @@ -57,6 +57,7 @@ jmethodID J_DuckResultSetMeta_init; jclass J_DuckVector; jmethodID J_DuckVector_init; +jmethodID J_DuckVector_retainConstlenData; jfieldID J_DuckVector_constlen; jfieldID J_DuckVector_varlen; @@ -270,6 +271,7 @@ void create_refs(JNIEnv *env) { J_String_getBytes = get_method_id(env, J_String, "getBytes", "(Ljava/nio/charset/Charset;)[B"); J_DuckVector_init = get_method_id(env, J_DuckVector, "", "(Ljava/lang/String;I[Z)V"); + J_DuckVector_retainConstlenData = get_method_id(env, J_DuckVector, "retainConstlenData", "()V"); J_DuckVector_constlen = get_field_id(env, J_DuckVector, "constlen_data", "Ljava/nio/ByteBuffer;"); J_DuckVector_varlen = get_field_id(env, J_DuckVector, "varlen_data", "[Ljava/lang/Object;"); diff --git a/src/jni/refs.hpp b/src/jni/refs.hpp index 006afa36d..cd7b20121 100644 --- a/src/jni/refs.hpp +++ b/src/jni/refs.hpp @@ -54,6 +54,7 @@ extern jmethodID J_DuckResultSetMeta_init; extern jclass J_DuckVector; extern jmethodID J_DuckVector_init; +extern jmethodID J_DuckVector_retainConstlenData; extern jfieldID J_DuckVector_constlen; extern jfieldID J_DuckVector_varlen; diff --git a/src/main/java/org/duckdb/DuckDBColumnType.java b/src/main/java/org/duckdb/DuckDBColumnType.java index d46df3fd7..bc6731e81 100644 --- a/src/main/java/org/duckdb/DuckDBColumnType.java +++ b/src/main/java/org/duckdb/DuckDBColumnType.java @@ -36,5 +36,6 @@ public enum DuckDBColumnType { MAP, ARRAY, UNKNOWN, - UNION; + UNION, + VARIANT; } diff --git a/src/main/java/org/duckdb/DuckDBVector.java b/src/main/java/org/duckdb/DuckDBVector.java index f16536b3f..c629eb32b 100644 --- a/src/main/java/org/duckdb/DuckDBVector.java +++ b/src/main/java/org/duckdb/DuckDBVector.java @@ -59,6 +59,14 @@ class DuckDBVector { this.nullmask = nullmask; } + private void retainConstlenData() { + if (null != constlen_data) { + byte[] constlenBytes = new byte[constlen_data.capacity()]; + constlen_data.get(constlenBytes); + this.constlen_data = ByteBuffer.wrap(constlenBytes); + } + } + Object getObject(int idx) throws SQLException { if (check_and_null(idx)) { return null; @@ -121,6 +129,8 @@ Object getObject(int idx) throws SQLException { return getStruct(idx); case UNION: return getUnion(idx); + case VARIANT: + return getVariant(idx); default: return getLazyString(idx); } @@ -721,4 +731,9 @@ Object getUnion(int idx) throws SQLException { return attributes[1 + tag]; } + + Object getVariant(int idx) throws SQLException { + DuckDBVector vec = (DuckDBVector) varlen_data[idx]; + return vec.getObject(0); + } } diff --git a/src/test/java/org/duckdb/TestDuckDBJDBC.java b/src/test/java/org/duckdb/TestDuckDBJDBC.java index f83a32536..4a62057f0 100644 --- a/src/test/java/org/duckdb/TestDuckDBJDBC.java +++ b/src/test/java/org/duckdb/TestDuckDBJDBC.java @@ -2243,12 +2243,12 @@ public static void main(String[] args) throws Exception { Class clazz = Class.forName("org.duckdb." + arg1); statusCode = runTests(new String[0], clazz); } else { - statusCode = - runTests(args, TestDuckDBJDBC.class, TestAppender.class, TestAppenderCollection.class, - TestAppenderCollection2D.class, TestAppenderComposite.class, TestSingleValueAppender.class, - TestBatch.class, TestBindings.class, TestClosure.class, TestExtensionTypes.class, - TestMetadata.class, TestNoLib.class, /* TestSpatial.class, */ TestParameterMetadata.class, - TestPrepare.class, TestResults.class, TestSessionInit.class, TestTimestamp.class); + statusCode = runTests(args, TestDuckDBJDBC.class, TestAppender.class, TestAppenderCollection.class, + TestAppenderCollection2D.class, TestAppenderComposite.class, + TestSingleValueAppender.class, TestBatch.class, TestBindings.class, TestClosure.class, + TestExtensionTypes.class, TestMetadata.class, TestNoLib.class, + /* TestSpatial.class, */ TestParameterMetadata.class, TestPrepare.class, + TestResults.class, TestSessionInit.class, TestTimestamp.class, TestVariant.class); } System.exit(statusCode); } diff --git a/src/test/java/org/duckdb/TestVariant.java b/src/test/java/org/duckdb/TestVariant.java new file mode 100644 index 000000000..052802eb1 --- /dev/null +++ b/src/test/java/org/duckdb/TestVariant.java @@ -0,0 +1,271 @@ +package org.duckdb; + +import static org.duckdb.TestDuckDBJDBC.JDBC_URL; +import static org.duckdb.test.Assertions.*; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.sql.*; +import java.util.LinkedHashMap; +import java.util.Map; + +public class TestVariant { + + public static void test_variant_varchar() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT 'foo'::VARCHAR::VARIANT AS col1")) { + assertTrue(rs.next()); + assertEquals(rs.getMetaData().getColumnType(1), Types.OTHER); + assertEquals(rs.getObject(1).getClass(), String.class); + assertEquals(rs.getObject(1), "foo"); + assertFalse(rs.next()); + } + } + + public static void test_variant_bool() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT TRUE::BOOL::VARIANT AS col1")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), Boolean.class); + assertEquals(rs.getObject(1), true); + assertFalse(rs.next()); + } + } + + public static void test_variant_integrals() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT 41::TINYINT::VARIANT AS col1" + + " UNION ALL " + + "SELECT 42::SMALLINT::VARIANT AS col1" + + " UNION ALL " + + "SELECT 43::INTEGER::VARIANT AS col1" + + " UNION ALL " + + "SELECT 44::BIGINT::VARIANT AS col1" + + " UNION ALL " + + "SELECT 45::HUGEINT::VARIANT AS col1")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), Byte.class); + assertEquals(rs.getObject(1), (byte) 41); + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), Short.class); + assertEquals(rs.getObject(1), (short) 42); + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), Integer.class); + assertEquals(rs.getObject(1), 43); + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), Long.class); + assertEquals(rs.getObject(1), (long) 44); + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), BigInteger.class); + assertEquals(rs.getObject(1), BigInteger.valueOf(45)); + assertFalse(rs.next()); + } + } + + public static void test_variant_floats() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT 41.1::FLOAT::VARIANT AS col1" + + " UNION ALL " + + "SELECT 42.2::DOUBLE::VARIANT AS col1")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), Float.class); + assertEquals(rs.getObject(1), (float) 41.1); + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), Double.class); + assertEquals(rs.getObject(1), 42.2); + assertFalse(rs.next()); + } + } + + public static void test_variant_decimals() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT 41.1::DECIMAL(8,1)::VARIANT AS col1" + + " UNION ALL " + + "SELECT 42.2::DECIMAL(38,1)::VARIANT AS col1")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), BigDecimal.class); + assertEquals(rs.getObject(1), BigDecimal.valueOf(41.1)); + // assertEquals(rs.getMetaData().getPrecision(1), 8); + // assertEquals(rs.getMetaData().getScale(1), 1); + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), BigDecimal.class); + assertEquals(rs.getObject(1), BigDecimal.valueOf(42.2)); + // assertEquals(rs.getMetaData().getPrecision(1), 38); + // assertEquals(rs.getMetaData().getScale(1), 1); + assertFalse(rs.next()); + } + } + + public static void test_variant_null() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT 'foo'::VARCHAR::VARIANT AS col1" + + " UNION ALL " + + "SELECT NULL::VARIANT AS col1" + + " UNION ALL " + + "SELECT 42::INTEGER::VARIANT AS col1")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), String.class); + assertEquals(rs.getObject(1), "foo"); + assertTrue(rs.next()); + assertEquals(rs.getObject(1), null); + assertTrue(rs.wasNull()); + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), Integer.class); + assertEquals(rs.getObject(1), 42); + assertFalse(rs.next()); + } + } + + public static void test_variant_query_params() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); + PreparedStatement ps = conn.prepareStatement("SELECT ?::VARCHAR::VARIANT AS col1" + + " UNION ALL " + + "SELECT ?::INTEGER::VARIANT AS col1")) { + ps.setString(1, "foo"); + ps.setInt(2, 42); + try (ResultSet rs = ps.executeQuery()) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), String.class); + assertEquals(rs.getObject(1), "foo"); + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), Integer.class); + assertEquals(rs.getObject(1), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_variant_columns() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT 'foo'::VARCHAR::VARIANT AS col1, 42::INTEGER::VARIANT AS col2")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), String.class); + assertEquals(rs.getObject(1), "foo"); + assertEquals(rs.getObject(2).getClass(), Integer.class); + assertEquals(rs.getObject(2), 42); + assertFalse(rs.next()); + } + } + + public static void test_variant_array() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT [41, 42, 43]::INTEGER[3]::VARIANT AS col1")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), DuckDBArray.class); + Array arrayWrapper = (Array) rs.getObject(1); + Object[] array = (Object[]) arrayWrapper.getArray(); + assertEquals(array.length, 3); + assertEquals(array[0].getClass(), Integer.class); + assertEquals(array[0], 41); + assertEquals(array[1].getClass(), Integer.class); + assertEquals(array[1], 42); + assertEquals(array[2].getClass(), Integer.class); + assertEquals(array[2], 43); + assertFalse(rs.next()); + } + } + + public static void test_variant_list() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT [41, 42, 43]::INTEGER[]::VARIANT AS col1")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), DuckDBArray.class); + Array arrayWrapper = (Array) rs.getObject(1); + Object[] array = (Object[]) arrayWrapper.getArray(); + assertEquals(array.length, 3); + assertEquals(array[0].getClass(), Integer.class); + assertEquals(array[0], 41); + assertEquals(array[1].getClass(), Integer.class); + assertEquals(array[1], 42); + assertEquals(array[2].getClass(), Integer.class); + assertEquals(array[2], 43); + assertFalse(rs.next()); + } + } + + public static void test_variant_list_of_variants() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + ResultSet rs = + stmt.executeQuery("SELECT [41::VARIANT, NULL::VARIANT, 'foo'::VARIANT]::VARIANT[]::VARIANT AS col1")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), DuckDBArray.class); + Array arrayWrapper = (Array) rs.getObject(1); + Object[] array = (Object[]) arrayWrapper.getArray(); + assertEquals(array.length, 3); + assertEquals(array[0].getClass(), Integer.class); + assertEquals(array[0], 41); + assertNull(array[1]); + assertEquals(array[2].getClass(), String.class); + assertEquals(array[2], "foo"); + assertFalse(rs.next()); + } + } + + public static void test_variant_map() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT MAP {'foo': 41, 'bar': 42}::VARIANT AS col1")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1).getClass(), DuckDBArray.class); + Array arrayWrapper = (Array) rs.getObject(1); + Object[] array = (Object[]) arrayWrapper.getArray(); + assertEquals(array.length, 2); + { + DuckDBStruct struct = (DuckDBStruct) array[0]; + Map map = struct.getMap(); + assertEquals(map.size(), 2); + assertEquals(map.get("key"), "foo"); + assertEquals(map.get("value"), 41); + } + { + DuckDBStruct struct = (DuckDBStruct) array[1]; + Map map = struct.getMap(); + assertEquals(map.size(), 2); + assertEquals(map.get("key"), "bar"); + assertEquals(map.get("value"), 42); + } + assertFalse(rs.next()); + } + } + + public static void test_variant_struct() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT {'foo': 41, 'bar': 42}::VARIANT AS col1")) { + assertTrue(rs.next()); + DuckDBStruct struct = (DuckDBStruct) rs.getObject(1); + Map map = struct.getMap(); + assertEquals(map.size(), 2); + assertEquals(map.get("foo"), 41); + assertEquals(map.get("bar"), 42); + assertFalse(rs.next()); + } + } + + public static void test_variant_struct_with_variant() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + + stmt.execute("CREATE TABLE tab1 (col1 INTEGER, col2 STRUCT(s1 INTEGER, s2 VARIANT))"); + stmt.execute("INSERT INTO tab1 VALUES(41, row(42, 43))"); + stmt.execute("INSERT INTO tab1 VALUES(44, row(45, 'foo'))"); + + try (ResultSet rs = stmt.executeQuery("SELECT col2 FROM tab1 ORDER BY col1")) { + assertTrue(rs.next()); + { + DuckDBStruct struct = (DuckDBStruct) rs.getObject(1); + Map map = struct.getMap(); + assertEquals(map.size(), 2); + assertEquals(map.get("s1"), 42); + assertEquals(map.get("s2"), 43); + } + assertTrue(rs.next()); + { + DuckDBStruct struct = (DuckDBStruct) rs.getObject(1); + Map map = struct.getMap(); + assertEquals(map.size(), 2); + assertEquals(map.get("s1"), 45); + assertEquals(map.get("s2"), "foo"); + } + assertFalse(rs.next()); + } + } + } +}