Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/jni/duckdb_java.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ extern "C" {
#include "duckdb/common/arrow/result_arrow_wrapper.hpp"
#include "duckdb/common/operator/cast_operators.hpp"
#include "duckdb/common/shared_ptr.hpp"
#include "duckdb/function/scalar/variant_utils.hpp"
#include "duckdb/function/table/arrow.hpp"
#include "duckdb/main/appender.hpp"
#include "duckdb/main/client_context.hpp"
Expand Down Expand Up @@ -685,6 +686,24 @@ jobject ProcessVector(JNIEnv *env, Connection *conn_ref, Vector &vec, idx_t row_
}
break;
}
case LogicalTypeId::VARIANT: {
RecursiveUnifiedVectorFormat format;
Vector::RecursiveToUnifiedFormat(vec, 1, format);
UnifiedVariantVectorData vector_data(format);
varlen_data = env->NewObjectArray(row_count, J_Object, nullptr);
for (idx_t row_idx = 0; row_idx < row_count; row_idx++) {
auto variant_val = VariantUtils::ConvertVariantToValue(vector_data, row_idx, 0);
if (variant_val.IsNull()) {
continue;
}
Vector variant_vec(variant_val);
variant_vec.Flatten(1);
jobject variant_j_vec = ProcessVector(env, conn_ref, variant_vec, 1);
env->CallVoidMethod(variant_j_vec, J_DuckVector_retainConstlenData);
env->SetObjectArrayElement(varlen_data, row_idx, variant_j_vec);
}
break;
}
default: {
Vector string_vec(LogicalType::VARCHAR);
VectorOperations::Cast(*conn_ref->context, vec, string_vec, row_count);
Expand Down
2 changes: 2 additions & 0 deletions src/jni/refs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ jmethodID J_DuckResultSetMeta_init;

jclass J_DuckVector;
jmethodID J_DuckVector_init;
jmethodID J_DuckVector_retainConstlenData;
jfieldID J_DuckVector_constlen;
jfieldID J_DuckVector_varlen;

Expand Down Expand Up @@ -270,6 +271,7 @@ void create_refs(JNIEnv *env) {
J_String_getBytes = get_method_id(env, J_String, "getBytes", "(Ljava/nio/charset/Charset;)[B");

J_DuckVector_init = get_method_id(env, J_DuckVector, "<init>", "(Ljava/lang/String;I[Z)V");
J_DuckVector_retainConstlenData = get_method_id(env, J_DuckVector, "retainConstlenData", "()V");
J_DuckVector_constlen = get_field_id(env, J_DuckVector, "constlen_data", "Ljava/nio/ByteBuffer;");
J_DuckVector_varlen = get_field_id(env, J_DuckVector, "varlen_data", "[Ljava/lang/Object;");

Expand Down
1 change: 1 addition & 0 deletions src/jni/refs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ extern jmethodID J_DuckResultSetMeta_init;

extern jclass J_DuckVector;
extern jmethodID J_DuckVector_init;
extern jmethodID J_DuckVector_retainConstlenData;
extern jfieldID J_DuckVector_constlen;
extern jfieldID J_DuckVector_varlen;

Expand Down
3 changes: 2 additions & 1 deletion src/main/java/org/duckdb/DuckDBColumnType.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ public enum DuckDBColumnType {
MAP,
ARRAY,
UNKNOWN,
UNION;
UNION,
VARIANT;
}
15 changes: 15 additions & 0 deletions src/main/java/org/duckdb/DuckDBVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ class DuckDBVector {
this.nullmask = nullmask;
}

private void retainConstlenData() {
if (null != constlen_data) {
byte[] constlenBytes = new byte[constlen_data.capacity()];
constlen_data.get(constlenBytes);
this.constlen_data = ByteBuffer.wrap(constlenBytes);
}
}

Object getObject(int idx) throws SQLException {
if (check_and_null(idx)) {
return null;
Expand Down Expand Up @@ -121,6 +129,8 @@ Object getObject(int idx) throws SQLException {
return getStruct(idx);
case UNION:
return getUnion(idx);
case VARIANT:
return getVariant(idx);
default:
return getLazyString(idx);
}
Expand Down Expand Up @@ -721,4 +731,9 @@ Object getUnion(int idx) throws SQLException {

return attributes[1 + tag];
}

Object getVariant(int idx) throws SQLException {
DuckDBVector vec = (DuckDBVector) varlen_data[idx];
return vec.getObject(0);
}
}
12 changes: 6 additions & 6 deletions src/test/java/org/duckdb/TestDuckDBJDBC.java
Original file line number Diff line number Diff line change
Expand Up @@ -2243,12 +2243,12 @@ public static void main(String[] args) throws Exception {
Class<?> clazz = Class.forName("org.duckdb." + arg1);
statusCode = runTests(new String[0], clazz);
} else {
statusCode =
runTests(args, TestDuckDBJDBC.class, TestAppender.class, TestAppenderCollection.class,
TestAppenderCollection2D.class, TestAppenderComposite.class, TestSingleValueAppender.class,
TestBatch.class, TestBindings.class, TestClosure.class, TestExtensionTypes.class,
TestMetadata.class, TestNoLib.class, /* TestSpatial.class, */ TestParameterMetadata.class,
TestPrepare.class, TestResults.class, TestSessionInit.class, TestTimestamp.class);
statusCode = runTests(args, TestDuckDBJDBC.class, TestAppender.class, TestAppenderCollection.class,
TestAppenderCollection2D.class, TestAppenderComposite.class,
TestSingleValueAppender.class, TestBatch.class, TestBindings.class, TestClosure.class,
TestExtensionTypes.class, TestMetadata.class, TestNoLib.class,
/* TestSpatial.class, */ TestParameterMetadata.class, TestPrepare.class,
TestResults.class, TestSessionInit.class, TestTimestamp.class, TestVariant.class);
}
System.exit(statusCode);
}
Expand Down
271 changes: 271 additions & 0 deletions src/test/java/org/duckdb/TestVariant.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
package org.duckdb;

import static org.duckdb.TestDuckDBJDBC.JDBC_URL;
import static org.duckdb.test.Assertions.*;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.sql.*;
import java.util.LinkedHashMap;
import java.util.Map;

public class TestVariant {

public static void test_variant_varchar() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT 'foo'::VARCHAR::VARIANT AS col1")) {
assertTrue(rs.next());
assertEquals(rs.getMetaData().getColumnType(1), Types.OTHER);
assertEquals(rs.getObject(1).getClass(), String.class);
assertEquals(rs.getObject(1), "foo");
assertFalse(rs.next());
}
}

public static void test_variant_bool() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT TRUE::BOOL::VARIANT AS col1")) {
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), Boolean.class);
assertEquals(rs.getObject(1), true);
assertFalse(rs.next());
}
}

public static void test_variant_integrals() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT 41::TINYINT::VARIANT AS col1"
+ " UNION ALL "
+ "SELECT 42::SMALLINT::VARIANT AS col1"
+ " UNION ALL "
+ "SELECT 43::INTEGER::VARIANT AS col1"
+ " UNION ALL "
+ "SELECT 44::BIGINT::VARIANT AS col1"
+ " UNION ALL "
+ "SELECT 45::HUGEINT::VARIANT AS col1")) {
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), Byte.class);
assertEquals(rs.getObject(1), (byte) 41);
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), Short.class);
assertEquals(rs.getObject(1), (short) 42);
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), Integer.class);
assertEquals(rs.getObject(1), 43);
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), Long.class);
assertEquals(rs.getObject(1), (long) 44);
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), BigInteger.class);
assertEquals(rs.getObject(1), BigInteger.valueOf(45));
assertFalse(rs.next());
}
}

public static void test_variant_floats() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT 41.1::FLOAT::VARIANT AS col1"
+ " UNION ALL "
+ "SELECT 42.2::DOUBLE::VARIANT AS col1")) {
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), Float.class);
assertEquals(rs.getObject(1), (float) 41.1);
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), Double.class);
assertEquals(rs.getObject(1), 42.2);
assertFalse(rs.next());
}
}

public static void test_variant_decimals() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT 41.1::DECIMAL(8,1)::VARIANT AS col1"
+ " UNION ALL "
+ "SELECT 42.2::DECIMAL(38,1)::VARIANT AS col1")) {
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), BigDecimal.class);
assertEquals(rs.getObject(1), BigDecimal.valueOf(41.1));
// assertEquals(rs.getMetaData().getPrecision(1), 8);
// assertEquals(rs.getMetaData().getScale(1), 1);
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), BigDecimal.class);
assertEquals(rs.getObject(1), BigDecimal.valueOf(42.2));
// assertEquals(rs.getMetaData().getPrecision(1), 38);
// assertEquals(rs.getMetaData().getScale(1), 1);
assertFalse(rs.next());
}
}

public static void test_variant_null() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT 'foo'::VARCHAR::VARIANT AS col1"
+ " UNION ALL "
+ "SELECT NULL::VARIANT AS col1"
+ " UNION ALL "
+ "SELECT 42::INTEGER::VARIANT AS col1")) {
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), String.class);
assertEquals(rs.getObject(1), "foo");
assertTrue(rs.next());
assertEquals(rs.getObject(1), null);
assertTrue(rs.wasNull());
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), Integer.class);
assertEquals(rs.getObject(1), 42);
assertFalse(rs.next());
}
}

public static void test_variant_query_params() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL);
PreparedStatement ps = conn.prepareStatement("SELECT ?::VARCHAR::VARIANT AS col1"
+ " UNION ALL "
+ "SELECT ?::INTEGER::VARIANT AS col1")) {
ps.setString(1, "foo");
ps.setInt(2, 42);
try (ResultSet rs = ps.executeQuery()) {
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), String.class);
assertEquals(rs.getObject(1), "foo");
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), Integer.class);
assertEquals(rs.getObject(1), 42);
assertFalse(rs.next());
}
}
}

public static void test_variant_columns() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT 'foo'::VARCHAR::VARIANT AS col1, 42::INTEGER::VARIANT AS col2")) {
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), String.class);
assertEquals(rs.getObject(1), "foo");
assertEquals(rs.getObject(2).getClass(), Integer.class);
assertEquals(rs.getObject(2), 42);
assertFalse(rs.next());
}
}

public static void test_variant_array() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT [41, 42, 43]::INTEGER[3]::VARIANT AS col1")) {
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), DuckDBArray.class);
Array arrayWrapper = (Array) rs.getObject(1);
Object[] array = (Object[]) arrayWrapper.getArray();
assertEquals(array.length, 3);
assertEquals(array[0].getClass(), Integer.class);
assertEquals(array[0], 41);
assertEquals(array[1].getClass(), Integer.class);
assertEquals(array[1], 42);
assertEquals(array[2].getClass(), Integer.class);
assertEquals(array[2], 43);
assertFalse(rs.next());
}
}

public static void test_variant_list() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT [41, 42, 43]::INTEGER[]::VARIANT AS col1")) {
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), DuckDBArray.class);
Array arrayWrapper = (Array) rs.getObject(1);
Object[] array = (Object[]) arrayWrapper.getArray();
assertEquals(array.length, 3);
assertEquals(array[0].getClass(), Integer.class);
assertEquals(array[0], 41);
assertEquals(array[1].getClass(), Integer.class);
assertEquals(array[1], 42);
assertEquals(array[2].getClass(), Integer.class);
assertEquals(array[2], 43);
assertFalse(rs.next());
}
}

public static void test_variant_list_of_variants() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();
ResultSet rs =
stmt.executeQuery("SELECT [41::VARIANT, NULL::VARIANT, 'foo'::VARIANT]::VARIANT[]::VARIANT AS col1")) {
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), DuckDBArray.class);
Array arrayWrapper = (Array) rs.getObject(1);
Object[] array = (Object[]) arrayWrapper.getArray();
assertEquals(array.length, 3);
assertEquals(array[0].getClass(), Integer.class);
assertEquals(array[0], 41);
assertNull(array[1]);
assertEquals(array[2].getClass(), String.class);
assertEquals(array[2], "foo");
assertFalse(rs.next());
}
}

public static void test_variant_map() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT MAP {'foo': 41, 'bar': 42}::VARIANT AS col1")) {
assertTrue(rs.next());
assertEquals(rs.getObject(1).getClass(), DuckDBArray.class);
Array arrayWrapper = (Array) rs.getObject(1);
Object[] array = (Object[]) arrayWrapper.getArray();
assertEquals(array.length, 2);
{
DuckDBStruct struct = (DuckDBStruct) array[0];
Map<?, ?> map = struct.getMap();
assertEquals(map.size(), 2);
assertEquals(map.get("key"), "foo");
assertEquals(map.get("value"), 41);
}
{
DuckDBStruct struct = (DuckDBStruct) array[1];
Map<?, ?> map = struct.getMap();
assertEquals(map.size(), 2);
assertEquals(map.get("key"), "bar");
assertEquals(map.get("value"), 42);
}
assertFalse(rs.next());
}
}

public static void test_variant_struct() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT {'foo': 41, 'bar': 42}::VARIANT AS col1")) {
assertTrue(rs.next());
DuckDBStruct struct = (DuckDBStruct) rs.getObject(1);
Map<?, ?> map = struct.getMap();
assertEquals(map.size(), 2);
assertEquals(map.get("foo"), 41);
assertEquals(map.get("bar"), 42);
assertFalse(rs.next());
}
}

public static void test_variant_struct_with_variant() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) {

stmt.execute("CREATE TABLE tab1 (col1 INTEGER, col2 STRUCT(s1 INTEGER, s2 VARIANT))");
stmt.execute("INSERT INTO tab1 VALUES(41, row(42, 43))");
stmt.execute("INSERT INTO tab1 VALUES(44, row(45, 'foo'))");

try (ResultSet rs = stmt.executeQuery("SELECT col2 FROM tab1 ORDER BY col1")) {
assertTrue(rs.next());
{
DuckDBStruct struct = (DuckDBStruct) rs.getObject(1);
Map<?, ?> map = struct.getMap();
assertEquals(map.size(), 2);
assertEquals(map.get("s1"), 42);
assertEquals(map.get("s2"), 43);
}
assertTrue(rs.next());
{
DuckDBStruct struct = (DuckDBStruct) rs.getObject(1);
Map<?, ?> map = struct.getMap();
assertEquals(map.size(), 2);
assertEquals(map.get("s1"), 45);
assertEquals(map.get("s2"), "foo");
}
assertFalse(rs.next());
}
}
}
}
Loading