Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ protected boolean shouldSetResourceName() {

private final DataStreamsTransactionTracker.TransactionSourceReader
DSM_TRANSACTION_SOURCE_READER =
(source, headerName) -> getRequestHeader((REQUEST) source, headerName);
(source, headerName) -> {
try {
return getRequestHeader((REQUEST) source, headerName);
} catch (Throwable ignored) {
return null;
}
};

public AgentSpan onRequest(final AgentSpan span, final REQUEST request) {
if (request != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,13 @@ protected AgentSpanContext startInferredProxySpan(Context context, AgentSpanCont

private final DataStreamsTransactionTracker.TransactionSourceReader
DSM_TRANSACTION_SOURCE_READER =
(source, headerName) -> getRequestHeader((REQUEST) source, headerName);
(source, headerName) -> {
try {
return getRequestHeader((REQUEST) source, headerName);
} catch (Throwable ignored) {
return null;
}
};

public AgentSpan onRequest(
final AgentSpan span,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ record =
.trackTransaction(
span,
DataStreamsTransactionExtractor.Type.KAFKA_PRODUCE_HEADERS,
record,
record.headers(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok but I would protect rather this:

DSM_TRANSACTION_SOURCE_READER =
          (source, headerName) -> new String(((Headers) source).lastHeader(headerName).value());

or better type the argument to prevent in the future this kind of issues

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added protective try/catch for all header readers to keep it simple for now.

DSM_TRANSACTION_SOURCE_READER);
return activateSpan(span);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datadog.trace.api.datastreams.DataStreamsTags
import datadog.trace.api.datastreams.DataStreamsTransactionExtractor
import datadog.trace.instrumentation.kafka_common.ClusterIdHolder

import static datadog.trace.agent.test.utils.TraceUtils.basicSpan
Expand Down Expand Up @@ -1047,6 +1048,126 @@ abstract class KafkaClientTestBase extends VersionedNamingTestBase {
producer?.close()
}

def "test producer DSM transaction tracking extracts transaction id from headers"() {
setup:
if (!isDataStreamsEnabled()) {
return
}

injectEnvConfig("DD_DATA_STREAMS_ENABLED", "true")

// Configure a DSM transaction extractor for KAFKA_PRODUCE_HEADERS
def extractorsByTypeField = TEST_DATA_STREAMS_MONITORING.getClass().getDeclaredField("extractorsByType")
extractorsByTypeField.setAccessible(true)
def oldExtractorsByType = extractorsByTypeField.get(TEST_DATA_STREAMS_MONITORING)

def extractor = new DataStreamsTransactionExtractor() {
String getName() {
return "kafka-produce-test"
}
DataStreamsTransactionExtractor.Type getType() {
return DataStreamsTransactionExtractor.Type.KAFKA_PRODUCE_HEADERS
}
String getValue() {
return "x-transaction-id"
}
}
def extractorsByType = new EnumMap<>(DataStreamsTransactionExtractor.Type)
extractorsByType.put(DataStreamsTransactionExtractor.Type.KAFKA_PRODUCE_HEADERS, [extractor])
extractorsByTypeField.set(TEST_DATA_STREAMS_MONITORING, extractorsByType)

def senderProps = KafkaTestUtils.senderProps(embeddedKafka.getBrokersAsString())
def producer = new KafkaProducer<>(senderProps, new StringSerializer(), new StringSerializer())

def headers = new RecordHeaders()
headers.add(new RecordHeader("x-transaction-id", "txn-123".getBytes(StandardCharsets.UTF_8)))

when:
def record = new ProducerRecord(SHARED_TOPIC, 0, null, "test-dsm-transaction", headers)
producer.send(record).get()

then:
TEST_WRITER.waitForTraces(1)
def producedSpan = TEST_WRITER[0][0]
producedSpan.getTag(Tags.DSM_TRANSACTION_ID) == "txn-123"
producedSpan.getTag(Tags.DSM_TRANSACTION_CHECKPOINT) == "kafka-produce-test"

cleanup:
extractorsByTypeField?.set(TEST_DATA_STREAMS_MONITORING, oldExtractorsByType)
producer?.close()
}

def "test consumer DSM transaction tracking extracts transaction id from headers"() {
setup:
if (!isDataStreamsEnabled()) {
return
}

injectEnvConfig("DD_DATA_STREAMS_ENABLED", "true")

// Configure a DSM transaction extractor for KAFKA_CONSUME_HEADERS
def extractorsByTypeField = TEST_DATA_STREAMS_MONITORING.getClass().getDeclaredField("extractorsByType")
extractorsByTypeField.setAccessible(true)
def oldExtractorsByType = extractorsByTypeField.get(TEST_DATA_STREAMS_MONITORING)

def extractor = new DataStreamsTransactionExtractor() {
String getName() {
return "kafka-consume-test"
}
DataStreamsTransactionExtractor.Type getType() {
return DataStreamsTransactionExtractor.Type.KAFKA_CONSUME_HEADERS
}
String getValue() {
return "x-transaction-id"
}
}
def extractorsByType = new EnumMap<>(DataStreamsTransactionExtractor.Type)
extractorsByType.put(DataStreamsTransactionExtractor.Type.KAFKA_CONSUME_HEADERS, [extractor])
extractorsByTypeField.set(TEST_DATA_STREAMS_MONITORING, extractorsByType)

def kafkaPartition = 0
def consumerProperties = KafkaTestUtils.consumerProps("sender", "false", embeddedKafka)
consumerProperties.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
def consumer = new KafkaConsumer<String, String>(consumerProperties)

def senderProps = KafkaTestUtils.senderProps(embeddedKafka.getBrokersAsString())
def producer = new KafkaProducer<>(senderProps, new StringSerializer(), new StringSerializer())

consumer.assign(Arrays.asList(new TopicPartition(SHARED_TOPIC, kafkaPartition)))

def headers = new RecordHeaders()
headers.add(new RecordHeader("x-transaction-id", "txn-456".getBytes(StandardCharsets.UTF_8)))

when:
def record = new ProducerRecord(SHARED_TOPIC, kafkaPartition, null, "test-dsm-consume-transaction", headers)
producer.send(record).get()

then:
TEST_WRITER.waitForTraces(1)
def pollResult = KafkaTestUtils.getRecords(consumer)
def recs = pollResult.records(new TopicPartition(SHARED_TOPIC, kafkaPartition)).iterator()
recs.hasNext()
recs.next().value() == "test-dsm-consume-transaction"

// The consume span is created by TracingIterator when iterating over records
// Find the consumer span with the DSM transaction tags
TEST_WRITER.waitForTraces(2)
def allTraces = TEST_WRITER.toArray() as List<List<DDSpan>>
def consumerSpan = allTraces.collectMany {
it
}.find {
it.getTag(Tags.DSM_TRANSACTION_ID) == "txn-456"
}
consumerSpan != null
consumerSpan.getTag(Tags.DSM_TRANSACTION_ID) == "txn-456"
consumerSpan.getTag(Tags.DSM_TRANSACTION_CHECKPOINT) == "kafka-consume-test"

cleanup:
extractorsByTypeField?.set(TEST_DATA_STREAMS_MONITORING, oldExtractorsByType)
consumer?.close()
producer?.close()
}

def containerProperties() {
try {
// Different class names for test and latestDepTest.
Expand All @@ -1057,12 +1178,12 @@ abstract class KafkaClientTestBase extends VersionedNamingTestBase {
}

def producerSpan(
TraceAssert trace,
Map<String, ?> config,
DDSpan parentSpan = null,
boolean partitioned = true,
boolean tombstone = false,
String schema = null
TraceAssert trace,
Map<String, ?> config,
DDSpan parentSpan = null,
boolean partitioned = true,
boolean tombstone = false,
String schema = null
) {
trace.span {
serviceName service()
Expand Down Expand Up @@ -1104,8 +1225,8 @@ abstract class KafkaClientTestBase extends VersionedNamingTestBase {
}

def queueSpan(
TraceAssert trace,
DDSpan parentSpan = null
TraceAssert trace,
DDSpan parentSpan = null
) {
trace.span {
serviceName splitByDestination() ? "$SHARED_TOPIC" : serviceForTimeInQueue()
Expand All @@ -1128,12 +1249,12 @@ abstract class KafkaClientTestBase extends VersionedNamingTestBase {
}

def consumerSpan(
TraceAssert trace,
Map<String, Object> config,
DDSpan parentSpan = null,
Range offset = 0..0,
boolean tombstone = false,
boolean distributedRootSpan = !hasQueueSpan()
TraceAssert trace,
Map<String, Object> config,
DDSpan parentSpan = null,
Range offset = 0..0,
boolean tombstone = false,
boolean distributedRootSpan = !hasQueueSpan()
) {
trace.span {
serviceName service()
Expand Down Expand Up @@ -1169,12 +1290,12 @@ abstract class KafkaClientTestBase extends VersionedNamingTestBase {
}

def pollSpan(
TraceAssert trace,
int recordCount = 1,
DDSpan parentSpan = null,
Range offset = 0..0,
boolean tombstone = false,
boolean distributedRootSpan = !hasQueueSpan()
TraceAssert trace,
int recordCount = 1,
DDSpan parentSpan = null,
Range offset = 0..0,
boolean tombstone = false,
boolean distributedRootSpan = !hasQueueSpan()
) {
trace.span {
serviceName Config.get().getServiceName()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ private Utils() {} // prevent instantiation

public static DataStreamsTransactionTracker.TransactionSourceReader
DSM_TRANSACTION_SOURCE_READER =
(source, headerName) -> new String(((Headers) source).lastHeader(headerName).value());
(source, headerName) -> {
try {
return new String(((Headers) source).lastHeader(headerName).value());
} catch (Throwable ignored) {
return null;
}
};

// this method is used in kafka-clients and kafka-streams instrumentations
public static long computePayloadSizeBytes(ConsumerRecord<?, ?> val) {
Expand Down