diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index ad2620a532f24..43cdb4d1cba88 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -79,3 +79,7 @@ name = "slice" [[bench]] harness = false name = "substring" + +[[bench]] +harness = false +name = "unhex" diff --git a/datafusion/spark/benches/unhex.rs b/datafusion/spark/benches/unhex.rs new file mode 100644 index 0000000000000..f5ded8d8d7b85 --- /dev/null +++ b/datafusion/spark/benches/unhex.rs @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{ + Array, LargeStringArray, LargeStringBuilder, StringArray, StringBuilder, + StringViewArray, StringViewBuilder, +}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::math::unhex::SparkUnhex; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn generate_hex_string_data(size: usize, null_density: f32) -> StringArray { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = StringBuilder::with_capacity(size, 0); + let hex_chars = b"0123456789abcdefABCDEF"; + + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(2..=100); + let s: String = std::iter::repeat_with(|| { + hex_chars[rng.random_range(0..hex_chars.len())] as char + }) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn generate_hex_large_string_data(size: usize, null_density: f32) -> LargeStringArray { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = LargeStringBuilder::with_capacity(size, 0); + let hex_chars = b"0123456789abcdefABCDEF"; + + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(2..=100); + let s: String = std::iter::repeat_with(|| { + hex_chars[rng.random_range(0..hex_chars.len())] as char + }) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn generate_hex_utf8view_data(size: usize, null_density: f32) -> StringViewArray { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = StringViewBuilder::with_capacity(size); + let hex_chars = b"0123456789abcdefABCDEF"; + + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(2..=100); + let s: String = std::iter::repeat_with(|| { + hex_chars[rng.random_range(0..hex_chars.len())] as char + }) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn run_benchmark(c: &mut Criterion, name: &str, size: usize, array: Arc) { + let unhex_func = SparkUnhex::new(); + let args = vec![ColumnarValue::Array(array)]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + unhex_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Binary, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + // Benchmark with hex string + for &size in &sizes { + let data = generate_hex_string_data(size, null_density); + run_benchmark(c, "unhex_utf8", size, Arc::new(data)); + } + + // Benchmark with hex large string + for &size in &sizes { + let data = generate_hex_large_string_data(size, null_density); + run_benchmark(c, "unhex_large_utf8", size, Arc::new(data)); + } + + // Benchmark with hex Utf8View + for &size in &sizes { + let data = generate_hex_utf8view_data(size, null_density); + run_benchmark(c, "unhex_utf8view", size, Arc::new(data)); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 1422eb250d939..bf212a8219d02 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -22,6 +22,7 @@ pub mod hex; pub mod modulus; pub mod rint; pub mod trigonometry; +pub mod unhex; pub mod width_bucket; use datafusion_expr::ScalarUDF; @@ -35,6 +36,7 @@ make_udf_function!(hex::SparkHex, hex); make_udf_function!(modulus::SparkMod, modulus); make_udf_function!(modulus::SparkPmod, pmod); make_udf_function!(rint::SparkRint, rint); +make_udf_function!(unhex::SparkUnhex, unhex); make_udf_function!(width_bucket::SparkWidthBucket, width_bucket); make_udf_function!(trigonometry::SparkCsc, csc); make_udf_function!(trigonometry::SparkSec, sec); @@ -57,6 +59,7 @@ pub mod expr_fn { "Returns the double value that is closest in value to the argument and is equal to a mathematical integer.", arg1 )); + export_functions!((unhex, "Converts hexadecimal string to binary.", arg1)); export_functions!((width_bucket, "Returns the bucket number into which the value of this expression would fall after being evaluated.", arg1 arg2 arg3 arg4)); export_functions!((csc, "Returns the cosecant of expr.", arg1)); export_functions!((sec, "Returns the secant of expr.", arg1)); @@ -71,6 +74,7 @@ pub fn functions() -> Vec> { modulus(), pmod(), rint(), + unhex(), width_bucket(), csc(), sec(), diff --git a/datafusion/spark/src/function/math/unhex.rs b/datafusion/spark/src/function/math/unhex.rs new file mode 100644 index 0000000000000..dee532d818f83 --- /dev/null +++ b/datafusion/spark/src/function/math/unhex.rs @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, BinaryBuilder}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::types::logical_string; +use datafusion_common::utils::take_function_args; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnhex { + signature: Signature, +} + +impl Default for SparkUnhex { + fn default() -> Self { + Self::new() + } +} + +impl SparkUnhex { + pub fn new() -> Self { + let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + + Self { + signature: Signature::coercible(vec![string], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkUnhex { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "unhex" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Binary) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_unhex(&args.args) + } +} + +#[inline] +fn hex_nibble(c: u8) -> Option { + match c { + b'0'..=b'9' => Some(c - b'0'), + b'a'..=b'f' => Some(c - b'a' + 10), + b'A'..=b'F' => Some(c - b'A' + 10), + _ => None, + } +} + +/// Decodes a hex-encoded byte slice into binary data. +/// Returns `true` if decoding succeeded, `false` if the input contains invalid hex characters. +fn unhex_common(bytes: &[u8], out: &mut Vec) -> bool { + if bytes.is_empty() { + return true; + } + + let mut i = 0usize; + + // If the hex string length is odd, implicitly left-pad with '0'. + if (bytes.len() & 1) == 1 { + match hex_nibble(bytes[0]) { + // Equivalent to (0 << 4) | lo + Some(lo) => out.push(lo), + None => return false, + } + i = 1; + } + + while i + 1 < bytes.len() { + match (hex_nibble(bytes[i]), hex_nibble(bytes[i + 1])) { + (Some(hi), Some(lo)) => out.push((hi << 4) | lo), + _ => return false, + } + i += 2; + } + + true +} + +/// Converts an iterator of hex strings to a binary array. +fn unhex_array( + iter: I, + len: usize, + capacity: usize, +) -> Result +where + I: Iterator>, + T: AsRef, +{ + let mut builder = BinaryBuilder::with_capacity(len, capacity); + let mut buffer = Vec::new(); + + for v in iter { + if let Some(s) = v { + buffer.clear(); + buffer.reserve(s.as_ref().len().div_ceil(2)); + if unhex_common(s.as_ref().as_bytes(), &mut buffer) { + builder.append_value(&buffer); + } else { + builder.append_null(); + } + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) +} + +/// Convert a single hex string to binary +fn unhex_scalar(s: &str) -> Option> { + let mut buffer = Vec::with_capacity(s.len().div_ceil(2)); + if unhex_common(s.as_bytes(), &mut buffer) { + Some(buffer) + } else { + None + } +} + +fn spark_unhex(args: &[ColumnarValue]) -> Result { + let [args] = take_function_args("unhex", args)?; + + match args { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let array = as_string_array(array)?; + let capacity = array.values().len().div_ceil(2); + Ok(ColumnarValue::Array(unhex_array( + array.iter(), + array.len(), + capacity, + )?)) + } + DataType::Utf8View => { + let array = as_string_view_array(array)?; + // Estimate capacity since StringViewArray data can be scattered or inlined. + let capacity = array.len() * 32; + Ok(ColumnarValue::Array(unhex_array( + array.iter(), + array.len(), + capacity, + )?)) + } + DataType::LargeUtf8 => { + let array = as_large_string_array(array)?; + let capacity = array.values().len().div_ceil(2); + Ok(ColumnarValue::Array(unhex_array( + array.iter(), + array.len(), + capacity, + )?)) + } + _ => exec_err!( + "unhex only supports string argument, but got: {}", + array.data_type() + ), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Utf8(None) + | ScalarValue::Utf8View(None) + | ScalarValue::LargeUtf8(None) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + ScalarValue::Utf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(unhex_scalar(s)))) + } + _ => { + exec_err!( + "unhex only supports string argument, but got: {}", + sv.data_type() + ) + } + }, + } +} diff --git a/datafusion/sqllogictest/test_files/spark/math/unhex.slt b/datafusion/sqllogictest/test_files/spark/math/unhex.slt new file mode 100644 index 0000000000000..051d8826c8a6c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/unhex.slt @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Basic hex string +query ? +SELECT unhex('537061726B2053514C'); +---- +537061726b2053514c + +query T +SELECT arrow_cast(unhex('537061726B2053514C'), 'Utf8'); +---- +Spark SQL + +# Lowercase hex +query ? +SELECT unhex('616263'); +---- +616263 + +query T +SELECT arrow_cast(unhex('616263'), 'Utf8'); +---- +abc + +# Odd length hex (left pad with 0) +query ? +SELECT unhex(a) FROM VALUES ('1A2B3'), ('1'), ('ABC'), ('123') AS t(a); +---- +01a2b3 +01 +0abc +0123 + +# Null input +query ? +SELECT unhex(NULL); +---- +NULL + +# Invalid hex characters +query ? +SELECT unhex('GGHH'); +---- +NULL + +# Empty hex string +query T +SELECT arrow_cast(unhex(''), 'Utf8'); +---- +(empty) + +# Array with mixed case +query ? +SELECT unhex(a) FROM VALUES ('4a4B4c'), ('F'), ('A'), ('AbCdEf'), ('123abc'), ('41 42'), ('00'), ('FF') AS t(a); +---- +4a4b4c +0f +0a +abcdef +123abc +NULL +00 +ff + +# LargeUtf8 type +statement ok +CREATE TABLE t_large_utf8 AS VALUES (arrow_cast('414243', 'LargeUtf8')), (NULL); + +query ? +SELECT unhex(column1) FROM t_large_utf8; +---- +414243 +NULL + +# Utf8View type +statement ok +CREATE TABLE t_utf8view AS VALUES (arrow_cast('414243', 'Utf8View')), (NULL); + +query ? +SELECT unhex(column1) FROM t_utf8view; +---- +414243 +NULL