diff --git a/writers/streamingbatchwriter/streamingbatchwriter.go b/writers/streamingbatchwriter/streamingbatchwriter.go index 3951a01a9c..ff4d9a75d8 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter.go +++ b/writers/streamingbatchwriter/streamingbatchwriter.go @@ -154,6 +154,31 @@ func (w *StreamingBatchWriter) Flush(context.Context) error { return nil // not checked below } +func (w *StreamingBatchWriter) flushInsertWorkers(ctx context.Context) error { + w.workersLock.RLock() + workers := make([]*streamingWorkerManager[*message.WriteInsert], 0, len(w.insertWorkers)) + for _, worker := range w.insertWorkers { + workers = append(workers, worker) + } + w.workersLock.RUnlock() + + for _, worker := range workers { + done := make(chan bool) + select { + case <-ctx.Done(): + return ctx.Err() + case worker.flush <- done: + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + } + } + return nil +} + + func (w *StreamingBatchWriter) Close(context.Context) error { w.workersLock.Lock() defer w.workersLock.Unlock() @@ -323,6 +348,10 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err return nil case *message.WriteDeleteRecord: + // flush pending inserts and table buffers before deletions + if err := w.flushInsertWorkers(ctx); err != nil { + return err + } w.workersLock.Lock() defer w.workersLock.Unlock() @@ -331,7 +360,6 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err return nil } - // TODO: flush all workers for nested tables as well (See https://github.com/cloudquery/plugin-sdk/issues/1296) w.deleteRecordWorker = &streamingWorkerManager[*message.WriteDeleteRecord]{ ch: make(chan *message.WriteDeleteRecord), writeFunc: w.client.DeleteRecords, @@ -516,3 +544,6 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup) } } } + + + diff --git a/writers/streamingbatchwriter/streamingbatchwriter_test.go b/writers/streamingbatchwriter/streamingbatchwriter_test.go index 9718c21919..95494b92e0 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter_test.go +++ b/writers/streamingbatchwriter/streamingbatchwriter_test.go @@ -624,3 +624,61 @@ func requireErrorCount(t *testing.T, errCh chan error, expectedMin, expectedMax } return -1 } + +func TestDeleteRecordFlushesPendingInserts(t *testing.T) { + t.Parallel() + + ctx := context.Background() + errCh := make(chan error, 10) + + testClient := newClient() + wr, err := New(testClient, WithBatchSizeRows(1000000)) // large batch to avoid auto-flush + if err != nil { + t.Fatal(err) + } + + // Create a table for insert + insertTable := &schema.Table{ + Name: "child_table", + Columns: []schema.Column{ + { + Name: "id", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + } + + // Build insert record + bldr := array.NewRecordBuilder(memory.DefaultAllocator, insertTable.ToArrowSchema()) + bldr.Field(0).(*array.Int64Builder).Append(1) + record := bldr.NewRecord() + + md := arrow.NewMetadata( + []string{schema.MetadataTableName}, + []string{insertTable.Name}, + ) + newSchema := arrow.NewSchema( + record.Schema().Fields(), + &md, + ) + + record = array.NewRecord(newSchema, record.Columns(), record.NumRows()) + + // Send insert + if err := wr.startWorker(ctx, errCh, &message.WriteInsert{Record: record}); err != nil { + t.Fatal(err) + } + + // send delete record to trigger flush + del := &message.WriteDeleteRecord{ + DeleteRecord: message.DeleteRecord{ + TableName: insertTable.Name, + }, + } + + if err := wr.startWorker(ctx, errCh, del); err != nil { + t.Fatal(err) + } + waitForLength(t, testClient.MessageLen, messageTypeInsert, 1) + _ = wr.Close(ctx) +}