diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala index 8349329841..149a98d8bb 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala @@ -24,7 +24,7 @@ import java.util.Locale import scala.jdk.CollectionConverters._ import org.apache.spark.SparkException -import org.apache.spark.sql.comet.{CometNativeExec, CometNativeWriteExec} +import org.apache.spark.sql.comet.{CometNativeExec, CometNativeWriteExec, CometPlan, CometSparkToColumnarExec} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, WriteFilesExec} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -160,13 +160,21 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec // Get the child plan from the WriteFilesExec or use the child directly val childPlan = op.child match { case writeFiles: WriteFilesExec => - // The WriteFilesExec child should already be a Comet operator writeFiles.child case other => - // Fallback: use the child directly other } + // Wrap with CometSparkToColumnarExec if child is not already a Comet operator. + // This ensures the input to CometNativeWriteExec is always in Arrow format. + // CometSparkToColumnarExec handles conversion from both: + // - Row-based Spark operators (via rowToArrowBatchIter) + // - Columnar Spark operators like RangeExec (via columnarBatchToArrowBatchIter) + val wrappedChild = childPlan match { + case _: CometPlan => childPlan // Already produces Arrow batches + case _ => CometSparkToColumnarExec(childPlan) // Convert Spark format to Arrow + } + // Create FileCommitProtocol for atomic writes val jobId = java.util.UUID.randomUUID().toString val committer = @@ -189,7 +197,7 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec throw new SparkException(s"Could not instantiate FileCommitProtocol: ${e.getMessage}") } - CometNativeWriteExec(nativeOp, childPlan, outputPath, committer, jobId) + CometNativeWriteExec(nativeOp, wrappedChild, outputPath, committer, jobId) } private def parseCompressionCodec(cmd: InsertIntoHadoopFsRelationCommand) = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala index f153a691ef..d2b2793572 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala @@ -138,18 +138,10 @@ case class CometNativeWriteExec( } override def doExecuteColumnar(): RDD[ColumnarBatch] = { - // Get the input data from the child operator - val childRDD = if (child.supportsColumnar) { - child.executeColumnar() - } else { - // If child doesn't support columnar, convert to columnar - child.execute().mapPartitionsInternal { _ => - // TODO this could delegate to CometRowToColumnar, but maybe Comet - // does not need to support this case? - throw new UnsupportedOperationException( - "Row-based child operators not yet supported for native write") - } - } + // Child is guaranteed to be a CometPlan (either already a Comet operator or wrapped + // with CometSparkToColumnarExec in CometDataWritingCommand.createExec()). + // This ensures we always receive Arrow-format batches. + val childRDD = child.executeColumnar() // Capture metadata before the transformation val numPartitions = childRDD.getNumPartitions @@ -203,7 +195,7 @@ case class CometNativeWriteExec( val execIterator = new CometExecIterator( CometExec.newIterId, - Seq(iter), + Seq(iter), // Child already produces Arrow batches via CometSparkToColumnarExec numOutputCols, planBytes, nativeMetrics, diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala index 3ae7f949ab..95412b7497 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala @@ -228,4 +228,113 @@ class CometParquetWriterSuite extends CometTestBase { } } } + + test("parquet write with spark.range() as data source - with spark-to-arrow conversion") { + // Test that spark.range() works when CometSparkToColumnarExec is enabled to convert + // Spark's OnHeapColumnVector to Arrow format + withTempPath { dir => + val outputPath = new File(dir, "output.parquet").getAbsolutePath + + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", + CometConf.COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST.key -> "Range") { + + // Use a listener to capture the execution plan during write + var capturedPlan: Option[org.apache.spark.sql.execution.QueryExecution] = None + + val listener = new org.apache.spark.sql.util.QueryExecutionListener { + override def onSuccess( + funcName: String, + qe: org.apache.spark.sql.execution.QueryExecution, + durationNs: Long): Unit = { + if (funcName == "save" || funcName.contains("command")) { + capturedPlan = Some(qe) + } + } + + override def onFailure( + funcName: String, + qe: org.apache.spark.sql.execution.QueryExecution, + exception: Exception): Unit = {} + } + + spark.listenerManager.register(listener) + + try { + // spark.range() uses RangeExec which produces OnHeapColumnVector + // CometSparkToColumnarExec converts these to Arrow format + spark.range(1000).write.mode("overwrite").parquet(outputPath) + + // Wait for listener + val maxWaitTimeMs = 15000 + val checkIntervalMs = 100 + var iterations = 0 + + while (capturedPlan.isEmpty && iterations < maxWaitTimeMs / checkIntervalMs) { + Thread.sleep(checkIntervalMs) + iterations += 1 + } + + // Verify that CometNativeWriteExec was used + capturedPlan.foreach { qe => + val executedPlan = stripAQEPlan(qe.executedPlan) + + var nativeWriteCount = 0 + executedPlan.foreach { + case _: CometNativeWriteExec => + nativeWriteCount += 1 + case d: DataWritingCommandExec => + d.child.foreach { + case _: CometNativeWriteExec => + nativeWriteCount += 1 + case _ => + } + case _ => + } + + assert( + nativeWriteCount == 1, + s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${executedPlan.treeString}") + } + + // Verify the data was written correctly + val resultDf = spark.read.parquet(outputPath) + assert(resultDf.count() == 1000, "Expected 1000 rows to be written") + } finally { + spark.listenerManager.unregister(listener) + } + } + } + } + + test("parquet write with spark.range() - issue #2944") { + // This test verifies the fix for https://github.com/apache/datafusion-comet/issues/2944 + // CometDataWritingCommand.createExec() wraps non-Comet child operators with + // CometSparkToColumnarExec to convert Spark columnar batches to Arrow format. + // This ensures the native writer always receives Arrow-format data. + withTempPath { dir => + val outputPath = new File(dir, "output.parquet").getAbsolutePath + + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + + // spark.range() uses RangeExec which produces OnHeapColumnVector (not Arrow). + // CometSparkToColumnarExec is automatically inserted to convert to Arrow format. + // Without the fix, this would fail with: + // "Comet execution only takes Arrow Arrays, but got OnHeapColumnVector" + spark.range(1000).write.mode("overwrite").parquet(outputPath) + + // Verify the data was written correctly + val resultDf = spark.read.parquet(outputPath) + assert(resultDf.count() == 1000, "Expected 1000 rows to be written") + } + } + } }