From 703cba3eef90e991205ed99f7019b033970daf8a Mon Sep 17 00:00:00 2001 From: Khaled Yakdan Date: Wed, 18 Feb 2026 13:36:21 +0100 Subject: [PATCH] feat: add float/double comparison tracking to instrumentor --- docs/advanced.md | 2 +- .../instrumentor/TraceDataFlowInstrumentor.kt | 46 ++++++++++++++--- .../jazzer/runtime/TraceCmpHooks.java | 50 +++++++++++++++++++ .../runtime/TraceDataFlowNativeCallbacks.java | 23 +++++++++ .../MockTraceDataFlowCallbacks.java | 16 ++++++ .../TraceDataFlowInstrumentationTarget.java | 27 ++++++++++ .../TraceDataFlowInstrumentationTest.kt | 17 ++++++- .../jazzer/runtime/TraceCmpHooksTest.java | 44 ++++++++++++++++ tests/BUILD.bazel | 22 ++++++++ .../com/example/FloatDoubleCmpFuzzer.java | 31 ++++++++++++ 10 files changed, 269 insertions(+), 9 deletions(-) create mode 100644 tests/src/test/java/com/example/FloatDoubleCmpFuzzer.java diff --git a/docs/advanced.md b/docs/advanced.md index 96ea08e24..a3432685f 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -93,7 +93,7 @@ These hooks correspond to [clang's data flow hooks](https://clang.llvm.org/docs/ The particular instrumentation types to apply can be specified using the `--trace` flag, which accepts the following values: * `cov`: AFL-style edge coverage -* `cmp`: compares (int, long, String) and switch cases +* `cmp`: compares (int, long, float, double, String) and switch cases * `div`: divisors in integer divisions * `gep`: constant array indexes * `indir`: call through `Method#invoke` diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentor.kt b/src/main/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentor.kt index d4399c6a8..301715b8a 100644 --- a/src/main/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentor.kt +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentor.kt @@ -32,6 +32,13 @@ import org.objectweb.asm.tree.MethodNode import org.objectweb.asm.tree.TableSwitchInsnNode import org.objectweb.asm.tree.VarInsnNode +private val TRACE_COMPARISON_METHODS = + setOf( + "traceCmpLongWrapper", + "traceCmpDoubleWrapper", + "traceCmpFloatWrapper", + ) + internal class TraceDataFlowInstrumentor( private val types: Set, private val callbackInternalClassName: String = "com/code_intelligence/jazzer/runtime/TraceDataFlowNativeCallbacks", @@ -65,6 +72,18 @@ internal class TraceDataFlowInstrumentor( method.instructions.insertBefore(inst, longCmpInstrumentation()) method.instructions.remove(inst) } + Opcodes.DCMPG, Opcodes.DCMPL -> { + if (InstrumentationType.CMP !in types) continue@loop + val nanResult = if (inst.opcode == Opcodes.DCMPG) 1 else -1 + method.instructions.insertBefore(inst, doubleCmpInstrumentation(nanResult)) + method.instructions.remove(inst) + } + Opcodes.FCMPG, Opcodes.FCMPL -> { + if (InstrumentationType.CMP !in types) continue@loop + val nanResult = if (inst.opcode == Opcodes.FCMPG) 1 else -1 + method.instructions.insertBefore(inst, floatCmpInstrumentation(nanResult)) + method.instructions.remove(inst) + } Opcodes.IF_ICMPEQ, Opcodes.IF_ICMPNE, Opcodes.IF_ICMPLT, Opcodes.IF_ICMPLE, Opcodes.IF_ICMPGT, Opcodes.IF_ICMPGE, @@ -78,15 +97,14 @@ internal class TraceDataFlowInstrumentor( -> { if (InstrumentationType.CMP !in types) continue@loop // The IF* opcodes are often used to branch based on the result of a compare - // instruction for a type other than int. The operands of this compare will - // already be reported via the instrumentation above (for non-floating point - // numbers) and the follow-up compare does not provide a good signal as all - // operands will be in {-1, 0, 1}. Skip instrumentation for it. - if (inst.previous?.opcode in listOf(Opcodes.DCMPG, Opcodes.DCMPL, Opcodes.FCMPG, Opcodes.DCMPL) || - (inst.previous as? MethodInsnNode)?.name == "traceCmpLongWrapper" - ) { + // instruction for a type other than int (long, float, double). The operands + // of this compare will already be reported via the instrumentation above and + // the follow-up compare does not provide a good signal as all operands will + // be in {-1, 0, 1}. Skip instrumentation for it. + if ((inst.previous as? MethodInsnNode)?.name in TRACE_COMPARISON_METHODS) { continue@loop } + method.instructions.insertBefore(inst, ifInstrumentation()) } Opcodes.LOOKUPSWITCH, Opcodes.TABLESWITCH -> { @@ -256,6 +274,20 @@ internal class TraceDataFlowInstrumentor( add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpLongWrapper", "(JJI)I", false)) } + private fun doubleCmpInstrumentation(nanResult: Int) = + InsnList().apply { + add(LdcInsnNode(nanResult)) + pushFakePc() + add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpDoubleWrapper", "(DDII)I", false)) + } + + private fun floatCmpInstrumentation(nanResult: Int) = + InsnList().apply { + add(LdcInsnNode(nanResult)) + pushFakePc() + add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpFloatWrapper", "(FFII)I", false)) + } + private fun intCmpInstrumentation() = InsnList().apply { add(InsnNode(Opcodes.DUP2)) diff --git a/src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java b/src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java index dd72027ad..43f513fa9 100644 --- a/src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java +++ b/src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java @@ -85,6 +85,56 @@ public static void integerCompareTo( ((Number) thisObject).intValue(), ((Number) arguments[0]).intValue(), hookId); } + @MethodHook( + type = HookType.BEFORE, + targetClassName = "java.lang.Float", + targetMethod = "compare", + targetMethodDescriptor = "(FF)I") + public static void floatCompare( + MethodHandle method, Object alwaysNull, Object[] arguments, int hookId) { + TraceDataFlowNativeCallbacks.traceCmpInt( + Float.floatToRawIntBits((float) arguments[0]), + Float.floatToRawIntBits((float) arguments[1]), + hookId); + } + + @MethodHook( + type = HookType.BEFORE, + targetClassName = "java.lang.Double", + targetMethod = "compare", + targetMethodDescriptor = "(DD)I") + public static void doubleCompare( + MethodHandle method, Object alwaysNull, Object[] arguments, int hookId) { + TraceDataFlowNativeCallbacks.traceCmpLong( + Double.doubleToRawLongBits((double) arguments[0]), + Double.doubleToRawLongBits((double) arguments[1]), + hookId); + } + + @MethodHook( + type = HookType.BEFORE, + targetClassName = "java.lang.Float", + targetMethod = "compareTo", + targetMethodDescriptor = "(Ljava/lang/Float;)I") + public static void floatCompareTo( + MethodHandle method, Float thisObject, Object[] arguments, int hookId) { + TraceDataFlowNativeCallbacks.traceCmpInt( + Float.floatToRawIntBits(thisObject), Float.floatToRawIntBits((float) arguments[0]), hookId); + } + + @MethodHook( + type = HookType.BEFORE, + targetClassName = "java.lang.Double", + targetMethod = "compareTo", + targetMethodDescriptor = "(Ljava/lang/Double;)I") + public static void doubleCompareTo( + MethodHandle method, Double thisObject, Object[] arguments, int hookId) { + TraceDataFlowNativeCallbacks.traceCmpLong( + Double.doubleToRawLongBits(thisObject), + Double.doubleToRawLongBits((double) arguments[0]), + hookId); + } + @MethodHook( type = HookType.BEFORE, targetClassName = "java.lang.Long", diff --git a/src/main/java/com/code_intelligence/jazzer/runtime/TraceDataFlowNativeCallbacks.java b/src/main/java/com/code_intelligence/jazzer/runtime/TraceDataFlowNativeCallbacks.java index 708db267c..71d0bd38f 100644 --- a/src/main/java/com/code_intelligence/jazzer/runtime/TraceDataFlowNativeCallbacks.java +++ b/src/main/java/com/code_intelligence/jazzer/runtime/TraceDataFlowNativeCallbacks.java @@ -81,6 +81,24 @@ public static int traceCmpLongWrapper(long arg1, long arg2, int pc) { return Long.compare(arg1, arg2); } + public static int traceCmpDoubleWrapper(double arg1, double arg2, int nanResult, int pc) { + traceCmpLong(Double.doubleToRawLongBits(arg1), Double.doubleToRawLongBits(arg2), pc); + if (Double.isNaN(arg1) || Double.isNaN(arg2)) return nanResult; + // Mirror DCMPG/DCMPL semantics: in particular, -0.0 == +0.0 must yield 0. + if (arg1 > arg2) return 1; + if (arg1 == arg2) return 0; + return -1; + } + + public static int traceCmpFloatWrapper(float arg1, float arg2, int nanResult, int pc) { + traceCmpInt(Float.floatToRawIntBits(arg1), Float.floatToRawIntBits(arg2), pc); + if (Float.isNaN(arg1) || Float.isNaN(arg2)) return nanResult; + // Mirror FCMPG/FCMPL semantics: in particular, -0.0 == +0.0 must yield 0. + if (arg1 > arg2) return 1; + if (arg1 == arg2) return 0; + return -1; + } + // The caller has to ensure that arg1 and arg2 have the same class. public static void traceGenericCmp(Object arg1, Object arg2, int pc) { if (arg1 instanceof CharSequence) { @@ -89,6 +107,11 @@ public static void traceGenericCmp(Object arg1, Object arg2, int pc) { traceCmpInt((int) arg1, (int) arg2, pc); } else if (arg1 instanceof Long) { traceCmpLong((long) arg1, (long) arg2, pc); + } else if (arg1 instanceof Float) { + traceCmpInt(Float.floatToRawIntBits((float) arg1), Float.floatToRawIntBits((float) arg2), pc); + } else if (arg1 instanceof Double) { + traceCmpLong( + Double.doubleToRawLongBits((double) arg1), Double.doubleToRawLongBits((double) arg2), pc); } else if (arg1 instanceof Short) { traceCmpInt((short) arg1, (short) arg2, pc); } else if (arg1 instanceof Byte) { diff --git a/src/test/java/com/code_intelligence/jazzer/instrumentor/MockTraceDataFlowCallbacks.java b/src/test/java/com/code_intelligence/jazzer/instrumentor/MockTraceDataFlowCallbacks.java index b9d2341ab..f57b14269 100644 --- a/src/test/java/com/code_intelligence/jazzer/instrumentor/MockTraceDataFlowCallbacks.java +++ b/src/test/java/com/code_intelligence/jazzer/instrumentor/MockTraceDataFlowCallbacks.java @@ -109,4 +109,20 @@ public static int traceCmpLongWrapper(long value1, long value2, int pc) { // (behaviour is the same) return Long.compare(value1, value2); } + + public static int traceCmpDoubleWrapper(double value1, double value2, int nanResult, int pc) { + traceCmpLong(Double.doubleToRawLongBits(value1), Double.doubleToRawLongBits(value2), pc); + if (Double.isNaN(value1) || Double.isNaN(value2)) return nanResult; + if (value1 > value2) return 1; + if (value1 == value2) return 0; + return -1; + } + + public static int traceCmpFloatWrapper(float value1, float value2, int nanResult, int pc) { + traceCmpInt(Float.floatToRawIntBits(value1), Float.floatToRawIntBits(value2), pc); + if (Float.isNaN(value1) || Float.isNaN(value2)) return nanResult; + if (value1 > value2) return 1; + if (value1 == value2) return 0; + return -1; + } } diff --git a/src/test/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentationTarget.java b/src/test/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentationTarget.java index 118e7d3fa..b7e334419 100644 --- a/src/test/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentationTarget.java +++ b/src/test/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentationTarget.java @@ -37,6 +37,22 @@ public class TraceDataFlowInstrumentationTarget implements DynamicTestContract { volatile int int3 = 6; volatile int int4 = 5; + volatile double double1 = 1.5; + volatile double double2 = 1.5; + volatile double double3 = 2.5; + volatile double double4 = 3.5; + volatile double doubleNaN = Double.NaN; + volatile double doubleNegZero = -0.0; + volatile double doublePosZero = 0.0; + + volatile float float1 = 1.5f; + volatile float float2 = 1.5f; + volatile float float3 = 2.5f; + volatile float float4 = 3.5f; + volatile float floatNaN = Float.NaN; + volatile float floatNegZero = -0.0f; + volatile float floatPosZero = 0.0f; + volatile int switchValue = 1200; @SuppressWarnings("ReturnValueIgnored") @@ -47,6 +63,17 @@ public Map selfCheck() { results.put("longCompareEq", long1 == long2); results.put("longCompareNe", long3 != long4); + results.put("doubleCompareEq", double1 == double2); + results.put("doubleCompareNe", double3 != double4); + results.put("floatCompareEq", float1 == float2); + results.put("floatCompareNe", float3 != float4); + results.put("doubleCompareSignedZeroEq", doubleNegZero == doublePosZero); + results.put("floatCompareSignedZeroEq", floatNegZero == floatPosZero); + results.put("doubleCompareNaNLessFalse", !(doubleNaN < double1)); + results.put("doubleCompareNaNGreaterFalse", !(doubleNaN > double1)); + results.put("floatCompareNaNLessFalse", !(floatNaN < float1)); + results.put("floatCompareNaNGreaterFalse", !(floatNaN > float1)); + results.put("intCompareEq", int1 == int2); results.put("intCompareNe", int3 != int4); results.put("intCompareLt", int4 < int3); diff --git a/src/test/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentationTest.kt b/src/test/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentationTest.kt index 074486270..e109a5451 100644 --- a/src/test/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentationTest.kt +++ b/src/test/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentationTest.kt @@ -58,6 +58,20 @@ class TraceDataFlowInstrumentationTest { // long compares "LCMP: 1, 1", "LCMP: 2, 3", + // double compares + "LCMP: 4609434218613702656, 4609434218613702656", + "LCMP: 4612811918334230528, 4615063718147915776", + // float compares + "ICMP: 1069547520, 1069547520", + "ICMP: 1075838976, 1080033280", + // signed zero compares + "LCMP: -9223372036854775808, 0", + "ICMP: -2147483648, 0", + // NaN compares + "LCMP: 4609434218613702656, 9221120237041090560", + "LCMP: 4609434218613702656, 9221120237041090560", + "ICMP: 1069547520, 2143289344", + "ICMP: 1069547520, 2143289344", // int compares "ICMP: 4, 4", "ICMP: 5, 6", @@ -87,9 +101,10 @@ class TraceDataFlowInstrumentationTest { "ICMP: 3, 3", // doubleArray[4] == 4 "GEP: 4", + "LCMP: 4616189618054758400, 4616189618054758400", // floatArray[5] == 5 "GEP: 5", - "CICMP: 0, 0", + "ICMP: 1084227584, 1084227584", // intArray[6] == 6 "GEP: 6", "ICMP: 6, 6", diff --git a/src/test/java/com/code_intelligence/jazzer/runtime/TraceCmpHooksTest.java b/src/test/java/com/code_intelligence/jazzer/runtime/TraceCmpHooksTest.java index 776cd197c..521097b41 100644 --- a/src/test/java/com/code_intelligence/jazzer/runtime/TraceCmpHooksTest.java +++ b/src/test/java/com/code_intelligence/jazzer/runtime/TraceCmpHooksTest.java @@ -16,6 +16,8 @@ package com.code_intelligence.jazzer.runtime; +import static org.junit.Assert.assertEquals; + import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutorService; @@ -61,4 +63,46 @@ public void handlesNullValuesInArrayCompare() { TraceCmpHooks.arraysEquals(null, null, new Object[] {b1, b2}, 1, false); TraceCmpHooks.arraysCompare(null, null, new Object[] {b1, b2}, 1, 1); } + + @Test + public void traceCmpDoubleWrapperShouldMatchDcmpSemantics() { + assertEquals(0, invokeTraceCmpDoubleWrapper(-0.0d, +0.0d, /* nanResult= */ -1)); + assertEquals(0, invokeTraceCmpDoubleWrapper(+0.0d, -0.0d, /* nanResult= */ 1)); + assertEquals(-1, invokeTraceCmpDoubleWrapper(Double.NaN, 1.0d, /* nanResult= */ -1)); + assertEquals(1, invokeTraceCmpDoubleWrapper(Double.NaN, 1.0d, /* nanResult= */ 1)); + } + + @Test + public void traceCmpFloatWrapperShouldMatchFcmpSemantics() { + assertEquals(0, invokeTraceCmpFloatWrapper(-0.0f, +0.0f, /* nanResult= */ -1)); + assertEquals(0, invokeTraceCmpFloatWrapper(+0.0f, -0.0f, /* nanResult= */ 1)); + assertEquals(-1, invokeTraceCmpFloatWrapper(Float.NaN, 1.0f, /* nanResult= */ -1)); + assertEquals(1, invokeTraceCmpFloatWrapper(Float.NaN, 1.0f, /* nanResult= */ 1)); + } + + private static int invokeTraceCmpDoubleWrapper(double arg1, double arg2, int nanResult) { + try { + Class callbacksClass = + Class.forName("com.code_intelligence.jazzer.runtime.TraceDataFlowNativeCallbacks"); + return (int) + callbacksClass + .getMethod("traceCmpDoubleWrapper", double.class, double.class, int.class, int.class) + .invoke(null, arg1, arg2, nanResult, 1); + } catch (ReflectiveOperationException e) { + throw new AssertionError(e); + } + } + + private static int invokeTraceCmpFloatWrapper(float arg1, float arg2, int nanResult) { + try { + Class callbacksClass = + Class.forName("com.code_intelligence.jazzer.runtime.TraceDataFlowNativeCallbacks"); + return (int) + callbacksClass + .getMethod("traceCmpFloatWrapper", float.class, float.class, int.class, int.class) + .invoke(null, arg1, arg2, nanResult, 1); + } catch (ReflectiveOperationException e) { + throw new AssertionError(e); + } + } } diff --git a/tests/BUILD.bazel b/tests/BUILD.bazel index 7a7fe9e0c..b3d88e1a0 100644 --- a/tests/BUILD.bazel +++ b/tests/BUILD.bazel @@ -846,6 +846,28 @@ java_fuzz_target_test( ], ) +java_fuzz_target_test( + name = "FloatDoubleCmpFuzzer", + timeout = "short", + srcs = ["src/test/java/com/example/FloatDoubleCmpFuzzer.java"], + allowed_findings = [ + "com.code_intelligence.jazzer.api.FuzzerSecurityIssueLow", + ], + fuzzer_args = [ + "-use_value_profile=1", + ], + target_class = "com.example.FloatDoubleCmpFuzzer", + verify_crash_reproducer = False, + runtime_deps = [ + "@maven//:org_junit_jupiter_junit_jupiter_engine", + ], + deps = [ + "//deploy:jazzer-junit", + "@maven//:org_junit_jupiter_junit_jupiter_api", + "@maven//:org_junit_jupiter_junit_jupiter_params", + ], +) + java_fuzz_target_test( name = "LocalDateTimeFuzzer", timeout = "short", diff --git a/tests/src/test/java/com/example/FloatDoubleCmpFuzzer.java b/tests/src/test/java/com/example/FloatDoubleCmpFuzzer.java new file mode 100644 index 000000000..059dc8c85 --- /dev/null +++ b/tests/src/test/java/com/example/FloatDoubleCmpFuzzer.java @@ -0,0 +1,31 @@ +/* + * Copyright 2026 Code Intelligence GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example; + +import com.code_intelligence.jazzer.api.FuzzerSecurityIssueLow; +import com.code_intelligence.jazzer.junit.FuzzTest; + +public class FloatDoubleCmpFuzzer { + @FuzzTest + void floatDoubleCmp(float f, double d) { + float fx = f * f - 8.625f * f; + double dx = d * d - 10.8125 * d; + if (fx == -12.65625f && dx == -24.3046875) { + throw new FuzzerSecurityIssueLow("Float/double comparison tracking works!"); + } + } +}