From 891b22ff7c6fd643f40e76640976cd0bf5725403 Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Fri, 13 Feb 2026 10:41:54 +0000 Subject: [PATCH 01/11] import *just* the resp-reader parts from the IO rewrite from side branch --- Directory.Build.props | 2 +- Directory.Packages.props | 3 + StackExchange.Redis.sln | 14 + docs/exp/SER004.md | 15 + src/RESPite/Internal/Raw.cs | 138 ++ src/RESPite/Internal/RespConstants.cs | 53 + src/RESPite/Messages/RespAttributeReader.cs | 72 + src/RESPite/Messages/RespFrameScanner.cs | 195 ++ src/RESPite/Messages/RespPrefix.cs | 101 + .../RespReader.AggregateEnumerator.cs | 214 ++ src/RESPite/Messages/RespReader.Debug.cs | 33 + .../Messages/RespReader.ScalarEnumerator.cs | 105 + src/RESPite/Messages/RespReader.Span.cs | 84 + src/RESPite/Messages/RespReader.Utils.cs | 317 +++ src/RESPite/Messages/RespReader.cs | 1774 +++++++++++++++++ src/RESPite/Messages/RespScanState.cs | 162 ++ src/RESPite/PublicAPI/PublicAPI.Shipped.txt | 1 + src/RESPite/PublicAPI/PublicAPI.Unshipped.txt | 134 ++ .../PublicAPI/net8.0/PublicAPI.Shipped.txt | 1 + .../PublicAPI/net8.0/PublicAPI.Unshipped.txt | 3 + src/RESPite/RESPite.csproj | 51 + src/RESPite/RespException.cs | 12 + src/RESPite/readme.md | 6 + src/StackExchange.Redis/Experiments.cs | 5 +- src/StackExchange.Redis/FrameworkShims.cs | 3 +- tests/RESPite.Tests/RESPite.Tests.csproj | 21 + tests/RESPite.Tests/RespReaderTests.cs | 863 ++++++++ 27 files changed, 4378 insertions(+), 4 deletions(-) create mode 100644 docs/exp/SER004.md create mode 100644 src/RESPite/Internal/Raw.cs create mode 100644 src/RESPite/Internal/RespConstants.cs create mode 100644 src/RESPite/Messages/RespAttributeReader.cs create mode 100644 src/RESPite/Messages/RespFrameScanner.cs create mode 100644 src/RESPite/Messages/RespPrefix.cs create mode 100644 src/RESPite/Messages/RespReader.AggregateEnumerator.cs create mode 100644 src/RESPite/Messages/RespReader.Debug.cs create mode 100644 src/RESPite/Messages/RespReader.ScalarEnumerator.cs create mode 100644 src/RESPite/Messages/RespReader.Span.cs create mode 100644 src/RESPite/Messages/RespReader.Utils.cs create mode 100644 src/RESPite/Messages/RespReader.cs create mode 100644 src/RESPite/Messages/RespScanState.cs create mode 100644 src/RESPite/PublicAPI/PublicAPI.Shipped.txt create mode 100644 src/RESPite/PublicAPI/PublicAPI.Unshipped.txt create mode 100644 src/RESPite/PublicAPI/net8.0/PublicAPI.Shipped.txt create mode 100644 src/RESPite/PublicAPI/net8.0/PublicAPI.Unshipped.txt create mode 100644 src/RESPite/RESPite.csproj create mode 100644 src/RESPite/RespException.cs create mode 100644 src/RESPite/readme.md create mode 100644 tests/RESPite.Tests/RESPite.Tests.csproj create mode 100644 tests/RESPite.Tests/RespReaderTests.cs diff --git a/Directory.Build.props b/Directory.Build.props index e36f0f7d1..4b45e0f1d 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -10,7 +10,7 @@ true $(MSBuildThisFileDirectory)Shared.ruleset NETSDK1069 - $(NoWarn);NU5105;NU1507;SER001;SER002;SER003 + $(NoWarn);NU5105;NU1507;SER001;SER002;SER003;SER004 https://stackexchange.github.io/StackExchange.Redis/ReleaseNotes https://stackexchange.github.io/StackExchange.Redis/ MIT diff --git a/Directory.Packages.props b/Directory.Packages.props index 3fa9e0e3d..9767a0ab1 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -10,6 +10,9 @@ + + + diff --git a/StackExchange.Redis.sln b/StackExchange.Redis.sln index adb1291de..ca5e3a60d 100644 --- a/StackExchange.Redis.sln +++ b/StackExchange.Redis.sln @@ -127,6 +127,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "eng", "eng", "{5FA0958E-6EB EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StackExchange.Redis.Build", "eng\StackExchange.Redis.Build\StackExchange.Redis.Build.csproj", "{190742E1-FA50-4E36-A8C4-88AE87654340}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RESPite", "src\RESPite\RESPite.csproj", "{AEA77181-DDD2-4E43-828B-908C7460A12D}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RESPite.Tests", "tests\RESPite.Tests\RESPite.Tests.csproj", "{1D324077-A15E-4EE2-9AD6-A9045636CEAC}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -189,6 +193,14 @@ Global {190742E1-FA50-4E36-A8C4-88AE87654340}.Debug|Any CPU.Build.0 = Debug|Any CPU {190742E1-FA50-4E36-A8C4-88AE87654340}.Release|Any CPU.ActiveCfg = Release|Any CPU {190742E1-FA50-4E36-A8C4-88AE87654340}.Release|Any CPU.Build.0 = Release|Any CPU + {AEA77181-DDD2-4E43-828B-908C7460A12D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AEA77181-DDD2-4E43-828B-908C7460A12D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AEA77181-DDD2-4E43-828B-908C7460A12D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AEA77181-DDD2-4E43-828B-908C7460A12D}.Release|Any CPU.Build.0 = Release|Any CPU + {1D324077-A15E-4EE2-9AD6-A9045636CEAC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1D324077-A15E-4EE2-9AD6-A9045636CEAC}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1D324077-A15E-4EE2-9AD6-A9045636CEAC}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1D324077-A15E-4EE2-9AD6-A9045636CEAC}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -212,6 +224,8 @@ Global {69A0ACF2-DF1F-4F49-B554-F732DCA938A3} = {73A5C363-CA1F-44C4-9A9B-EF791A76BA6A} {59889284-FFEE-82E7-94CB-3B43E87DA6CF} = {73A5C363-CA1F-44C4-9A9B-EF791A76BA6A} {190742E1-FA50-4E36-A8C4-88AE87654340} = {5FA0958E-6EBD-45F4-808E-3447A293F96F} + {AEA77181-DDD2-4E43-828B-908C7460A12D} = {00CA0876-DA9F-44E8-B0DC-A88716BF347A} + {1D324077-A15E-4EE2-9AD6-A9045636CEAC} = {73A5C363-CA1F-44C4-9A9B-EF791A76BA6A} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {193AA352-6748-47C1-A5FC-C9AA6B5F000B} diff --git a/docs/exp/SER004.md b/docs/exp/SER004.md new file mode 100644 index 000000000..91f5d87c4 --- /dev/null +++ b/docs/exp/SER004.md @@ -0,0 +1,15 @@ +# RESPite + +RESPite is an experimental library that provides high-performance low-level RESP (Redis, etc) parsing and serialization. +It is used as the IO core for StackExchange.Redis v3+. You should not (yet) use it directly unless you have a very +good reason to do so. + +```xml +$(NoWarn);SER004 +``` + +or more granularly / locally in C#: + +``` c# +#pragma warning disable SER004 +``` diff --git a/src/RESPite/Internal/Raw.cs b/src/RESPite/Internal/Raw.cs new file mode 100644 index 000000000..65d0c5059 --- /dev/null +++ b/src/RESPite/Internal/Raw.cs @@ -0,0 +1,138 @@ +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; + +#if NETCOREAPP3_0_OR_GREATER +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +#endif + +namespace RESPite.Internal; + +/// +/// Pre-computed payload fragments, for high-volume scenarios / common values. +/// +/// +/// CPU-endianness applies here; we can't just use "const" - however, modern JITs treat "static readonly" *almost* the same as "const", so: meh. +/// +internal static class Raw +{ + public static ulong Create64(ReadOnlySpan bytes, int length) + { + if (length != bytes.Length) + { + throw new ArgumentException($"Length check failed: {length} vs {bytes.Length}, value: {RespConstants.UTF8.GetString(bytes)}", nameof(length)); + } + if (length < 0 || length > sizeof(ulong)) + { + throw new ArgumentOutOfRangeException(nameof(length), $"Invalid length {length} - must be 0-{sizeof(ulong)}"); + } + + // this *will* be aligned; this approach intentionally chosen for parity with write + Span scratch = stackalloc byte[sizeof(ulong)]; + if (length != sizeof(ulong)) scratch.Slice(length).Clear(); + bytes.CopyTo(scratch); + return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(scratch)); + } + + public static uint Create32(ReadOnlySpan bytes, int length) + { + if (length != bytes.Length) + { + throw new ArgumentException($"Length check failed: {length} vs {bytes.Length}, value: {RespConstants.UTF8.GetString(bytes)}", nameof(length)); + } + if (length < 0 || length > sizeof(uint)) + { + throw new ArgumentOutOfRangeException(nameof(length), $"Invalid length {length} - must be 0-{sizeof(uint)}"); + } + + // this *will* be aligned; this approach intentionally chosen for parity with write + Span scratch = stackalloc byte[sizeof(uint)]; + if (length != sizeof(uint)) scratch.Slice(length).Clear(); + bytes.CopyTo(scratch); + return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(scratch)); + } + + public static ulong BulkStringEmpty_6 = Create64("$0\r\n\r\n"u8, 6); + + public static ulong BulkStringInt32_M1_8 = Create64("$2\r\n-1\r\n"u8, 8); + public static ulong BulkStringInt32_0_7 = Create64("$1\r\n0\r\n"u8, 7); + public static ulong BulkStringInt32_1_7 = Create64("$1\r\n1\r\n"u8, 7); + public static ulong BulkStringInt32_2_7 = Create64("$1\r\n2\r\n"u8, 7); + public static ulong BulkStringInt32_3_7 = Create64("$1\r\n3\r\n"u8, 7); + public static ulong BulkStringInt32_4_7 = Create64("$1\r\n4\r\n"u8, 7); + public static ulong BulkStringInt32_5_7 = Create64("$1\r\n5\r\n"u8, 7); + public static ulong BulkStringInt32_6_7 = Create64("$1\r\n6\r\n"u8, 7); + public static ulong BulkStringInt32_7_7 = Create64("$1\r\n7\r\n"u8, 7); + public static ulong BulkStringInt32_8_7 = Create64("$1\r\n8\r\n"u8, 7); + public static ulong BulkStringInt32_9_7 = Create64("$1\r\n9\r\n"u8, 7); + public static ulong BulkStringInt32_10_8 = Create64("$2\r\n10\r\n"u8, 8); + + public static ulong BulkStringPrefix_M1_5 = Create64("$-1\r\n"u8, 5); + public static uint BulkStringPrefix_0_4 = Create32("$0\r\n"u8, 4); + public static uint BulkStringPrefix_1_4 = Create32("$1\r\n"u8, 4); + public static uint BulkStringPrefix_2_4 = Create32("$2\r\n"u8, 4); + public static uint BulkStringPrefix_3_4 = Create32("$3\r\n"u8, 4); + public static uint BulkStringPrefix_4_4 = Create32("$4\r\n"u8, 4); + public static uint BulkStringPrefix_5_4 = Create32("$5\r\n"u8, 4); + public static uint BulkStringPrefix_6_4 = Create32("$6\r\n"u8, 4); + public static uint BulkStringPrefix_7_4 = Create32("$7\r\n"u8, 4); + public static uint BulkStringPrefix_8_4 = Create32("$8\r\n"u8, 4); + public static uint BulkStringPrefix_9_4 = Create32("$9\r\n"u8, 4); + public static ulong BulkStringPrefix_10_5 = Create64("$10\r\n"u8, 5); + + public static ulong ArrayPrefix_M1_5 = Create64("*-1\r\n"u8, 5); + public static uint ArrayPrefix_0_4 = Create32("*0\r\n"u8, 4); + public static uint ArrayPrefix_1_4 = Create32("*1\r\n"u8, 4); + public static uint ArrayPrefix_2_4 = Create32("*2\r\n"u8, 4); + public static uint ArrayPrefix_3_4 = Create32("*3\r\n"u8, 4); + public static uint ArrayPrefix_4_4 = Create32("*4\r\n"u8, 4); + public static uint ArrayPrefix_5_4 = Create32("*5\r\n"u8, 4); + public static uint ArrayPrefix_6_4 = Create32("*6\r\n"u8, 4); + public static uint ArrayPrefix_7_4 = Create32("*7\r\n"u8, 4); + public static uint ArrayPrefix_8_4 = Create32("*8\r\n"u8, 4); + public static uint ArrayPrefix_9_4 = Create32("*9\r\n"u8, 4); + public static ulong ArrayPrefix_10_5 = Create64("*10\r\n"u8, 5); + +#if NETCOREAPP3_0_OR_GREATER + private static uint FirstAndLast(char first, char last) + { + Debug.Assert(first < 128 && last < 128, "ASCII please"); + Span scratch = [(byte)first, 0, 0, (byte)last]; + // this *will* be aligned; this approach intentionally chosen for how we read + return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(scratch)); + } + + public const int CommonRespIndex_Success = 0; + public const int CommonRespIndex_SingleDigitInteger = 1; + public const int CommonRespIndex_DoubleDigitInteger = 2; + public const int CommonRespIndex_SingleDigitString = 3; + public const int CommonRespIndex_DoubleDigitString = 4; + public const int CommonRespIndex_SingleDigitArray = 5; + public const int CommonRespIndex_DoubleDigitArray = 6; + public const int CommonRespIndex_Error = 7; + + public static readonly Vector256 CommonRespPrefixes = Vector256.Create( + FirstAndLast('+', '\r'), // success +OK\r\n + FirstAndLast(':', '\n'), // single-digit integer :4\r\n + FirstAndLast(':', '\r'), // double-digit integer :42\r\n + FirstAndLast('$', '\n'), // 0-9 char string $0\r\n\r\n + FirstAndLast('$', '\r'), // null/10-99 char string $-1\r\n or $10\r\nABCDEFGHIJ\r\n + FirstAndLast('*', '\n'), // 0-9 length array *0\r\n + FirstAndLast('*', '\r'), // null/10-99 length array *-1\r\n or *10\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n + FirstAndLast('-', 'R')); // common errors -ERR something bad happened + + public static readonly Vector256 FirstLastMask = CreateUInt32(0xFF0000FF); + + private static Vector256 CreateUInt32(uint value) + { +#if NET7_0_OR_GREATER + return Vector256.Create(value); +#else + return Vector256.Create(value, value, value, value, value, value, value, value); +#endif + } + +#endif +} diff --git a/src/RESPite/Internal/RespConstants.cs b/src/RESPite/Internal/RespConstants.cs new file mode 100644 index 000000000..accb8400b --- /dev/null +++ b/src/RESPite/Internal/RespConstants.cs @@ -0,0 +1,53 @@ +using System.Buffers.Binary; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +// ReSharper disable InconsistentNaming +namespace RESPite.Internal; + +internal static class RespConstants +{ + public static readonly UTF8Encoding UTF8 = new(false); + + public static ReadOnlySpan CrlfBytes => "\r\n"u8; + + public static readonly ushort CrLfUInt16 = UnsafeCpuUInt16(CrlfBytes); + + public static ReadOnlySpan OKBytes_LC => "ok"u8; + public static ReadOnlySpan OKBytes => "OK"u8; + public static readonly ushort OKUInt16 = UnsafeCpuUInt16(OKBytes); + public static readonly ushort OKUInt16_LC = UnsafeCpuUInt16(OKBytes_LC); + + public static readonly uint BulkStringStreaming = UnsafeCpuUInt32("$?\r\n"u8); + public static readonly uint BulkStringNull = UnsafeCpuUInt32("$-1\r"u8); + + public static readonly uint ArrayStreaming = UnsafeCpuUInt32("*?\r\n"u8); + public static readonly uint ArrayNull = UnsafeCpuUInt32("*-1\r"u8); + + public static ushort UnsafeCpuUInt16(ReadOnlySpan bytes) + => Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(bytes)); + public static ushort UnsafeCpuUInt16(ReadOnlySpan bytes, int offset) + => Unsafe.ReadUnaligned(ref Unsafe.Add(ref MemoryMarshal.GetReference(bytes), offset)); + public static byte UnsafeCpuByte(ReadOnlySpan bytes, int offset) + => Unsafe.Add(ref MemoryMarshal.GetReference(bytes), offset); + public static uint UnsafeCpuUInt32(ReadOnlySpan bytes) + => Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(bytes)); + public static uint UnsafeCpuUInt32(ReadOnlySpan bytes, int offset) + => Unsafe.ReadUnaligned(ref Unsafe.Add(ref MemoryMarshal.GetReference(bytes), offset)); + public static ulong UnsafeCpuUInt64(ReadOnlySpan bytes) + => Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(bytes)); + public static ushort CpuUInt16(ushort bigEndian) + => BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(bigEndian) : bigEndian; + public static uint CpuUInt32(uint bigEndian) + => BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(bigEndian) : bigEndian; + public static ulong CpuUInt64(ulong bigEndian) + => BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(bigEndian) : bigEndian; + + public const int MaxRawBytesInt32 = 11, // "-2147483648" + MaxRawBytesInt64 = 20, // "-9223372036854775808", + MaxProtocolBytesIntegerInt32 = MaxRawBytesInt32 + 3, // ?X10X\r\n where ? could be $, *, etc - usually a length prefix + MaxProtocolBytesBulkStringIntegerInt32 = MaxRawBytesInt32 + 7, // $NN\r\nX11X\r\n for NN (length) 1-11 + MaxProtocolBytesBulkStringIntegerInt64 = MaxRawBytesInt64 + 7, // $NN\r\nX20X\r\n for NN (length) 1-20 + MaxRawBytesNumber = 20, // note G17 format, allow 20 for payload + MaxProtocolBytesBytesNumber = MaxRawBytesNumber + 7; // $NN\r\nX...X\r\n for NN (length) 1-20 +} diff --git a/src/RESPite/Messages/RespAttributeReader.cs b/src/RESPite/Messages/RespAttributeReader.cs new file mode 100644 index 000000000..46fd26a19 --- /dev/null +++ b/src/RESPite/Messages/RespAttributeReader.cs @@ -0,0 +1,72 @@ +using System.Diagnostics.CodeAnalysis; +using StackExchange.Redis; + +namespace RESPite.Messages; + +/// +/// Allows attribute data to be parsed conveniently. +/// +/// The type of data represented by this reader. +[Experimental(Experiments.Respite, UrlFormat = Experiments.UrlFormat)] +public abstract class RespAttributeReader +{ + /// + /// Parse a group of attributes. + /// + public virtual void Read(ref RespReader reader, ref T value) + { + reader.Demand(RespPrefix.Attribute); + _ = ReadKeyValuePairs(ref reader, ref value); + } + + /// + /// Parse an aggregate as a set of key/value pairs. + /// + /// The number of pairs successfully processed. + protected virtual int ReadKeyValuePairs(ref RespReader reader, ref T value) + { + var iterator = reader.AggregateChildren(); + + byte[] pooledBuffer = []; + Span localBuffer = stackalloc byte[128]; + int count = 0; + while (iterator.MoveNext() && iterator.Value.TryReadNext()) + { + if (iterator.Value.IsScalar) + { + var key = iterator.Value.Buffer(ref pooledBuffer, localBuffer); + + if (iterator.MoveNext() && iterator.Value.TryReadNext()) + { + if (ReadKeyValuePair(key, ref iterator.Value, ref value)) + { + count++; + } + } + else + { + break; // no matching value for this key + } + } + else + { + if (iterator.MoveNext() && iterator.Value.TryReadNext()) + { + // we won't try to handle aggregate keys; skip the value + } + else + { + break; // no matching value for this key + } + } + } + iterator.MovePast(out reader); + return count; + } + + /// + /// Parse an individual key/value pair. + /// + /// True if the pair was successfully processed. + public virtual bool ReadKeyValuePair(scoped ReadOnlySpan key, ref RespReader reader, ref T value) => false; +} diff --git a/src/RESPite/Messages/RespFrameScanner.cs b/src/RESPite/Messages/RespFrameScanner.cs new file mode 100644 index 000000000..5034e994a --- /dev/null +++ b/src/RESPite/Messages/RespFrameScanner.cs @@ -0,0 +1,195 @@ +using System.Buffers; +using System.Diagnostics.CodeAnalysis; +using StackExchange.Redis; +using static RESPite.Internal.RespConstants; +namespace RESPite.Messages; + +/// +/// Scans RESP frames. +/// . +[Experimental(Experiments.Respite, UrlFormat = Experiments.UrlFormat)] +public sealed class RespFrameScanner // : IFrameSacanner, IFrameValidator +{ + /// + /// Gets a frame scanner for RESP2 request/response connections, or RESP3 connections. + /// + public static RespFrameScanner Default { get; } = new(false); + + /// + /// Gets a frame scanner that identifies RESP2 pub/sub messages. + /// + public static RespFrameScanner Subscription { get; } = new(true); + private RespFrameScanner(bool pubsub) => _pubsub = pubsub; + private readonly bool _pubsub; + + private static readonly uint FastNull = UnsafeCpuUInt32("_\r\n\0"u8), + SingleCharScalarMask = CpuUInt32(0xFF00FFFF), + SingleDigitInteger = UnsafeCpuUInt32(":\0\r\n"u8), + EitherBoolean = UnsafeCpuUInt32("#\0\r\n"u8), + FirstThree = CpuUInt32(0xFFFFFF00); + private static readonly ulong OK = UnsafeCpuUInt64("+OK\r\n\0\0\0"u8), + PONG = UnsafeCpuUInt64("+PONG\r\n\0"u8), + DoubleCharScalarMask = CpuUInt64(0xFF0000FFFF000000), + DoubleDigitInteger = UnsafeCpuUInt64(":\0\0\r\n"u8), + FirstFive = CpuUInt64(0xFFFFFFFFFF000000), + FirstSeven = CpuUInt64(0xFFFFFFFFFFFFFF00); + + private const OperationStatus UseReader = (OperationStatus)(-1); + private static OperationStatus TryFastRead(ReadOnlySpan data, ref RespScanState info) + { + // use silly math to detect the most common short patterns without needing + // to access a reader, or use indexof etc; handles: + // +OK\r\n + // +PONG\r\n + // :N\r\n for any single-digit N (integer) + // :NN\r\n for any double-digit N (integer) + // #N\r\n for any single-digit N (boolean) + // _\r\n (null) + uint hi, lo; + switch (data.Length) + { + case 0: + case 1: + case 2: + return OperationStatus.NeedMoreData; + case 3: + hi = (((uint)UnsafeCpuUInt16(data)) << 16) | (((uint)UnsafeCpuByte(data, 2)) << 8); + break; + default: + hi = UnsafeCpuUInt32(data); + break; + } + if ((hi & FirstThree) == FastNull) + { + info.SetComplete(3, RespPrefix.Null); + return OperationStatus.Done; + } + + var masked = hi & SingleCharScalarMask; + if (masked == SingleDigitInteger) + { + info.SetComplete(4, RespPrefix.Integer); + return OperationStatus.Done; + } + else if (masked == EitherBoolean) + { + info.SetComplete(4, RespPrefix.Boolean); + return OperationStatus.Done; + } + + switch (data.Length) + { + case 3: + return OperationStatus.NeedMoreData; + case 4: + return UseReader; + case 5: + lo = ((uint)data[4]) << 24; + break; + case 6: + lo = ((uint)UnsafeCpuUInt16(data, 4)) << 16; + break; + case 7: + lo = ((uint)UnsafeCpuUInt16(data, 4)) << 16 | ((uint)UnsafeCpuByte(data, 6)) << 8; + break; + default: + lo = UnsafeCpuUInt32(data, 4); + break; + } + var u64 = BitConverter.IsLittleEndian ? ((((ulong)lo) << 32) | hi) : ((((ulong)hi) << 32) | lo); + if (((u64 & FirstFive) == OK) | ((u64 & DoubleCharScalarMask) == DoubleDigitInteger)) + { + info.SetComplete(5, RespPrefix.SimpleString); + return OperationStatus.Done; + } + if ((u64 & FirstSeven) == PONG) + { + info.SetComplete(7, RespPrefix.SimpleString); + return OperationStatus.Done; + } + return UseReader; + } + + /// + /// Attempt to read more data as part of the current frame. + /// + public OperationStatus TryRead(ref RespScanState state, in ReadOnlySequence data) + { + if (!_pubsub & state.TotalBytes == 0 & data.IsSingleSegment) + { +#if NETCOREAPP3_1_OR_GREATER + var status = TryFastRead(data.FirstSpan, ref state); +#else + var status = TryFastRead(data.First.Span, ref state); +#endif + if (status != UseReader) return status; + } + + return TryReadViaReader(ref state, in data); + + static OperationStatus TryReadViaReader(ref RespScanState state, in ReadOnlySequence data) + { + var reader = new RespReader(in data); + var complete = state.TryRead(ref reader, out var consumed); + if (complete) + { + return OperationStatus.Done; + } + return OperationStatus.NeedMoreData; + } + } + + /// + /// Attempt to read more data as part of the current frame. + /// + public OperationStatus TryRead(ref RespScanState state, ReadOnlySpan data) + { + if (!_pubsub & state.TotalBytes == 0) + { +#if NETCOREAPP3_1_OR_GREATER + var status = TryFastRead(data, ref state); +#else + var status = TryFastRead(data, ref state); +#endif + if (status != UseReader) return status; + } + + return TryReadViaReader(ref state, data); + + static OperationStatus TryReadViaReader(ref RespScanState state, ReadOnlySpan data) + { + var reader = new RespReader(data); + var complete = state.TryRead(ref reader, out var consumed); + if (complete) + { + return OperationStatus.Done; + } + return OperationStatus.NeedMoreData; + } + } + + /// + /// Validate that the supplied message is a valid RESP request, specifically: that it contains a single + /// top-level array payload with bulk-string elements, the first of which is non-empty (the command). + /// + public void ValidateRequest(in ReadOnlySequence message) + { + if (message.IsEmpty) Throw("Empty RESP frame"); + RespReader reader = new(in message); + reader.MoveNext(RespPrefix.Array); + reader.DemandNotNull(); + if (reader.IsStreaming) Throw("Streaming is not supported in this context"); + var count = reader.AggregateLength(); + for (int i = 0; i < count; i++) + { + reader.MoveNext(RespPrefix.BulkString); + reader.DemandNotNull(); + if (reader.IsStreaming) Throw("Streaming is not supported in this context"); + + if (i == 0 && reader.ScalarIsEmpty()) Throw("command must be non-empty"); + } + reader.DemandEnd(); + + static void Throw(string message) => throw new InvalidOperationException(message); + } +} diff --git a/src/RESPite/Messages/RespPrefix.cs b/src/RESPite/Messages/RespPrefix.cs new file mode 100644 index 000000000..828c01d88 --- /dev/null +++ b/src/RESPite/Messages/RespPrefix.cs @@ -0,0 +1,101 @@ +using System.Diagnostics.CodeAnalysis; +using StackExchange.Redis; + +namespace RESPite.Messages; + +/// +/// RESP protocol prefix. +/// +[Experimental(Experiments.Respite, UrlFormat = Experiments.UrlFormat)] +public enum RespPrefix : byte +{ + /// + /// Invalid. + /// + None = 0, + + /// + /// Simple strings: +OK\r\n. + /// + SimpleString = (byte)'+', + + /// + /// Simple errors: -ERR message\r\n. + /// + SimpleError = (byte)'-', + + /// + /// Integers: :123\r\n. + /// + Integer = (byte)':', + + /// + /// String with support for binary data: $7\r\nmessage\r\n. + /// + BulkString = (byte)'$', + + /// + /// Multiple inner messages: *1\r\n+message\r\n. + /// + Array = (byte)'*', + + /// + /// Null strings/arrays: _\r\n. + /// + Null = (byte)'_', + + /// + /// Boolean values: #T\r\n. + /// + Boolean = (byte)'#', + + /// + /// Floating-point number: ,123.45\r\n. + /// + Double = (byte)',', + + /// + /// Large integer number: (12...89\r\n. + /// + BigInteger = (byte)'(', + + /// + /// Error with support for binary data: !7\r\nmessage\r\n. + /// + BulkError = (byte)'!', + + /// + /// String that should be interpreted verbatim: =11\r\ntxt:message\r\n. + /// + VerbatimString = (byte)'=', + + /// + /// Multiple sub-items that represent a map. + /// + Map = (byte)'%', + + /// + /// Multiple sub-items that represent a set. + /// + Set = (byte)'~', + + /// + /// Out-of band messages. + /// + Push = (byte)'>', + + /// + /// Continuation of streaming scalar values. + /// + StreamContinuation = (byte)';', + + /// + /// End sentinel for streaming aggregate values. + /// + StreamTerminator = (byte)'.', + + /// + /// Metadata about the next element. + /// + Attribute = (byte)'|', +} diff --git a/src/RESPite/Messages/RespReader.AggregateEnumerator.cs b/src/RESPite/Messages/RespReader.AggregateEnumerator.cs new file mode 100644 index 000000000..1853d2ee6 --- /dev/null +++ b/src/RESPite/Messages/RespReader.AggregateEnumerator.cs @@ -0,0 +1,214 @@ +using System.Collections; +using System.ComponentModel; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace RESPite.Messages; + +public ref partial struct RespReader +{ + /// + /// Reads the sub-elements associated with an aggregate value. + /// + public readonly AggregateEnumerator AggregateChildren() => new(in this); + + /// + /// Reads the sub-elements associated with an aggregate value. + /// + public ref struct AggregateEnumerator + { + // Note that _reader is the overall reader that can see outside this aggregate, as opposed + // to Current which is the sub-tree of the current element *only* + private RespReader _reader; + private int _remaining; + + /// + /// Create a new enumerator for the specified . + /// + /// The reader containing the data for this operation. + public AggregateEnumerator(scoped in RespReader reader) + { + reader.DemandAggregate(); + _remaining = reader.IsStreaming ? -1 : reader._length; + _reader = reader; + Value = default; + } + + /// + public readonly AggregateEnumerator GetEnumerator() => this; + + /// + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public RespReader Current => Value; + + /// + /// Gets the current element associated with this reader. + /// + public RespReader Value; // intentionally a field, because of ref-semantics + + /// + /// Move to the next child if possible, and move the child element into the next node. + /// + public bool MoveNext(RespPrefix prefix) + { + bool result = MoveNext(); + if (result) + { + Value.MoveNext(prefix); + } + return result; + } + + /// + /// Move to the next child if possible, and move the child element into the next node. + /// + /// The type of data represented by this reader. + public bool MoveNext(RespPrefix prefix, RespAttributeReader respAttributeReader, ref T attributes) + { + bool result = MoveNext(respAttributeReader, ref attributes); + if (result) + { + Value.MoveNext(prefix); + } + return result; + } + + /// > + public bool MoveNext() + { + object? attributes = null; + return MoveNextCore(null, ref attributes); + } + + /// > + /// The type of data represented by this reader. + public bool MoveNext(RespAttributeReader respAttributeReader, ref T attributes) + => MoveNextCore(respAttributeReader, ref attributes); + + /// > + private bool MoveNextCore(RespAttributeReader? attributeReader, ref T attributes) + { + if (_remaining == 0) + { + Value = default; + return false; + } + + // in order to provide access to attributes etc, we want Current to be positioned + // *before* the next element; for that, we'll take a snapshot before we read + _reader.MovePastCurrent(); + var snapshot = _reader.Clone(); + + if (attributeReader is null) + { + _reader.MoveNext(); + } + else + { + _reader.MoveNext(attributeReader, ref attributes); + } + if (_remaining > 0) + { + // non-streaming, decrement + _remaining--; + } + else if (_reader.Prefix == RespPrefix.StreamTerminator) + { + // end of streaming aggregate + _remaining = 0; + Value = default; + return false; + } + + // move past that sub-tree and trim the "snapshot" state, giving + // us a scoped reader that is *just* that sub-tree + _reader.SkipChildren(); + snapshot.TrimToTotal(_reader.BytesConsumed); + + Value = snapshot; + return true; + } + + /// + /// Move to the end of this aggregate and export the state of the . + /// + /// The reader positioned at the end of the data; this is commonly + /// used to update a tree reader, to get to the next data after the aggregate. + public void MovePast(out RespReader reader) + { + while (MoveNext()) { } + reader = _reader; + } + + public void DemandNext() + { + if (!MoveNext()) ThrowEof(); + Value.MoveNext(); // skip any attributes etc + } + + public T ReadOne(Projection projection) + { + DemandNext(); + return projection(ref Value); + } + + public void FillAll(scoped Span target, Projection projection) + { + for (int i = 0; i < target.Length; i++) + { + if (!MoveNext()) ThrowEof(); + + Value.MoveNext(); // skip any attributes etc + target[i] = projection(ref Value); + } + } + + public void FillAll( + scoped Span target, + Projection first, + Projection second, + Func combine) + { + for (int i = 0; i < target.Length; i++) + { + if (!MoveNext()) ThrowEof(); + + Value.MoveNext(); // skip any attributes etc + var x = first(ref Value); + + if (!MoveNext()) ThrowEof(); + + Value.MoveNext(); // skip any attributes etc + var y = second(ref Value); + target[i] = combine(x, y); + } + } + } + + internal void TrimToTotal(long length) => TrimToRemaining(length - BytesConsumed); + + internal void TrimToRemaining(long bytes) + { + if (_prefix != RespPrefix.None || bytes < 0) Throw(); + + var current = CurrentAvailable; + if (bytes <= current) + { + UnsafeTrimCurrentBy(current - (int)bytes); + _remainingTailLength = 0; + return; + } + + bytes -= current; + if (bytes <= _remainingTailLength) + { + _remainingTailLength = bytes; + return; + } + + Throw(); + static void Throw() => throw new ArgumentOutOfRangeException(nameof(bytes)); + } +} diff --git a/src/RESPite/Messages/RespReader.Debug.cs b/src/RESPite/Messages/RespReader.Debug.cs new file mode 100644 index 000000000..3f471bbd1 --- /dev/null +++ b/src/RESPite/Messages/RespReader.Debug.cs @@ -0,0 +1,33 @@ +using System.Diagnostics; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace RESPite.Messages; + +[DebuggerDisplay($"{{{nameof(GetDebuggerDisplay)}(),nq}}")] +public ref partial struct RespReader +{ + internal bool DebugEquals(in RespReader other) + => _prefix == other._prefix + && _length == other._length + && _flags == other._flags + && _bufferIndex == other._bufferIndex + && _positionBase == other._positionBase + && _remainingTailLength == other._remainingTailLength; + + internal new string ToString() => $"{Prefix} ({_flags}); length {_length}, {TotalAvailable} remaining"; + + internal void DebugReset() + { + _bufferIndex = 0; + _length = 0; + _flags = 0; + _prefix = RespPrefix.None; + } + +#if DEBUG + internal bool VectorizeDisabled { get; set; } +#endif +} diff --git a/src/RESPite/Messages/RespReader.ScalarEnumerator.cs b/src/RESPite/Messages/RespReader.ScalarEnumerator.cs new file mode 100644 index 000000000..9e8ffbe70 --- /dev/null +++ b/src/RESPite/Messages/RespReader.ScalarEnumerator.cs @@ -0,0 +1,105 @@ +using System.Buffers; +using System.Collections; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace RESPite.Messages; + +public ref partial struct RespReader +{ + /// + /// Gets the chunks associated with a scalar value. + /// + public readonly ScalarEnumerator ScalarChunks() => new(in this); + + /// + /// Allows enumeration of chunks in a scalar value; this includes simple values + /// that span multiple segments, and streaming + /// scalar RESP values. + /// + public ref struct ScalarEnumerator + { + /// + public readonly ScalarEnumerator GetEnumerator() => this; + + private RespReader _reader; + + private ReadOnlySpan _current; + private ReadOnlySequenceSegment? _tail; + private int _offset, _remaining; + + /// + /// Create a new enumerator for the specified . + /// + /// The reader containing the data for this operation. + public ScalarEnumerator(scoped in RespReader reader) + { + reader.DemandScalar(); + _reader = reader; + InitSegment(); + } + + private void InitSegment() + { + _current = _reader.CurrentSpan(); + _tail = _reader._tail; + _offset = CurrentLength = 0; + _remaining = _reader._length; + if (_reader.TotalAvailable < _remaining) ThrowEof(); + } + + /// + public bool MoveNext() + { + while (true) // for each streaming element + { + _offset += CurrentLength; + while (_remaining > 0) // for each span in the current element + { + // look in the active span + var take = Math.Min(_remaining, _current.Length - _offset); + if (take > 0) // more in the current chunk + { + _remaining -= take; + CurrentLength = take; + return true; + } + + // otherwise, we expect more tail data + if (_tail is null) ThrowEof(); + + _current = _tail.Memory.Span; + _offset = 0; + _tail = _tail.Next; + } + + if (!_reader.MoveNextStreamingScalar()) break; + InitSegment(); + } + + CurrentLength = 0; + return false; + } + + /// + public readonly ReadOnlySpan Current => _current.Slice(_offset, CurrentLength); + + /// + /// Gets the or . + /// + public int CurrentLength { readonly get; private set; } + + /// + /// Move to the end of this aggregate and export the state of the . + /// + /// The reader positioned at the end of the data; this is commonly + /// used to update a tree reader, to get to the next data after the aggregate. + public void MovePast(out RespReader reader) + { + while (MoveNext()) { } + reader = _reader; + } + } +} diff --git a/src/RESPite/Messages/RespReader.Span.cs b/src/RESPite/Messages/RespReader.Span.cs new file mode 100644 index 000000000..fd3870ef3 --- /dev/null +++ b/src/RESPite/Messages/RespReader.Span.cs @@ -0,0 +1,84 @@ +#define USE_UNSAFE_SPAN + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace RESPite.Messages; + +/* + How we actually implement the underlying buffer depends on the capabilities of the runtime. + */ + +#if NET7_0_OR_GREATER && USE_UNSAFE_SPAN + +public ref partial struct RespReader +{ + // intent: avoid lots of slicing by dealing with everything manually, and accepting the "don't get it wrong" rule + private ref byte _bufferRoot; + private int _bufferLength; + + private partial void UnsafeTrimCurrentBy(int count) + { + Debug.Assert(count >= 0 && count <= _bufferLength, "Unsafe trim length"); + _bufferLength -= count; + } + + private readonly partial ref byte UnsafeCurrent + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.Add(ref _bufferRoot, _bufferIndex); + } + + private readonly partial int CurrentLength + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _bufferLength; + } + + private readonly partial ReadOnlySpan CurrentSpan() => MemoryMarshal.CreateReadOnlySpan( + ref UnsafeCurrent, CurrentAvailable); + + private readonly partial ReadOnlySpan UnsafePastPrefix() => MemoryMarshal.CreateReadOnlySpan( + ref Unsafe.Add(ref _bufferRoot, _bufferIndex + 1), + _bufferLength - (_bufferIndex + 1)); + + private partial void SetCurrent(ReadOnlySpan value) + { + _bufferRoot = ref MemoryMarshal.GetReference(value); + _bufferLength = value.Length; + } +} +#else +public ref partial struct RespReader // much more conservative - uses slices etc +{ + private ReadOnlySpan _buffer; + + private partial void UnsafeTrimCurrentBy(int count) + { + _buffer = _buffer.Slice(0, _buffer.Length - count); + } + + private readonly partial ref byte UnsafeCurrent + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.AsRef(in _buffer[_bufferIndex]); // hack around CS8333 + } + + private readonly partial int CurrentLength + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _buffer.Length; + } + + private readonly partial ReadOnlySpan UnsafePastPrefix() => _buffer.Slice(_bufferIndex + 1); + + private readonly partial ReadOnlySpan CurrentSpan() => _buffer.Slice(_bufferIndex); + + private partial void SetCurrent(ReadOnlySpan value) => _buffer = value; +} +#endif diff --git a/src/RESPite/Messages/RespReader.Utils.cs b/src/RESPite/Messages/RespReader.Utils.cs new file mode 100644 index 000000000..da6b641d8 --- /dev/null +++ b/src/RESPite/Messages/RespReader.Utils.cs @@ -0,0 +1,317 @@ +using System.Buffers.Text; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using RESPite.Internal; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace RESPite.Messages; + +public ref partial struct RespReader +{ + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void UnsafeAssertClLf(int offset) => UnsafeAssertClLf(ref UnsafeCurrent, offset); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void UnsafeAssertClLf(scoped ref byte source, int offset) + { + if (Unsafe.ReadUnaligned(ref Unsafe.Add(ref source, offset)) != RespConstants.CrLfUInt16) + { + ThrowProtocolFailure("Expected CR/LF"); + } + } + + private enum LengthPrefixResult + { + NeedMoreData, + Length, + Null, + Streaming, + } + + /// + /// Asserts that the current element is a scalar type. + /// + public readonly void DemandScalar() + { + if (!IsScalar) Throw(Prefix); + static void Throw(RespPrefix prefix) => throw new InvalidOperationException($"This operation requires a scalar element, got {prefix}"); + } + + /// + /// Asserts that the current element is a scalar type. + /// + public readonly void DemandAggregate() + { + if (!IsAggregate) Throw(Prefix); + static void Throw(RespPrefix prefix) => throw new InvalidOperationException($"This operation requires an aggregate element, got {prefix}"); + } + + private static LengthPrefixResult TryReadLengthPrefix(ReadOnlySpan bytes, out int value, out int byteCount) + { + var end = bytes.IndexOf(RespConstants.CrlfBytes); + if (end < 0) + { + byteCount = value = 0; + if (bytes.Length >= RespConstants.MaxRawBytesInt32 + 2) + { + ThrowProtocolFailure("Unterminated or over-length integer"); // should have failed; report failure to prevent infinite loop + } + return LengthPrefixResult.NeedMoreData; + } + byteCount = end + 2; + switch (end) + { + case 0: + ThrowProtocolFailure("Length prefix expected"); + goto case default; // not reached, just satisfying definite assignment + case 1 when bytes[0] == (byte)'?': + value = 0; + return LengthPrefixResult.Streaming; + default: + if (end > RespConstants.MaxRawBytesInt32 || !(Utf8Parser.TryParse(bytes, out value, out var consumed) && consumed == end)) + { + ThrowProtocolFailure("Unable to parse integer"); + value = 0; + } + if (value < 0) + { + if (value == -1) + { + value = 0; + return LengthPrefixResult.Null; + } + ThrowProtocolFailure("Invalid negative length prefix"); + } + return LengthPrefixResult.Length; + } + } + + private readonly RespReader Clone() => this; // useful for performing streaming operations without moving the primary + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + private static void ThrowProtocolFailure(string message) + => throw new InvalidOperationException("RESP protocol failure: " + message); // protocol exception? + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + internal static void ThrowEof() => throw new EndOfStreamException(); + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + private static void ThrowFormatException() => throw new FormatException(); + + private int RawTryReadByte() + { + if (_bufferIndex < CurrentLength || TryMoveToNextSegment()) + { + var result = UnsafeCurrent; + _bufferIndex++; + return result; + } + return -1; + } + + private int RawPeekByte() + { + return (CurrentLength < _bufferIndex || TryMoveToNextSegment()) ? UnsafeCurrent : -1; + } + + private bool RawAssertCrLf() + { + if (CurrentAvailable >= 2) + { + UnsafeAssertClLf(0); + _bufferIndex += 2; + return true; + } + else + { + int next = RawTryReadByte(); + if (next < 0) return false; + if (next == '\r') + { + next = RawTryReadByte(); + if (next < 0) return false; + if (next == '\n') return true; + } + ThrowProtocolFailure("Expected CR/LF"); + return false; + } + } + + private LengthPrefixResult RawTryReadLengthPrefix() + { + _length = 0; + if (!RawTryFindCrLf(out int end)) + { + if (TotalAvailable >= RespConstants.MaxRawBytesInt32 + 2) + { + ThrowProtocolFailure("Unterminated or over-length integer"); // should have failed; report failure to prevent infinite loop + } + return LengthPrefixResult.NeedMoreData; + } + + switch (end) + { + case 0: + ThrowProtocolFailure("Length prefix expected"); + goto case default; // not reached, just satisfying definite assignment + case 1: + var b = (byte)RawTryReadByte(); + RawAssertCrLf(); + if (b == '?') + { + return LengthPrefixResult.Streaming; + } + else + { + _length = ParseSingleDigit(b); + return LengthPrefixResult.Length; + } + default: + if (end > RespConstants.MaxRawBytesInt32) + { + ThrowProtocolFailure("Unable to parse integer"); + } + Span bytes = stackalloc byte[end]; + RawFillBytes(bytes); + RawAssertCrLf(); + if (!(Utf8Parser.TryParse(bytes, out _length, out var consumed) && consumed == end)) + { + ThrowProtocolFailure("Unable to parse integer"); + } + + if (_length < 0) + { + if (_length == -1) + { + _length = 0; + return LengthPrefixResult.Null; + } + ThrowProtocolFailure("Invalid negative length prefix"); + } + + return LengthPrefixResult.Length; + } + } + + private void RawFillBytes(scoped Span target) + { + do + { + var current = CurrentSpan(); + if (current.Length >= target.Length) + { + // more than enough, need to trim + current.Slice(0, target.Length).CopyTo(target); + _bufferIndex += target.Length; + return; // we're done + } + else + { + // take what we can + current.CopyTo(target); + target = target.Slice(current.Length); + // we could move _bufferIndex here, but we're about to trash that in TryMoveToNextSegment + } + } + while (TryMoveToNextSegment()); + ThrowEof(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int ParseSingleDigit(byte value) + { + return value switch + { + (byte)'0' or (byte)'1' or (byte)'2' or (byte)'3' or (byte)'4' or (byte)'5' or (byte)'6' or (byte)'7' or (byte)'8' or (byte)'9' => value - (byte)'0', + _ => Invalid(value), + }; + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + static int Invalid(byte value) => throw new FormatException($"Unable to parse integer: '{(char)value}'"); + } + + private readonly bool RawTryAssertInlineScalarPayloadCrLf() + { + Debug.Assert(IsInlineScalar, "should be inline scalar"); + + var reader = Clone(); + var len = reader._length; + if (len == 0) return reader.RawAssertCrLf(); + + do + { + var current = reader.CurrentSpan(); + if (current.Length >= len) + { + reader._bufferIndex += len; + return reader.RawAssertCrLf(); // we're done + } + else + { + // take what we can + len -= current.Length; + // we could move _bufferIndex here, but we're about to trash that in TryMoveToNextSegment + } + } + while (reader.TryMoveToNextSegment()); + return false; // EOF + } + + private readonly bool RawTryFindCrLf(out int length) + { + length = 0; + RespReader reader = Clone(); + do + { + var span = reader.CurrentSpan(); + var index = span.IndexOf((byte)'\r'); + if (index >= 0) + { + checked + { + length += index; + } + // move past the CR and assert the LF + reader._bufferIndex += index + 1; + var next = reader.RawTryReadByte(); + if (next < 0) break; // we don't know + if (next != '\n') ThrowProtocolFailure("CR/LF expected"); + + return true; + } + checked + { + length += span.Length; + } + } + while (reader.TryMoveToNextSegment()); + length = 0; + return false; + } + + private string GetDebuggerDisplay() + { + return ToString(); + } + + internal readonly int GetInitialScanCount(out ushort streamingAggregateDepth) + { + // this is *similar* to GetDelta, but: without any discount for attributes + switch (_flags & (RespFlags.IsAggregate | RespFlags.IsStreaming)) + { + case RespFlags.IsAggregate: + streamingAggregateDepth = 0; + return _length - 1; + case RespFlags.IsAggregate | RespFlags.IsStreaming: + streamingAggregateDepth = 1; + return 0; + default: + streamingAggregateDepth = 0; + return -1; + } + } +} diff --git a/src/RESPite/Messages/RespReader.cs b/src/RESPite/Messages/RespReader.cs new file mode 100644 index 000000000..56e4ddefa --- /dev/null +++ b/src/RESPite/Messages/RespReader.cs @@ -0,0 +1,1774 @@ +using System.Buffers; +using System.Buffers.Text; +using System.ComponentModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Runtime.CompilerServices; +using System.Text; +using RESPite.Internal; +using StackExchange.Redis; + +#if NETCOREAPP3_0_OR_GREATER +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +#endif + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace RESPite.Messages; + +/// +/// Provides low level RESP parsing functionality. +/// +[Experimental(Experiments.Respite, UrlFormat = Experiments.UrlFormat)] +public ref partial struct RespReader +{ + [Flags] + private enum RespFlags : byte + { + None = 0, + IsScalar = 1 << 0, // simple strings, bulk strings, etc + IsAggregate = 1 << 1, // arrays, maps, sets, etc + IsNull = 1 << 2, // explicit null RESP types, or bulk-strings/aggregates with length -1 + IsInlineScalar = 1 << 3, // a non-null scalar, i.e. with payload+CrLf + IsAttribute = 1 << 4, // is metadata for following elements + IsStreaming = 1 << 5, // unknown length + IsError = 1 << 6, // an explicit error reported inside the protocol + } + + // relates to the element we're currently reading + private RespFlags _flags; + private RespPrefix _prefix; + + private int _length; // for null: 0; for scalars: the length of the payload; for aggregates: the child count + + // the current buffer that we're observing + private int _bufferIndex; // after TryRead, this should be positioned immediately before the actual data + + // the position in a multi-segment payload + private long _positionBase; // total data we've already moved past in *previous* buffers + private ReadOnlySequenceSegment? _tail; // the next tail node + private long _remainingTailLength; // how much more can we consume from the tail? + + public long ProtocolBytesRemaining => TotalAvailable; + + private readonly int CurrentAvailable => CurrentLength - _bufferIndex; + + private readonly long TotalAvailable => CurrentAvailable + _remainingTailLength; + private partial void UnsafeTrimCurrentBy(int count); + private readonly partial ref byte UnsafeCurrent { get; } + private readonly partial int CurrentLength { get; } + private partial void SetCurrent(ReadOnlySpan value); + private RespPrefix UnsafePeekPrefix() => (RespPrefix)UnsafeCurrent; + private readonly partial ReadOnlySpan UnsafePastPrefix(); + private readonly partial ReadOnlySpan CurrentSpan(); + + /// + /// Get the scalar value as a single-segment span. + /// + /// True if this is a non-streaming scalar element that covers a single span only, otherwise False. + /// If a scalar reports False, can be used to iterate the entire payload. + /// When True, the contents of the scalar value. + public readonly bool TryGetSpan(out ReadOnlySpan value) + { + if (IsInlineScalar && CurrentAvailable >= _length) + { + value = CurrentSpan().Slice(0, _length); + return true; + } + + value = default; + return IsNullScalar; + } + + /// + /// Returns the position after the end of the current element. + /// + public readonly long BytesConsumed => _positionBase + _bufferIndex + TrailingLength; + + /// + /// Body length of scalar values, plus any terminating sentinels. + /// + private readonly int TrailingLength => (_flags & RespFlags.IsInlineScalar) == 0 ? 0 : (_length + 2); + + /// + /// Gets the RESP kind of the current element. + /// + public readonly RespPrefix Prefix => _prefix; + + /// + /// The payload length of this scalar element (includes combined length for streaming scalars). + /// + public readonly int ScalarLength() => + IsInlineScalar ? _length : IsNullScalar ? 0 : checked((int)ScalarLengthSlow()); + + /// + /// Indicates whether this scalar value is zero-length. + /// + public readonly bool ScalarIsEmpty() => + IsInlineScalar ? _length == 0 : (IsNullScalar || !ScalarChunks().MoveNext()); + + /// + /// The payload length of this scalar element (includes combined length for streaming scalars). + /// + public readonly long ScalarLongLength() => IsInlineScalar ? _length : IsNullScalar ? 0 : ScalarLengthSlow(); + + private readonly long ScalarLengthSlow() + { + DemandScalar(); + long length = 0; + var iterator = ScalarChunks(); + while (iterator.MoveNext()) + { + length += iterator.CurrentLength; + } + + return length; + } + + /// + /// The number of child elements associated with an aggregate. + /// + /// For + /// and aggregates, this is twice the value reported in the RESP protocol, + /// i.e. a map of the form %2\r\n... will report 4 as the length. + /// Note that if the data could be streaming (), it may be preferable to use + /// the API, using the API to update the outer reader. + public readonly int AggregateLength() => + (_flags & (RespFlags.IsAggregate | RespFlags.IsStreaming)) == RespFlags.IsAggregate + ? _length + : AggregateLengthSlow(); + + public delegate T Projection(ref RespReader value); + + public void FillAll(scoped Span target, Projection projection) + { + DemandNotNull(); + AggregateChildren().FillAll(target, projection); + } + + private readonly int AggregateLengthSlow() + { + switch (_flags & (RespFlags.IsAggregate | RespFlags.IsStreaming)) + { + case RespFlags.IsAggregate: + return _length; + case RespFlags.IsAggregate | RespFlags.IsStreaming: + break; + default: + DemandAggregate(); // we expect this to throw + break; + } + + int count = 0; + var reader = Clone(); + while (true) + { + if (!reader.TryMoveNext()) ThrowEof(); + if (reader.Prefix == RespPrefix.StreamTerminator) + { + return count; + } + + reader.SkipChildren(); + count++; + } + } + + /// + /// Indicates whether this is a scalar value, i.e. with a potential payload body. + /// + public readonly bool IsScalar => (_flags & RespFlags.IsScalar) != 0; + + internal readonly bool IsInlineScalar => (_flags & RespFlags.IsInlineScalar) != 0; + + internal readonly bool IsNullScalar => + (_flags & (RespFlags.IsScalar | RespFlags.IsNull)) == (RespFlags.IsScalar | RespFlags.IsNull); + + /// + /// Indicates whether this is an aggregate value, i.e. represents a collection of sub-values. + /// + public readonly bool IsAggregate => (_flags & RespFlags.IsAggregate) != 0; + + /// + /// Indicates whether this is a null value; this could be an explicit , + /// or a scalar or aggregate a negative reported length. + /// + public readonly bool IsNull => (_flags & RespFlags.IsNull) != 0; + + /// + /// Indicates whether this is an attribute value, i.e. metadata relating to later element data. + /// + public readonly bool IsAttribute => (_flags & RespFlags.IsAttribute) != 0; + + /// + /// Indicates whether this represents streaming content, where the or is not known in advance. + /// + public readonly bool IsStreaming => (_flags & RespFlags.IsStreaming) != 0; + + /// + /// Equivalent to both and . + /// + internal readonly bool IsStreamingScalar => (_flags & (RespFlags.IsScalar | RespFlags.IsStreaming)) == + (RespFlags.IsScalar | RespFlags.IsStreaming); + + /// + /// Indicates errors reported inside the protocol. + /// + public readonly bool IsError => (_flags & RespFlags.IsError) != 0; + + /// + /// Gets the effective change (in terms of how many RESP nodes we expect to see) from consuming this element. + /// For simple scalars, this is -1 because we have one less node to read; for simple aggregates, this is + /// AggregateLength-1 because we will have consumed one element, but now need to read the additional + /// child elements. Attributes report 0, since they supplement data + /// we still need to consume. The final terminator for streaming data reports a delta of -1, otherwise: 0. + /// + /// This does not account for being nested inside a streaming aggregate; the caller must deal with that manually. + internal int Delta() => + (_flags & (RespFlags.IsScalar | RespFlags.IsAggregate | RespFlags.IsStreaming | RespFlags.IsAttribute)) switch + { + RespFlags.IsScalar => -1, + RespFlags.IsAggregate => _length - 1, + RespFlags.IsAggregate | RespFlags.IsAttribute => _length, + _ => 0, + }; + + /// + /// Assert that this is the final element in the current payload. + /// + /// If additional elements are available. + public void DemandEnd() + { + while (IsStreamingScalar) + { + if (!TryReadNext()) ThrowEof(); + } + + if (TryReadNext()) + { + Throw(Prefix); + } + + static void Throw(RespPrefix prefix) => + throw new InvalidOperationException($"Expected end of payload, but found {prefix}"); + } + + private bool TryReadNextSkipAttributes() + { + while (TryReadNext()) + { + if (IsAttribute) + { + SkipChildren(); + } + else + { + return true; + } + } + + return false; + } + + private bool TryReadNextProcessAttributes(RespAttributeReader respAttributeReader, ref T attributes) + { + while (TryReadNext()) + { + if (IsAttribute) + { + respAttributeReader.Read(ref this, ref attributes); + } + else + { + return true; + } + } + + return false; + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// If the data is exhausted before a streaming scalar is exhausted. + /// If the data contains an explicit error element. + public bool TryMoveNext() + { + while (IsStreamingScalar) // close out the current streaming scalar + { + if (!TryReadNextSkipAttributes()) ThrowEof(); + } + + if (TryReadNextSkipAttributes()) + { + if (IsError) ThrowError(); + return true; + } + + return false; + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// Whether to check and throw for error messages. + /// If the data is exhausted before a streaming scalar is exhausted. + /// If the data contains an explicit error element. + public bool TryMoveNext(bool checkError) + { + while (IsStreamingScalar) // close out the current streaming scalar + { + if (!TryReadNextSkipAttributes()) ThrowEof(); + } + + if (TryReadNextSkipAttributes()) + { + if (checkError && IsError) ThrowError(); + return true; + } + + return false; + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// Parser for attribute data preceding the data. + /// The state for attributes encountered. + /// If the data is exhausted before a streaming scalar is exhausted. + /// If the data contains an explicit error element. + /// The type of data represented by this reader. + public bool TryMoveNext(RespAttributeReader respAttributeReader, ref T attributes) + { + while (IsStreamingScalar) // close out the current streaming scalar + { + if (!TryReadNextSkipAttributes()) ThrowEof(); + } + + if (TryReadNextProcessAttributes(respAttributeReader, ref attributes)) + { + if (IsError) ThrowError(); + return true; + } + + return false; + } + + /// + /// Move to the next content element, asserting that it is of the expected type; this skips attribute metadata, checking for RESP error messages by default. + /// + /// The expected data type. + /// If the data is exhausted before a streaming scalar is exhausted. + /// If the data contains an explicit error element. + /// If the data is not of the expected type. + public bool TryMoveNext(RespPrefix prefix) + { + bool result = TryMoveNext(); + if (result) Demand(prefix); + return result; + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + public void MoveNext() + { + if (!TryMoveNext()) ThrowEof(); + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// Parser for attribute data preceding the data. + /// The state for attributes encountered. + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// The type of data represented by this reader. + public void MoveNext(RespAttributeReader respAttributeReader, ref T attributes) + { + if (!TryMoveNext(respAttributeReader, ref attributes)) ThrowEof(); + } + + private bool MoveNextStreamingScalar() + { + if (IsStreamingScalar) + { + while (TryReadNext()) + { + if (IsAttribute) + { + SkipChildren(); + } + else + { + if (Prefix != RespPrefix.StreamContinuation) + ThrowProtocolFailure("Streaming continuation expected"); + return _length > 0; + } + } + + ThrowEof(); // we should have found something! + } + + return false; + } + + /// + /// Move to the next content element () and assert that it is a scalar (). + /// + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// If the data is not a scalar type. + public void MoveNextScalar() + { + MoveNext(); + DemandScalar(); + } + + /// + /// Move to the next content element () and assert that it is an aggregate (). + /// + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// If the data is not an aggregate type. + public void MoveNextAggregate() + { + MoveNext(); + DemandAggregate(); + } + + /// + /// Move to the next content element () and assert that it of type specified + /// in . + /// + /// The expected data type. + /// Parser for attribute data preceding the data. + /// The state for attributes encountered. + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// If the data is not of the expected type. + /// The type of data represented by this reader. + public void MoveNext(RespPrefix prefix, RespAttributeReader respAttributeReader, ref T attributes) + { + MoveNext(respAttributeReader, ref attributes); + Demand(prefix); + } + + /// + /// Move to the next content element () and assert that it of type specified + /// in . + /// + /// The expected data type. + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// If the data is not of the expected type. + public void MoveNext(RespPrefix prefix) + { + MoveNext(); + Demand(prefix); + } + + internal void Demand(RespPrefix prefix) + { + if (Prefix != prefix) Throw(prefix, Prefix); + + static void Throw(RespPrefix expected, RespPrefix actual) => + throw new InvalidOperationException($"Expected {expected} element, but found {actual}."); + } + + private readonly void ThrowError() => throw new RespException(ReadString()!); + + /// + /// Skip all sub elements of the current node; this includes both aggregate children and scalar streaming elements. + /// + public void SkipChildren() + { + // if this is a simple non-streaming scalar, then: there's nothing complex to do; otherwise, re-use the + // frame scanner logic to seek past the noise (this way, we avoid recursion etc) + switch (_flags & (RespFlags.IsScalar | RespFlags.IsAggregate | RespFlags.IsStreaming)) + { + case RespFlags.None: + // no current element + break; + case RespFlags.IsScalar: + // simple scalar + MovePastCurrent(); + break; + default: + // something more complex + RespScanState state = new(in this); + if (!state.TryRead(ref this, out _)) ThrowEof(); + break; + } + } + + /// + /// Reads the current element as a string value. + /// + public readonly string? ReadString() => ReadString(out _); + + /// + /// Reads the current element as a string value. + /// + public readonly string? ReadString(out string prefix) + { + byte[] pooled = []; + try + { + var span = Buffer(ref pooled, stackalloc byte[256]); + prefix = ""; + if (span.IsEmpty) + { + return IsNull ? null : ""; + } + + if (Prefix == RespPrefix.VerbatimString + && span.Length >= 4 && span[3] == ':') + { + // "the first three bytes provide information about the format of the following string, + // which can be txt for plain text, or mkd for markdown. The fourth byte is always :. + // Then the real string follows." + var prefixValue = RespConstants.UnsafeCpuUInt32(span); + if (prefixValue == PrefixTxt) + { + prefix = "txt"; + } + else if (prefixValue == PrefixMkd) + { + prefix = "mkd"; + } + else + { + prefix = RespConstants.UTF8.GetString(span.Slice(0, 3)); + } + + span = span.Slice(4); + } + + return RespConstants.UTF8.GetString(span); + } + finally + { + ArrayPool.Shared.Return(pooled); + } + } + + private static readonly uint + PrefixTxt = RespConstants.UnsafeCpuUInt32("txt:"u8), + PrefixMkd = RespConstants.UnsafeCpuUInt32("mkd:"u8); + + /// + /// Reads the current element as a string value. + /// + public readonly byte[]? ReadByteArray() + { + byte[] pooled = []; + try + { + var span = Buffer(ref pooled, stackalloc byte[256]); + if (span.IsEmpty) + { + return IsNull ? null : []; + } + + return span.ToArray(); + } + finally + { + ArrayPool.Shared.Return(pooled); + } + } + + /// + /// Reads the current element using a general purpose text parser. + /// + /// The type of data being parsed. + public readonly T ParseBytes(Parser parser) + { + byte[] pooled = []; + var span = Buffer(ref pooled, stackalloc byte[256]); + try + { + return parser(span); + } + finally + { + ArrayPool.Shared.Return(pooled); + } + } + + /// + /// Reads the current element using a general purpose text parser. + /// + /// The type of data being parsed. + /// State required by the parser. + public readonly T ParseBytes(Parser parser, TState? state) + { + byte[] pooled = []; + var span = Buffer(ref pooled, stackalloc byte[256]); + try + { + return parser(span, default); + } + finally + { + ArrayPool.Shared.Return(pooled); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly ReadOnlySpan Buffer(Span target) + { + if (TryGetSpan(out var simple)) + { + return simple; + } + +#if NET6_0_OR_GREATER + return BufferSlow(ref Unsafe.NullRef(), target, usePool: false); +#else + byte[] pooled = []; + return BufferSlow(ref pooled, target, usePool: false); +#endif + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly ReadOnlySpan Buffer(scoped ref byte[] pooled, Span target = default) + => TryGetSpan(out var simple) ? simple : BufferSlow(ref pooled, target, true); + + [MethodImpl(MethodImplOptions.NoInlining)] + private readonly ReadOnlySpan BufferSlow(scoped ref byte[] pooled, Span target, bool usePool) + { + DemandScalar(); + + if (IsInlineScalar && usePool) + { + // grow to the correct size in advance, if needed + var length = ScalarLength(); + if (length > target.Length) + { + var bigger = ArrayPool.Shared.Rent(length); + ArrayPool.Shared.Return(pooled); + target = pooled = bigger; + } + } + + var iterator = ScalarChunks(); + ReadOnlySpan current; + int offset = 0; + while (iterator.MoveNext()) + { + // will the current chunk fit? + current = iterator.Current; + if (current.TryCopyTo(target.Slice(offset))) + { + // fits into the current buffer + offset += current.Length; + } + else if (!usePool) + { + // rent disallowed; fill what we can + var available = target.Slice(offset); + current.Slice(0, available.Length).CopyTo(available); + return target; // we filled it + } + else + { + // rent a bigger buffer, copy and recycle + var bigger = ArrayPool.Shared.Rent(offset + current.Length); + if (offset != 0) + { + target.Slice(0, offset).CopyTo(bigger); + } + + ArrayPool.Shared.Return(pooled); + target = pooled = bigger; + current.CopyTo(target.Slice(offset)); + } + } + + return target.Slice(0, offset); + } + + /// + /// Reads the current element using a general purpose byte parser. + /// + /// The type of data being parsed. + public readonly T ParseChars(Parser parser) + { + byte[] bArr = []; + char[] cArr = []; + try + { + var bSpan = Buffer(ref bArr, stackalloc byte[128]); + var maxChars = RespConstants.UTF8.GetMaxCharCount(bSpan.Length); + Span cSpan = maxChars <= 128 ? stackalloc char[128] : (cArr = ArrayPool.Shared.Rent(maxChars)); + int chars = RespConstants.UTF8.GetChars(bSpan, cSpan); + return parser(cSpan.Slice(0, chars)); + } + finally + { + ArrayPool.Shared.Return(bArr); + ArrayPool.Shared.Return(cArr); + } + } + + /// + /// Reads the current element using a general purpose byte parser. + /// + /// The type of data being parsed. + /// State required by the parser. + public readonly T ParseChars(Parser parser, TState? state) + { + byte[] bArr = []; + char[] cArr = []; + try + { + var bSpan = Buffer(ref bArr, stackalloc byte[128]); + var maxChars = RespConstants.UTF8.GetMaxCharCount(bSpan.Length); + Span cSpan = maxChars <= 128 ? stackalloc char[128] : (cArr = ArrayPool.Shared.Rent(maxChars)); + int chars = RespConstants.UTF8.GetChars(bSpan, cSpan); + return parser(cSpan.Slice(0, chars), state); + } + finally + { + ArrayPool.Shared.Return(bArr); + ArrayPool.Shared.Return(cArr); + } + } + +#if NET7_0_OR_GREATER + /// + /// Reads the current element using . + /// + /// The type of data being parsed. +#pragma warning disable RS0016, RS0027 // back-compat overload + public readonly T ParseChars(IFormatProvider? formatProvider = null) where T : ISpanParsable +#pragma warning restore RS0016, RS0027 // back-compat overload + { + byte[] bArr = []; + char[] cArr = []; + try + { + var bSpan = Buffer(ref bArr, stackalloc byte[128]); + var maxChars = RespConstants.UTF8.GetMaxCharCount(bSpan.Length); + Span cSpan = maxChars <= 128 ? stackalloc char[128] : (cArr = ArrayPool.Shared.Rent(maxChars)); + int chars = RespConstants.UTF8.GetChars(bSpan, cSpan); + return T.Parse(cSpan.Slice(0, chars), formatProvider ?? CultureInfo.InvariantCulture); + } + finally + { + ArrayPool.Shared.Return(bArr); + ArrayPool.Shared.Return(cArr); + } + } +#endif + +#if NET8_0_OR_GREATER + /// + /// Reads the current element using . + /// + /// The type of data being parsed. +#pragma warning disable RS0016, RS0027 // back-compat overload + public readonly T ParseBytes(IFormatProvider? formatProvider = null) where T : IUtf8SpanParsable +#pragma warning restore RS0016, RS0027 // back-compat overload + { + byte[] bArr = []; + try + { + var bSpan = Buffer(ref bArr, stackalloc byte[128]); + return T.Parse(bSpan, formatProvider ?? CultureInfo.InvariantCulture); + } + finally + { + ArrayPool.Shared.Return(bArr); + } + } +#endif + + /// + /// General purpose parsing callback. + /// + /// The type of source data being parsed. + /// State required by the parser. + /// The output type of data being parsed. + public delegate TValue Parser(ReadOnlySpan value, TState? state); + + /// + /// General purpose parsing callback. + /// + /// The type of source data being parsed. + /// The output type of data being parsed. + public delegate TValue Parser(ReadOnlySpan value); + + /// + /// Initializes a new instance of the struct. + /// + /// The raw contents to parse with this instance. + public RespReader(ReadOnlySpan value) + { + _length = 0; + _flags = RespFlags.None; + _prefix = RespPrefix.None; + SetCurrent(value); + + _remainingTailLength = _positionBase = 0; + _tail = null; + } + + private void MovePastCurrent() + { + // skip past the trailing portion of a value, if any + var skip = TrailingLength; + if (_bufferIndex + skip <= CurrentLength) + { + _bufferIndex += skip; // available in the current buffer + } + else + { + AdvanceSlow(skip); + } + + // reset the current state + _length = 0; + _flags = 0; + _prefix = RespPrefix.None; + } + + /// + public RespReader(scoped in ReadOnlySequence value) +#if NETCOREAPP3_0_OR_GREATER + : this(value.FirstSpan) +#else + : this(value.First.Span) +#endif + { + if (!value.IsSingleSegment) + { + _remainingTailLength = value.Length - CurrentLength; + _tail = (value.Start.GetObject() as ReadOnlySequenceSegment)?.Next ?? MissingNext(); + } + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + static ReadOnlySequenceSegment MissingNext() => + throw new ArgumentException("Unable to extract tail segment", nameof(value)); + } + + /// + /// Attempt to move to the next RESP element. + /// + /// Unless you are intentionally handling errors, attributes and streaming data, should be preferred. + [EditorBrowsable(EditorBrowsableState.Never), Browsable(false)] + public unsafe bool TryReadNext() + { + MovePastCurrent(); + +#if NETCOREAPP3_0_OR_GREATER + // check what we have available; don't worry about zero/fetching the next segment; this is only + // for SIMD lookup, and zero would only apply when data ends exactly on segment boundaries, which + // is incredible niche + var available = CurrentAvailable; + + if (Avx2.IsSupported && Bmi1.IsSupported && available >= sizeof(uint)) + { + // read the first 4 bytes + ref byte origin = ref UnsafeCurrent; + var comparand = Unsafe.ReadUnaligned(ref origin); + + // broadcast those 4 bytes into a vector, mask to get just the first and last byte, and apply a SIMD equality test with our known cases + var eqs = + Avx2.CompareEqual(Avx2.And(Avx2.BroadcastScalarToVector256(&comparand), Raw.FirstLastMask), Raw.CommonRespPrefixes); + + // reinterpret that as floats, and pick out the sign bits (which will be 1 for "equal", 0 for "not equal"); since the + // test cases are mutually exclusive, we expect zero or one matches, so: lzcount tells us which matched + var index = + Bmi1.TrailingZeroCount((uint)Avx.MoveMask(Unsafe.As, Vector256>(ref eqs))); + int len; +#if DEBUG + if (VectorizeDisabled) index = uint.MaxValue; // just to break the switch +#endif + switch (index) + { + case Raw.CommonRespIndex_Success when available >= 5 && Unsafe.Add(ref origin, 4) == (byte)'\n': + _prefix = RespPrefix.SimpleString; + _length = 2; + _bufferIndex++; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + return true; + case Raw.CommonRespIndex_SingleDigitInteger when Unsafe.Add(ref origin, 2) == (byte)'\r': + _prefix = RespPrefix.Integer; + _length = 1; + _bufferIndex++; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + return true; + case Raw.CommonRespIndex_DoubleDigitInteger when available >= 5 && Unsafe.Add(ref origin, 4) == (byte)'\n': + _prefix = RespPrefix.Integer; + _length = 2; + _bufferIndex++; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + return true; + case Raw.CommonRespIndex_SingleDigitString when Unsafe.Add(ref origin, 2) == (byte)'\r': + if (comparand == RespConstants.BulkStringStreaming) + { + _flags = RespFlags.IsScalar | RespFlags.IsStreaming; + } + else + { + len = ParseSingleDigit(Unsafe.Add(ref origin, 1)); + if (available < len + 6) break; // need more data + + UnsafeAssertClLf(4 + len); + _length = len; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + } + _prefix = RespPrefix.BulkString; + _bufferIndex += 4; + return true; + case Raw.CommonRespIndex_DoubleDigitString when available >= 5 && Unsafe.Add(ref origin, 4) == (byte)'\n': + if (comparand == RespConstants.BulkStringNull) + { + _length = 0; + _flags = RespFlags.IsScalar | RespFlags.IsNull; + } + else + { + len = ParseDoubleDigitsNonNegative(ref Unsafe.Add(ref origin, 1)); + if (available < len + 7) break; // need more data + + UnsafeAssertClLf(5 + len); + _length = len; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + } + _prefix = RespPrefix.BulkString; + _bufferIndex += 5; + return true; + case Raw.CommonRespIndex_SingleDigitArray when Unsafe.Add(ref origin, 2) == (byte)'\r': + if (comparand == RespConstants.ArrayStreaming) + { + _flags = RespFlags.IsAggregate | RespFlags.IsStreaming; + } + else + { + _flags = RespFlags.IsAggregate; + _length = ParseSingleDigit(Unsafe.Add(ref origin, 1)); + } + _prefix = RespPrefix.Array; + _bufferIndex += 4; + return true; + case Raw.CommonRespIndex_DoubleDigitArray when available >= 5 && Unsafe.Add(ref origin, 4) == (byte)'\n': + if (comparand == RespConstants.ArrayNull) + { + _flags = RespFlags.IsAggregate | RespFlags.IsNull; + } + else + { + _length = ParseDoubleDigitsNonNegative(ref Unsafe.Add(ref origin, 1)); + _flags = RespFlags.IsAggregate; + } + _prefix = RespPrefix.Array; + _bufferIndex += 5; + return true; + case Raw.CommonRespIndex_Error: + len = UnsafePastPrefix().IndexOf(RespConstants.CrlfBytes); + if (len < 0) break; // need more data + + _prefix = RespPrefix.SimpleError; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar | RespFlags.IsError; + _length = len; + _bufferIndex++; + return true; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static int ParseDoubleDigitsNonNegative(ref byte value) => (10 * ParseSingleDigit(value)) + ParseSingleDigit(Unsafe.Add(ref value, 1)); +#endif + + // no fancy vectorization, but: we can still try to find the payload the fast way in a single segment + if (_bufferIndex + 3 <= CurrentLength) // shortest possible RESP fragment is length 3 + { + var remaining = UnsafePastPrefix(); + switch (_prefix = UnsafePeekPrefix()) + { + case RespPrefix.SimpleString: + case RespPrefix.SimpleError: + case RespPrefix.Integer: + case RespPrefix.Boolean: + case RespPrefix.Double: + case RespPrefix.BigInteger: + // CRLF-terminated + _length = remaining.IndexOf(RespConstants.CrlfBytes); + if (_length < 0) break; // can't find, need more data + _bufferIndex++; // payload follows prefix directly + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + if (_prefix == RespPrefix.SimpleError) _flags |= RespFlags.IsError; + return true; + case RespPrefix.BulkError: + case RespPrefix.BulkString: + case RespPrefix.VerbatimString: + // length prefix with value payload; first, the length + switch (TryReadLengthPrefix(remaining, out _length, out int consumed)) + { + case LengthPrefixResult.Length: + // still need to valid terminating CRLF + if (remaining.Length < consumed + _length + 2) break; // need more data + UnsafeAssertClLf(1 + consumed + _length); + + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + break; + case LengthPrefixResult.Null: + _flags = RespFlags.IsScalar | RespFlags.IsNull; + break; + case LengthPrefixResult.Streaming: + _flags = RespFlags.IsScalar | RespFlags.IsStreaming; + break; + } + + if (_flags == 0) break; // will need more data to know + if (_prefix == RespPrefix.BulkError) _flags |= RespFlags.IsError; + _bufferIndex += 1 + consumed; + return true; + case RespPrefix.StreamContinuation: + // length prefix, possibly with value payload; first, the length + switch (TryReadLengthPrefix(remaining, out _length, out consumed)) + { + case LengthPrefixResult.Length when _length == 0: + // EOF, no payload + _flags = RespFlags + .IsScalar; // don't claim as streaming, we want this to count towards delta-decrement + break; + case LengthPrefixResult.Length: + // still need to valid terminating CRLF + if (remaining.Length < consumed + _length + 2) break; // need more data + UnsafeAssertClLf(1 + consumed + _length); + + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar | RespFlags.IsStreaming; + break; + case LengthPrefixResult.Null: + case LengthPrefixResult.Streaming: + ThrowProtocolFailure("Invalid streaming scalar length prefix"); + break; + } + + if (_flags == 0) break; // will need more data to know + _bufferIndex += 1 + consumed; + return true; + case RespPrefix.Array: + case RespPrefix.Set: + case RespPrefix.Map: + case RespPrefix.Push: + case RespPrefix.Attribute: + // length prefix without value payload (child values follow) + switch (TryReadLengthPrefix(remaining, out _length, out consumed)) + { + case LengthPrefixResult.Length: + _flags = RespFlags.IsAggregate; + if (AggregateLengthNeedsDoubling()) _length *= 2; + break; + case LengthPrefixResult.Null: + _flags = RespFlags.IsAggregate | RespFlags.IsNull; + break; + case LengthPrefixResult.Streaming: + _flags = RespFlags.IsAggregate | RespFlags.IsStreaming; + break; + } + + if (_flags == 0) break; // will need more data to know + if (_prefix is RespPrefix.Attribute) _flags |= RespFlags.IsAttribute; + _bufferIndex += consumed + 1; + return true; + case RespPrefix.Null: // null + // note we already checked we had 3 bytes + UnsafeAssertClLf(1); + _flags = RespFlags.IsScalar | RespFlags.IsNull; + _bufferIndex += 3; // skip prefix+terminator + return true; + case RespPrefix.StreamTerminator: + // note we already checked we had 3 bytes + UnsafeAssertClLf(1); + _flags = RespFlags.IsAggregate; // don't claim as streaming - this counts towards delta + _bufferIndex += 3; // skip prefix+terminator + return true; + default: + ThrowProtocolFailure("Unexpected protocol prefix: " + _prefix); + return false; + } + } + + return TryReadNextSlow(ref this); + } + + private static bool TryReadNextSlow(ref RespReader live) + { + // in the case of failure, we don't want to apply any changes, + // so we work against an isolated copy until we're happy + live.MovePastCurrent(); + RespReader isolated = live; + + int next = isolated.RawTryReadByte(); + if (next < 0) return false; + + switch (isolated._prefix = (RespPrefix)next) + { + case RespPrefix.SimpleString: + case RespPrefix.SimpleError: + case RespPrefix.Integer: + case RespPrefix.Boolean: + case RespPrefix.Double: + case RespPrefix.BigInteger: + // CRLF-terminated + if (!isolated.RawTryFindCrLf(out isolated._length)) return false; + isolated._flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + if (isolated._prefix == RespPrefix.SimpleError) isolated._flags |= RespFlags.IsError; + break; + case RespPrefix.BulkError: + case RespPrefix.BulkString: + case RespPrefix.VerbatimString: + // length prefix with value payload + switch (isolated.RawTryReadLengthPrefix()) + { + case LengthPrefixResult.Length: + // still need to valid terminating CRLF + isolated._flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + if (!isolated.RawTryAssertInlineScalarPayloadCrLf()) return false; + break; + case LengthPrefixResult.Null: + isolated._flags = RespFlags.IsScalar | RespFlags.IsNull; + break; + case LengthPrefixResult.Streaming: + isolated._flags = RespFlags.IsScalar | RespFlags.IsStreaming; + break; + case LengthPrefixResult.NeedMoreData: + return false; + default: + ThrowProtocolFailure("Unexpected length prefix"); + return false; + } + + if (isolated._prefix == RespPrefix.BulkError) isolated._flags |= RespFlags.IsError; + break; + case RespPrefix.Array: + case RespPrefix.Set: + case RespPrefix.Map: + case RespPrefix.Push: + case RespPrefix.Attribute: + // length prefix without value payload (child values follow) + switch (isolated.RawTryReadLengthPrefix()) + { + case LengthPrefixResult.Length: + isolated._flags = RespFlags.IsAggregate; + if (isolated.AggregateLengthNeedsDoubling()) isolated._length *= 2; + break; + case LengthPrefixResult.Null: + isolated._flags = RespFlags.IsAggregate | RespFlags.IsNull; + break; + case LengthPrefixResult.Streaming: + isolated._flags = RespFlags.IsAggregate | RespFlags.IsStreaming; + break; + case LengthPrefixResult.NeedMoreData: + return false; + default: + ThrowProtocolFailure("Unexpected length prefix"); + return false; + } + + if (isolated._prefix is RespPrefix.Attribute) isolated._flags |= RespFlags.IsAttribute; + break; + case RespPrefix.Null: // null + if (!isolated.RawAssertCrLf()) return false; + isolated._flags = RespFlags.IsScalar | RespFlags.IsNull; + break; + case RespPrefix.StreamTerminator: + if (!isolated.RawAssertCrLf()) return false; + isolated._flags = RespFlags.IsAggregate; // don't claim as streaming - this counts towards delta + break; + case RespPrefix.StreamContinuation: + // length prefix, possibly with value payload; first, the length + switch (isolated.RawTryReadLengthPrefix()) + { + case LengthPrefixResult.Length when isolated._length == 0: + // EOF, no payload + isolated._flags = + RespFlags + .IsScalar; // don't claim as streaming, we want this to count towards delta-decrement + break; + case LengthPrefixResult.Length: + // still need to valid terminating CRLF + isolated._flags = RespFlags.IsScalar | RespFlags.IsInlineScalar | RespFlags.IsStreaming; + if (!isolated.RawTryAssertInlineScalarPayloadCrLf()) return false; // need more data + break; + case LengthPrefixResult.Null: + case LengthPrefixResult.Streaming: + ThrowProtocolFailure("Invalid streaming scalar length prefix"); + break; + case LengthPrefixResult.NeedMoreData: + default: + return false; + } + + break; + default: + ThrowProtocolFailure("Unexpected protocol prefix: " + isolated._prefix); + return false; + } + + // commit the speculative changes back, and accept + live = isolated; + return true; + } + + private void AdvanceSlow(long bytes) + { + while (bytes > 0) + { + var available = CurrentLength - _bufferIndex; + if (bytes <= available) + { + _bufferIndex += (int)bytes; + return; + } + + bytes -= available; + + if (!TryMoveToNextSegment()) Throw(); + } + + [DoesNotReturn] + static void Throw() => throw new EndOfStreamException( + "Unexpected end of payload; this is unexpected because we already validated that it was available!"); + } + + private bool AggregateLengthNeedsDoubling() => _prefix is RespPrefix.Map or RespPrefix.Attribute; + + private bool TryMoveToNextSegment() + { + while (_tail is not null && _remainingTailLength > 0) + { + var memory = _tail.Memory; + _tail = _tail.Next; + if (!memory.IsEmpty) + { + var span = memory.Span; // check we can get this before mutating anything + _positionBase += CurrentLength; + if (span.Length > _remainingTailLength) + { + span = span.Slice(0, (int)_remainingTailLength); + _remainingTailLength = 0; + } + else + { + _remainingTailLength -= span.Length; + } + + SetCurrent(span); + _bufferIndex = 0; + return true; + } + } + + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly bool IsOK() // go mad with this, because it is used so often + { + if (TryGetSpan(out var span) && span.Length == 2) + { + var u16 = Unsafe.ReadUnaligned(ref UnsafeCurrent); + return u16 == RespConstants.OKUInt16 | u16 == RespConstants.OKUInt16_LC; + } + + return IsSlow(RespConstants.OKBytes, RespConstants.OKBytes_LC); + } + + /// + /// Indicates whether the current element is a scalar with a value that matches the provided . + /// + /// The payload value to verify. + public readonly bool Is(ReadOnlySpan value) + => TryGetSpan(out var span) ? span.SequenceEqual(value) : IsSlow(value); + + /// + /// Indicates whether the current element is a scalar with a value that matches the provided . + /// + /// The payload value to verify. + public readonly bool Is(ReadOnlySpan value) + { + var bytes = RespConstants.UTF8.GetMaxByteCount(value.Length); + byte[]? oversized = null; + Span buffer = bytes <= 128 ? stackalloc byte[128] : (oversized = ArrayPool.Shared.Rent(bytes)); + bytes = RespConstants.UTF8.GetBytes(value, buffer); + bool result = Is(buffer.Slice(0, bytes)); + if (oversized is not null) ArrayPool.Shared.Return(oversized); + return result; + } + + internal readonly bool IsInlneCpuUInt32(uint value) + { + if (IsInlineScalar && _length == sizeof(uint)) + { + return CurrentAvailable >= sizeof(uint) + ? Unsafe.ReadUnaligned(ref UnsafeCurrent) == value + : SlowIsInlneCpuUInt32(value); + } + + return false; + } + + private readonly bool SlowIsInlneCpuUInt32(uint value) + { + Debug.Assert(IsInlineScalar && _length == sizeof(uint), "should be inline scalar of length 4"); + Span buffer = stackalloc byte[sizeof(uint)]; + var copy = this; + copy.RawFillBytes(buffer); + return RespConstants.UnsafeCpuUInt32(buffer) == value; + } + + /// + /// Indicates whether the current element is a scalar with a value that matches the provided . + /// + /// The payload value to verify. + public readonly bool Is(byte value) + { + if (IsInlineScalar && _length == 1 && CurrentAvailable >= 1) + { + return UnsafeCurrent == value; + } + + ReadOnlySpan span = [value]; + return IsSlow(span); + } + + private readonly bool IsSlow(ReadOnlySpan testValue0, ReadOnlySpan testValue2) + => IsSlow(testValue0) || IsSlow(testValue2); + + private readonly bool IsSlow(ReadOnlySpan testValue) + { + DemandScalar(); + if (IsNull) return false; // nothing equals null + if (TotalAvailable < testValue.Length) return false; + + if (!IsStreaming && testValue.Length != ScalarLength()) return false; + + var iterator = ScalarChunks(); + while (true) + { + if (testValue.IsEmpty) + { + // nothing left to test; if also nothing left to read, great! + return !iterator.MoveNext(); + } + + if (!iterator.MoveNext()) + { + return false; // test is longer + } + + var current = iterator.Current; + if (testValue.Length < current.Length) return false; // payload is longer + + if (!current.SequenceEqual(testValue.Slice(0, current.Length))) return false; // payload is different + + testValue = testValue.Slice(current.Length); // validated; continue + } + } + + /// + /// Copy the current scalar value out into the supplied , or as much as can be copied. + /// + /// The destination for the copy operation. + /// The number of bytes successfully copied. + public readonly int CopyTo(Span target) + { + if (TryGetSpan(out var value)) + { + if (target.Length < value.Length) value = value.Slice(0, target.Length); + + value.CopyTo(target); + return value.Length; + } + + int totalBytes = 0; + var iterator = ScalarChunks(); + while (iterator.MoveNext()) + { + value = iterator.Current; + if (target.Length <= value.Length) + { + value.Slice(0, target.Length).CopyTo(target); + return totalBytes + target.Length; + } + + value.CopyTo(target); + target = target.Slice(value.Length); + totalBytes += value.Length; + } + + return totalBytes; + } + + /// + /// Copy the current scalar value out into the supplied , or as much as can be copied. + /// + /// The destination for the copy operation. + /// The number of bytes successfully copied. + public readonly int CopyTo(IBufferWriter target) + { + if (TryGetSpan(out var value)) + { + target.Write(value); + return value.Length; + } + + int totalBytes = 0; + var iterator = ScalarChunks(); + while (iterator.MoveNext()) + { + value = iterator.Current; + target.Write(value); + totalBytes += value.Length; + } + + return totalBytes; + } + + /// + /// Asserts that the current element is not null. + /// + public void DemandNotNull() + { + if (IsNull) Throw(); + static void Throw() => throw new InvalidOperationException("A non-null element was expected"); + } + + /// + /// Read the current element as a value. + /// + [SuppressMessage("Style", "IDE0018:Inline variable declaration", Justification = "No it can't - conditional")] + public readonly long ReadInt64() + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesInt64 + 1]); + long value; + if (!(span.Length <= RespConstants.MaxRawBytesInt64 + && Utf8Parser.TryParse(span, out value, out int bytes) + && bytes == span.Length)) + { + ThrowFormatException(); + value = 0; + } + + return value; + } + + /// + /// Try to read the current element as a value. + /// + public readonly bool TryReadInt64(out long value) + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesInt64 + 1]); + if (span.Length <= RespConstants.MaxRawBytesInt64) + { + return Utf8Parser.TryParse(span, out value, out int bytes) & bytes == span.Length; + } + + value = 0; + return false; + } + + /// + /// Read the current element as a value. + /// + [SuppressMessage("Style", "IDE0018:Inline variable declaration", Justification = "No it can't - conditional")] + public readonly int ReadInt32() + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesInt32 + 1]); + int value; + if (!(span.Length <= RespConstants.MaxRawBytesInt32 + && Utf8Parser.TryParse(span, out value, out int bytes) + && bytes == span.Length)) + { + ThrowFormatException(); + value = 0; + } + + return value; + } + + /// + /// Try to read the current element as a value. + /// + public readonly bool TryReadInt32(out int value) + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesInt32 + 1]); + if (span.Length <= RespConstants.MaxRawBytesInt32) + { + return Utf8Parser.TryParse(span, out value, out int bytes) & bytes == span.Length; + } + + value = 0; + return false; + } + + /// + /// Read the current element as a value. + /// + public readonly double ReadDouble() + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesNumber + 1]); + + if (span.Length <= RespConstants.MaxRawBytesNumber + && Utf8Parser.TryParse(span, out double value, out int bytes) + && bytes == span.Length) + { + return value; + } + + switch (span.Length) + { + case 3 when "inf"u8.SequenceEqual(span): + return double.PositiveInfinity; + case 3 when "nan"u8.SequenceEqual(span): + return double.NaN; + case 4 when "+inf"u8.SequenceEqual(span): // not actually mentioned in spec, but: we'll allow it + return double.PositiveInfinity; + case 4 when "-inf"u8.SequenceEqual(span): + return double.NegativeInfinity; + } + + ThrowFormatException(); + return 0; + } + + /// + /// Try to read the current element as a value. + /// + public bool TryReadDouble(out double value, bool allowTokens = true) + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesNumber + 1]); + + if (span.Length <= RespConstants.MaxRawBytesNumber + && Utf8Parser.TryParse(span, out value, out int bytes) + && bytes == span.Length) + { + return true; + } + + if (allowTokens) + { + switch (span.Length) + { + case 3 when "inf"u8.SequenceEqual(span): + value = double.PositiveInfinity; + return true; + case 3 when "nan"u8.SequenceEqual(span): + value = double.NaN; + return true; + case 4 when "+inf"u8.SequenceEqual(span): // not actually mentioned in spec, but: we'll allow it + value = double.PositiveInfinity; + return true; + case 4 when "-inf"u8.SequenceEqual(span): + value = double.NegativeInfinity; + return true; + } + } + + value = 0; + return false; + } + + /// + /// Note this uses a stackalloc buffer; requesting too much may overflow the stack. + /// + internal readonly bool UnsafeTryReadShortAscii(out string value, int maxLength = 127) + { + var span = Buffer(stackalloc byte[maxLength + 1]); + value = ""; + if (span.IsEmpty) return true; + + if (span.Length <= maxLength) + { + // check for anything that looks binary or unicode + foreach (var b in span) + { + // allow [SPACE]-thru-[DEL], plus CR/LF + if (!(b < 127 & (b >= 32 | (b is 12 or 13)))) + { + return false; + } + } + + value = Encoding.UTF8.GetString(span); + return true; + } + + return false; + } + + /// + /// Read the current element as a value. + /// + [SuppressMessage("Style", "IDE0018:Inline variable declaration", Justification = "No it can't - conditional")] + public readonly decimal ReadDecimal() + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesNumber + 1]); + decimal value; + if (!(span.Length <= RespConstants.MaxRawBytesNumber + && Utf8Parser.TryParse(span, out value, out int bytes) + && bytes == span.Length)) + { + ThrowFormatException(); + value = 0; + } + + return value; + } + + /// + /// Read the current element as a value. + /// + public readonly bool ReadBoolean() + { + var span = Buffer(stackalloc byte[2]); + switch (span.Length) + { + case 1: + switch (span[0]) + { + case (byte)'0' when Prefix == RespPrefix.Integer: return false; + case (byte)'1' when Prefix == RespPrefix.Integer: return true; + case (byte)'f' when Prefix == RespPrefix.Boolean: return false; + case (byte)'t' when Prefix == RespPrefix.Boolean: return true; + } + + break; + case 2 when Prefix == RespPrefix.SimpleString && IsOK(): return true; + } + + ThrowFormatException(); + return false; + } + + /// + /// Parse a scalar value as an enum of type . + /// + /// The value to report if the value is not recognized. + /// The type of enum being parsed. + public readonly T ReadEnum(T unknownValue = default) where T : struct, Enum + { +#if NET6_0_OR_GREATER + return ParseChars(static (chars, state) => Enum.TryParse(chars, true, out T value) ? value : state, unknownValue); +#else + return Enum.TryParse(ReadString(), true, out T value) ? value : unknownValue; +#endif + } + + public TResult[]? ReadArray(Projection projection, bool scalar = false) + { + DemandAggregate(); + if (IsNull) return null; + var len = AggregateLength(); + if (len == 0) return []; + var result = new TResult[len]; + if (scalar) + { + // if the data to be consumed is simple (scalar), we can use + // a simpler path that doesn't need to worry about RESP subtrees + for (int i = 0; i < result.Length; i++) + { + MoveNextScalar(); + result[i] = projection(ref this); + } + } + else + { + var agg = AggregateChildren(); + agg.FillAll(result, projection); + agg.MovePast(out this); + } + + return result; + } + + public TResult[]? ReadPairArray( + Projection first, + Projection second, + Func combine, + bool scalar = true) + { + DemandAggregate(); + if (IsNull) return null; + int sourceLength = AggregateLength(); + if (sourceLength is 0 or 1) return []; + var result = new TResult[sourceLength >> 1]; + if (scalar) + { + // if the data to be consumed is simple (scalar), we can use + // a simpler path that doesn't need to worry about RESP subtrees + for (int i = 0; i < result.Length; i++) + { + MoveNextScalar(); + var x = first(ref this); + MoveNextScalar(); + var y = second(ref this); + result[i] = combine(x, y); + } + // if we have an odd number of source elements, skip the last one + if ((sourceLength & 1) != 0) MoveNextScalar(); + } + else + { + var agg = AggregateChildren(); + agg.FillAll(result, first, second, combine); + agg.MovePast(out this); + } + return result; + } + internal TResult[]? ReadLeasedPairArray( + Projection first, + Projection second, + Func combine, + out int count, + bool scalar = true) + { + DemandAggregate(); + if (IsNull) + { + count = 0; + return null; + } + int sourceLength = AggregateLength(); + count = sourceLength >> 1; + if (count is 0) return []; + + var oversized = ArrayPool.Shared.Rent(count); + var result = oversized.AsSpan(0, count); + if (scalar) + { + // if the data to be consumed is simple (scalar), we can use + // a simpler path that doesn't need to worry about RESP subtrees + for (int i = 0; i < result.Length; i++) + { + MoveNextScalar(); + var x = first(ref this); + MoveNextScalar(); + var y = second(ref this); + result[i] = combine(x, y); + } + // if we have an odd number of source elements, skip the last one + if ((sourceLength & 1) != 0) MoveNextScalar(); + } + else + { + var agg = AggregateChildren(); + agg.FillAll(result, first, second, combine); + agg.MovePast(out this); + } + return oversized; + } +} diff --git a/src/RESPite/Messages/RespScanState.cs b/src/RESPite/Messages/RespScanState.cs new file mode 100644 index 000000000..d7038c30d --- /dev/null +++ b/src/RESPite/Messages/RespScanState.cs @@ -0,0 +1,162 @@ +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using StackExchange.Redis; + +namespace RESPite.Messages; + +/// +/// Holds state used for RESP frame parsing, i.e. detecting the RESP for an entire top-level message. +/// +[Experimental(Experiments.Respite, UrlFormat = Experiments.UrlFormat)] +public struct RespScanState +{ + /* + The key point of ScanState is to skim over a RESP stream with minimal frame processing, to find the + end of a single top-level RESP message. We start by expecting 1 message, and then just read, with the + rules that the end of a message subtracts one, and aggregates add N. Streaming scalars apply zero offset + until the scalar stream terminator. Attributes also apply zero offset. + Note that streaming aggregates change the rules - when at least one streaming aggregate is in effect, + no offsets are applied until we get back out of the outermost streaming aggregate - we achieve this + by simply counting the streaming aggregate depth, which is usually zero. + Note that in reality streaming (scalar and aggregates) and attributes are non-existent; in addition + to being specific to RESP3, no known server currently implements these parts of the RESP3 specification, + so everything here is theoretical, but: works according to the spec. + */ + private int _delta; // when this becomes -1, we have fully read a top-level message; + private ushort _streamingAggregateDepth; + private RespPrefix _prefix; + + public RespPrefix Prefix => _prefix; + + private long _totalBytes; +#if DEBUG + private int _elementCount; + + /// + public override string ToString() => $"{_prefix}, consumed: {_totalBytes} bytes, {_elementCount} nodes, complete: {IsComplete}"; +#else + /// + public override string ToString() => _prefix.ToString(); +#endif + + /// + public override bool Equals([NotNullWhen(true)] object? obj) => throw new NotSupportedException(); + + /// + public override int GetHashCode() => throw new NotSupportedException(); + + /// + /// Gets whether an entire top-level RESP message has been consumed. + /// + public bool IsComplete => _delta == -1; + + /// + /// Gets the total length of the payload read (or read so far, if it is not yet complete); this combines payloads from multiple + /// TryRead operations. + /// + public long TotalBytes => _totalBytes; + + // used when spotting common replies - we entirely bypass the usual reader/delta mechanism + internal void SetComplete(int totalBytes, RespPrefix prefix) + { + _totalBytes = totalBytes; + _delta = -1; + _prefix = prefix; +#if DEBUG + _elementCount = 1; +#endif + } + + /// + /// The amount of data, in bytes, to read before attempting to read the next frame. + /// + public const int MinBytes = 3; // minimum legal RESP frame is: _\r\n + + /// + /// Create a new value that can parse the supplied node (and subtree). + /// + internal RespScanState(in RespReader reader) + { + Debug.Assert(reader.Prefix != RespPrefix.None, "missing RESP prefix"); + _totalBytes = 0; + _delta = reader.GetInitialScanCount(out _streamingAggregateDepth); + } + + /// + /// Scan as far as possible, stopping when an entire top-level RESP message has been consumed or the data is exhausted. + /// + /// True if a top-level RESP message has been consumed. + public bool TryRead(ref RespReader reader, out long bytesRead) + { + bytesRead = ReadCore(ref reader, reader.BytesConsumed); + return IsComplete; + } + + /// + /// Scan as far as possible, stopping when an entire top-level RESP message has been consumed or the data is exhausted. + /// + /// True if a top-level RESP message has been consumed. + public bool TryRead(ReadOnlySpan value, out int bytesRead) + { + var reader = new RespReader(value); + bytesRead = (int)ReadCore(ref reader); + return IsComplete; + } + + /// + /// Scan as far as possible, stopping when an entire top-level RESP message has been consumed or the data is exhausted. + /// + /// True if a top-level RESP message has been consumed. + public bool TryRead(in ReadOnlySequence value, out long bytesRead) + { + var reader = new RespReader(in value); + bytesRead = ReadCore(ref reader); + return IsComplete; + } + + /// + /// Scan as far as possible, stopping when an entire top-level RESP message has been consumed or the data is exhausted. + /// + /// The number of bytes consumed in this operation. + private long ReadCore(ref RespReader reader, long startOffset = 0) + { + while (_delta >= 0 && reader.TryReadNext()) + { +#if DEBUG + _elementCount++; +#endif + if (!reader.IsAttribute & _prefix == RespPrefix.None) + { + _prefix = reader.Prefix; + } + + if (reader.IsAggregate) ApplyAggregateRules(ref reader); + + if (_streamingAggregateDepth == 0) _delta += reader.Delta(); + } + + var bytesRead = reader.BytesConsumed - startOffset; + _totalBytes += bytesRead; + return bytesRead; + } + + private void ApplyAggregateRules(ref RespReader reader) + { + Debug.Assert(reader.IsAggregate, "RESP aggregate expected"); + if (reader.IsStreaming) + { + // entering an aggregate stream + if (_streamingAggregateDepth == ushort.MaxValue) ThrowTooDeep(); + _streamingAggregateDepth++; + } + else if (reader.Prefix == RespPrefix.StreamTerminator) + { + // exiting an aggregate stream + if (_streamingAggregateDepth == 0) ThrowUnexpectedTerminator(); + _streamingAggregateDepth--; + } + static void ThrowTooDeep() => throw new InvalidOperationException("Maximum streaming aggregate depth exceeded."); + static void ThrowUnexpectedTerminator() => throw new InvalidOperationException("Unexpected streaming aggregate terminator."); + } +} diff --git a/src/RESPite/PublicAPI/PublicAPI.Shipped.txt b/src/RESPite/PublicAPI/PublicAPI.Shipped.txt new file mode 100644 index 000000000..ab058de62 --- /dev/null +++ b/src/RESPite/PublicAPI/PublicAPI.Shipped.txt @@ -0,0 +1 @@ +#nullable enable diff --git a/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt b/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt new file mode 100644 index 000000000..d45262688 --- /dev/null +++ b/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt @@ -0,0 +1,134 @@ +#nullable enable +[SER004]const RESPite.Messages.RespScanState.MinBytes = 3 -> int +[SER004]override RESPite.Messages.RespScanState.Equals(object? obj) -> bool +[SER004]override RESPite.Messages.RespScanState.GetHashCode() -> int +[SER004]override RESPite.Messages.RespScanState.ToString() -> string! +[SER004]RESPite.Messages.RespAttributeReader +[SER004]RESPite.Messages.RespAttributeReader.RespAttributeReader() -> void +[SER004]RESPite.Messages.RespFrameScanner +[SER004]RESPite.Messages.RespFrameScanner.TryRead(ref RESPite.Messages.RespScanState state, in System.Buffers.ReadOnlySequence data) -> System.Buffers.OperationStatus +[SER004]RESPite.Messages.RespFrameScanner.TryRead(ref RESPite.Messages.RespScanState state, System.ReadOnlySpan data) -> System.Buffers.OperationStatus +[SER004]RESPite.Messages.RespFrameScanner.ValidateRequest(in System.Buffers.ReadOnlySequence message) -> void +[SER004]RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.Array = 42 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.Attribute = 124 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.BigInteger = 40 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.Boolean = 35 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.BulkError = 33 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.BulkString = 36 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.Double = 44 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.Integer = 58 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.Map = 37 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.None = 0 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.Null = 95 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.Push = 62 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.Set = 126 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.SimpleError = 45 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.SimpleString = 43 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.StreamContinuation = 59 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.StreamTerminator = 46 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespPrefix.VerbatimString = 61 -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespReader +[SER004]RESPite.Messages.RespReader.AggregateChildren() -> RESPite.Messages.RespReader.AggregateEnumerator +[SER004]RESPite.Messages.RespReader.AggregateEnumerator +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.AggregateEnumerator() -> void +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.AggregateEnumerator(scoped in RESPite.Messages.RespReader reader) -> void +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.Current.get -> RESPite.Messages.RespReader +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.DemandNext() -> void +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.FillAll(scoped System.Span target, RESPite.Messages.RespReader.Projection! projection) -> void +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.FillAll(scoped System.Span target, RESPite.Messages.RespReader.Projection! first, RESPite.Messages.RespReader.Projection! second, System.Func! combine) -> void +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.GetEnumerator() -> RESPite.Messages.RespReader.AggregateEnumerator +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.MoveNext() -> bool +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.MoveNext(RESPite.Messages.RespPrefix prefix) -> bool +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.MoveNext(RESPite.Messages.RespAttributeReader! respAttributeReader, ref T attributes) -> bool +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.MoveNext(RESPite.Messages.RespPrefix prefix, RESPite.Messages.RespAttributeReader! respAttributeReader, ref T attributes) -> bool +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.MovePast(out RESPite.Messages.RespReader reader) -> void +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.ReadOne(RESPite.Messages.RespReader.Projection! projection) -> T +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.Value -> RESPite.Messages.RespReader +[SER004]RESPite.Messages.RespReader.AggregateLength() -> int +[SER004]RESPite.Messages.RespReader.BytesConsumed.get -> long +[SER004]RESPite.Messages.RespReader.CopyTo(System.Buffers.IBufferWriter! target) -> int +[SER004]RESPite.Messages.RespReader.CopyTo(System.Span target) -> int +[SER004]RESPite.Messages.RespReader.DemandAggregate() -> void +[SER004]RESPite.Messages.RespReader.DemandEnd() -> void +[SER004]RESPite.Messages.RespReader.DemandNotNull() -> void +[SER004]RESPite.Messages.RespReader.DemandScalar() -> void +[SER004]RESPite.Messages.RespReader.FillAll(scoped System.Span target, RESPite.Messages.RespReader.Projection! projection) -> void +[SER004]RESPite.Messages.RespReader.Is(byte value) -> bool +[SER004]RESPite.Messages.RespReader.Is(System.ReadOnlySpan value) -> bool +[SER004]RESPite.Messages.RespReader.Is(System.ReadOnlySpan value) -> bool +[SER004]RESPite.Messages.RespReader.IsAggregate.get -> bool +[SER004]RESPite.Messages.RespReader.IsAttribute.get -> bool +[SER004]RESPite.Messages.RespReader.IsError.get -> bool +[SER004]RESPite.Messages.RespReader.IsNull.get -> bool +[SER004]RESPite.Messages.RespReader.IsScalar.get -> bool +[SER004]RESPite.Messages.RespReader.IsStreaming.get -> bool +[SER004]RESPite.Messages.RespReader.MoveNext() -> void +[SER004]RESPite.Messages.RespReader.MoveNext(RESPite.Messages.RespPrefix prefix) -> void +[SER004]RESPite.Messages.RespReader.MoveNext(RESPite.Messages.RespAttributeReader! respAttributeReader, ref T attributes) -> void +[SER004]RESPite.Messages.RespReader.MoveNext(RESPite.Messages.RespPrefix prefix, RESPite.Messages.RespAttributeReader! respAttributeReader, ref T attributes) -> void +[SER004]RESPite.Messages.RespReader.MoveNextAggregate() -> void +[SER004]RESPite.Messages.RespReader.MoveNextScalar() -> void +[SER004]RESPite.Messages.RespReader.ParseBytes(RESPite.Messages.RespReader.Parser! parser, TState? state) -> T +[SER004]RESPite.Messages.RespReader.ParseBytes(RESPite.Messages.RespReader.Parser! parser) -> T +[SER004]RESPite.Messages.RespReader.ParseChars(RESPite.Messages.RespReader.Parser! parser, TState? state) -> T +[SER004]RESPite.Messages.RespReader.ParseChars(RESPite.Messages.RespReader.Parser! parser) -> T +[SER004]RESPite.Messages.RespReader.Parser +[SER004]RESPite.Messages.RespReader.Parser +[SER004]RESPite.Messages.RespReader.Prefix.get -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespReader.Projection +[SER004]RESPite.Messages.RespReader.ProtocolBytesRemaining.get -> long +[SER004]RESPite.Messages.RespReader.ReadArray(RESPite.Messages.RespReader.Projection! projection, bool scalar = false) -> TResult[]? +[SER004]RESPite.Messages.RespReader.ReadBoolean() -> bool +[SER004]RESPite.Messages.RespReader.ReadByteArray() -> byte[]? +[SER004]RESPite.Messages.RespReader.ReadDecimal() -> decimal +[SER004]RESPite.Messages.RespReader.ReadDouble() -> double +[SER004]RESPite.Messages.RespReader.ReadEnum(T unknownValue = default(T)) -> T +[SER004]RESPite.Messages.RespReader.ReadInt32() -> int +[SER004]RESPite.Messages.RespReader.ReadInt64() -> long +[SER004]RESPite.Messages.RespReader.ReadPairArray(RESPite.Messages.RespReader.Projection! first, RESPite.Messages.RespReader.Projection! second, System.Func! combine, bool scalar = true) -> TResult[]? +[SER004]RESPite.Messages.RespReader.ReadString() -> string? +[SER004]RESPite.Messages.RespReader.ReadString(out string! prefix) -> string? +[SER004]RESPite.Messages.RespReader.RespReader() -> void +[SER004]RESPite.Messages.RespReader.RespReader(scoped in System.Buffers.ReadOnlySequence value) -> void +[SER004]RESPite.Messages.RespReader.RespReader(System.ReadOnlySpan value) -> void +[SER004]RESPite.Messages.RespReader.ScalarChunks() -> RESPite.Messages.RespReader.ScalarEnumerator +[SER004]RESPite.Messages.RespReader.ScalarEnumerator +[SER004]RESPite.Messages.RespReader.ScalarEnumerator.Current.get -> System.ReadOnlySpan +[SER004]RESPite.Messages.RespReader.ScalarEnumerator.CurrentLength.get -> int +[SER004]RESPite.Messages.RespReader.ScalarEnumerator.GetEnumerator() -> RESPite.Messages.RespReader.ScalarEnumerator +[SER004]RESPite.Messages.RespReader.ScalarEnumerator.MoveNext() -> bool +[SER004]RESPite.Messages.RespReader.ScalarEnumerator.MovePast(out RESPite.Messages.RespReader reader) -> void +[SER004]RESPite.Messages.RespReader.ScalarEnumerator.ScalarEnumerator() -> void +[SER004]RESPite.Messages.RespReader.ScalarEnumerator.ScalarEnumerator(scoped in RESPite.Messages.RespReader reader) -> void +[SER004]RESPite.Messages.RespReader.ScalarIsEmpty() -> bool +[SER004]RESPite.Messages.RespReader.ScalarLength() -> int +[SER004]RESPite.Messages.RespReader.ScalarLongLength() -> long +[SER004]RESPite.Messages.RespReader.SkipChildren() -> void +[SER004]RESPite.Messages.RespReader.TryGetSpan(out System.ReadOnlySpan value) -> bool +[SER004]RESPite.Messages.RespReader.TryMoveNext() -> bool +[SER004]RESPite.Messages.RespReader.TryMoveNext(bool checkError) -> bool +[SER004]RESPite.Messages.RespReader.TryMoveNext(RESPite.Messages.RespPrefix prefix) -> bool +[SER004]RESPite.Messages.RespReader.TryMoveNext(RESPite.Messages.RespAttributeReader! respAttributeReader, ref T attributes) -> bool +[SER004]RESPite.Messages.RespReader.TryReadDouble(out double value, bool allowTokens = true) -> bool +[SER004]RESPite.Messages.RespReader.TryReadInt32(out int value) -> bool +[SER004]RESPite.Messages.RespReader.TryReadInt64(out long value) -> bool +[SER004]RESPite.Messages.RespReader.TryReadNext() -> bool +[SER004]RESPite.Messages.RespScanState +[SER004]RESPite.Messages.RespScanState.IsComplete.get -> bool +[SER004]RESPite.Messages.RespScanState.Prefix.get -> RESPite.Messages.RespPrefix +[SER004]RESPite.Messages.RespScanState.RespScanState() -> void +[SER004]RESPite.Messages.RespScanState.TotalBytes.get -> long +[SER004]RESPite.Messages.RespScanState.TryRead(in System.Buffers.ReadOnlySequence value, out long bytesRead) -> bool +[SER004]RESPite.Messages.RespScanState.TryRead(ref RESPite.Messages.RespReader reader, out long bytesRead) -> bool +[SER004]RESPite.Messages.RespScanState.TryRead(System.ReadOnlySpan value, out int bytesRead) -> bool +[SER004]RESPite.RespException +[SER004]RESPite.RespException.RespException(string! message) -> void +[SER004]static RESPite.Messages.RespFrameScanner.Default.get -> RESPite.Messages.RespFrameScanner! +[SER004]static RESPite.Messages.RespFrameScanner.Subscription.get -> RESPite.Messages.RespFrameScanner! +[SER004]virtual RESPite.Messages.RespAttributeReader.Read(ref RESPite.Messages.RespReader reader, ref T value) -> void +[SER004]virtual RESPite.Messages.RespAttributeReader.ReadKeyValuePair(scoped System.ReadOnlySpan key, ref RESPite.Messages.RespReader reader, ref T value) -> bool +[SER004]virtual RESPite.Messages.RespAttributeReader.ReadKeyValuePairs(ref RESPite.Messages.RespReader reader, ref T value) -> int +[SER004]virtual RESPite.Messages.RespReader.Parser.Invoke(System.ReadOnlySpan value, TState? state) -> TValue +[SER004]virtual RESPite.Messages.RespReader.Parser.Invoke(System.ReadOnlySpan value) -> TValue +[SER004]virtual RESPite.Messages.RespReader.Projection.Invoke(ref RESPite.Messages.RespReader value) -> T diff --git a/src/RESPite/PublicAPI/net8.0/PublicAPI.Shipped.txt b/src/RESPite/PublicAPI/net8.0/PublicAPI.Shipped.txt new file mode 100644 index 000000000..ab058de62 --- /dev/null +++ b/src/RESPite/PublicAPI/net8.0/PublicAPI.Shipped.txt @@ -0,0 +1 @@ +#nullable enable diff --git a/src/RESPite/PublicAPI/net8.0/PublicAPI.Unshipped.txt b/src/RESPite/PublicAPI/net8.0/PublicAPI.Unshipped.txt new file mode 100644 index 000000000..c43af2e5e --- /dev/null +++ b/src/RESPite/PublicAPI/net8.0/PublicAPI.Unshipped.txt @@ -0,0 +1,3 @@ +#nullable enable +[SER004]RESPite.Messages.RespReader.ParseBytes(System.IFormatProvider? formatProvider = null) -> T +[SER004]RESPite.Messages.RespReader.ParseChars(System.IFormatProvider? formatProvider = null) -> T \ No newline at end of file diff --git a/src/RESPite/RESPite.csproj b/src/RESPite/RESPite.csproj new file mode 100644 index 000000000..4ad8a0634 --- /dev/null +++ b/src/RESPite/RESPite.csproj @@ -0,0 +1,51 @@ + + + + true + net461;netstandard2.0;net472;net6.0;net8.0;net10.0 + enable + enable + false + 2025 - $([System.DateTime]::Now.Year) Marc Gravell + readme.md + $(DefineConstants);RESPITE + + + + + + + + + + + + + + + + + RespReader.cs + + + Shared/Experiments.cs + + + Shared/FrameworkShims.cs + + + Shared/NullableHacks.cs + + + Shared/SkipLocalsInit.cs + + + + + + + + + + diff --git a/src/RESPite/RespException.cs b/src/RESPite/RespException.cs new file mode 100644 index 000000000..a6cb0c66a --- /dev/null +++ b/src/RESPite/RespException.cs @@ -0,0 +1,12 @@ +using System.Diagnostics.CodeAnalysis; +using StackExchange.Redis; + +namespace RESPite; + +/// +/// Represents a RESP error message. +/// +[Experimental(Experiments.Respite, UrlFormat = Experiments.UrlFormat)] +public sealed class RespException(string message) : Exception(message) +{ +} diff --git a/src/RESPite/readme.md b/src/RESPite/readme.md new file mode 100644 index 000000000..034cae8d3 --- /dev/null +++ b/src/RESPite/readme.md @@ -0,0 +1,6 @@ +# RESPite + +RESPite is a high-performance low-level RESP (Redis, etc) library, used as the IO core for +StackExchange.Redis v3+. It is also available for direct use from other places! + +For now: you probably shouldn't be using this. \ No newline at end of file diff --git a/src/StackExchange.Redis/Experiments.cs b/src/StackExchange.Redis/Experiments.cs index 547838873..1ec2b6f09 100644 --- a/src/StackExchange.Redis/Experiments.cs +++ b/src/StackExchange.Redis/Experiments.cs @@ -9,11 +9,12 @@ internal static class Experiments { public const string UrlFormat = "https://stackexchange.github.io/StackExchange.Redis/exp/"; + // ReSharper disable InconsistentNaming public const string VectorSets = "SER001"; - // ReSharper disable once InconsistentNaming public const string Server_8_4 = "SER002"; - // ReSharper disable once InconsistentNaming public const string Server_8_6 = "SER003"; + public const string Respite = "SER004"; + // ReSharper restore InconsistentNaming } } diff --git a/src/StackExchange.Redis/FrameworkShims.cs b/src/StackExchange.Redis/FrameworkShims.cs index c0fe4cb1d..c1c1bcfe2 100644 --- a/src/StackExchange.Redis/FrameworkShims.cs +++ b/src/StackExchange.Redis/FrameworkShims.cs @@ -1,6 +1,7 @@ #pragma warning disable SA1403 // single namespace -#if NET5_0_OR_GREATER +#if RESPITE // add nothing +#elif NET5_0_OR_GREATER // context: https://github.com/StackExchange/StackExchange.Redis/issues/2619 [assembly: System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Runtime.CompilerServices.IsExternalInit))] #else diff --git a/tests/RESPite.Tests/RESPite.Tests.csproj b/tests/RESPite.Tests/RESPite.Tests.csproj new file mode 100644 index 000000000..4f46712af --- /dev/null +++ b/tests/RESPite.Tests/RESPite.Tests.csproj @@ -0,0 +1,21 @@ + + + + net481;net8.0;net10.0 + enable + false + true + Exe + + + + + + + + + + + + + diff --git a/tests/RESPite.Tests/RespReaderTests.cs b/tests/RESPite.Tests/RespReaderTests.cs new file mode 100644 index 000000000..4b250e7ec --- /dev/null +++ b/tests/RESPite.Tests/RespReaderTests.cs @@ -0,0 +1,863 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Numerics; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using RESPite.Internal; +using RESPite.Messages; +using Xunit; +using Xunit.Sdk; +using Xunit.v3; + +namespace RESPite.Tests; + +public class RespReaderTests(ITestOutputHelper logger) +{ + public readonly struct RespPayload(string label, ReadOnlySequence payload, byte[] expected, bool? outOfBand, int count) + { + public override string ToString() => Label; + public string Label { get; } = label; + public ReadOnlySequence PayloadRaw { get; } = payload; + public int Length { get; } = CheckPayload(payload, expected, outOfBand, count); + private static int CheckPayload(scoped in ReadOnlySequence actual, byte[] expected, bool? outOfBand, int count) + { + Assert.Equal(expected.LongLength, actual.Length); + var pool = ArrayPool.Shared.Rent(expected.Length); + actual.CopyTo(pool); + bool isSame = pool.AsSpan(0, expected.Length).SequenceEqual(expected); + ArrayPool.Shared.Return(pool); + Assert.True(isSame, "Data mismatch"); + + // verify that the data exactly passes frame-scanning + long totalBytes = 0; + RespReader reader = new(actual); + while (count > 0) + { + RespScanState state = default; + Assert.True(state.TryRead(ref reader, out long bytesRead)); + totalBytes += bytesRead; + Assert.True(state.IsComplete, nameof(state.IsComplete)); + if (outOfBand.HasValue) + { + if (outOfBand.Value) + { + Assert.Equal(RespPrefix.Push, state.Prefix); + } + else + { + Assert.NotEqual(RespPrefix.Push, state.Prefix); + } + } + count--; + } + Assert.Equal(expected.Length, totalBytes); + reader.DemandEnd(); + return expected.Length; + } + + public RespReader Reader() => new(PayloadRaw); + } + + public sealed class RespAttribute : DataAttribute + { + public override bool SupportsDiscoveryEnumeration() => true; + + private readonly object _value; + public bool OutOfBand { get; set; } = false; + + private bool? EffectiveOutOfBand => Count == 1 ? OutOfBand : default(bool?); + public int Count { get; set; } = 1; + + public RespAttribute(string value) => _value = value; + public RespAttribute(params string[] values) => _value = values; + + public override ValueTask> GetData(MethodInfo testMethod, DisposalTracker disposalTracker) + => new(GetData(testMethod).ToArray()); + + public IEnumerable GetData(MethodInfo testMethod) + { + switch (_value) + { + case string s: + foreach (var item in GetVariants(s, EffectiveOutOfBand, Count)) + { + yield return new TheoryDataRow(item); + } + break; + case string[] arr: + foreach (string s in arr) + { + foreach (var item in GetVariants(s, EffectiveOutOfBand, Count)) + { + yield return new TheoryDataRow(item); + } + } + break; + } + } + + private static IEnumerable GetVariants(string value, bool? outOfBand, int count) + { + var bytes = Encoding.UTF8.GetBytes(value); + + // all in one + yield return new("Right-sized", new(bytes), bytes, outOfBand, count); + + var bigger = new byte[bytes.Length + 4]; + bytes.CopyTo(bigger.AsSpan(2, bytes.Length)); + bigger.AsSpan(0, 2).Fill(0xFF); + bigger.AsSpan(bytes.Length + 2, 2).Fill(0xFF); + + // all in one, oversized + yield return new("Oversized", new(bigger, 2, bytes.Length), bytes, outOfBand, count); + + // two-chunks + for (int i = 0; i <= bytes.Length; i++) + { + int offset = 2 + i; + var left = new Segment(new ReadOnlyMemory(bigger, 0, offset), null); + var right = new Segment(new ReadOnlyMemory(bigger, offset, bigger.Length - offset), left); + yield return new($"Split:{i}", new ReadOnlySequence(left, 2, right, right.Length - 2), bytes, outOfBand, count); + } + + // N-chunks + Segment head = new(new(bytes, 0, 1), null), tail = head; + for (int i = 1; i < bytes.Length; i++) + { + tail = new(new(bytes, i, 1), tail); + } + yield return new("Chunk-per-byte", new(head, 0, tail, 1), bytes, outOfBand, count); + } + } + + [Theory, Resp("$3\r\n128\r\n")] + public void HandleSplitTokens(RespPayload payload) + { + RespReader reader = payload.Reader(); + RespScanState scan = default; + bool readResult = scan.TryRead(ref reader, out _); + logger.WriteLine(scan.ToString()); + Assert.Equal(payload.Length, reader.BytesConsumed); + Assert.True(readResult); + } + + // the examples from https://github.com/redis/redis-specifications/blob/master/protocol/RESP3.md + [Theory, Resp("$11\r\nhello world\r\n", "$?\r\n;6\r\nhello \r\n;5\r\nworld\r\n;0\r\n")] + public void BlobString(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.BulkString); + Assert.True(reader.Is("hello world"u8)); + Assert.Equal("hello world", reader.ReadString()); + Assert.Equal("hello world", reader.ReadString(out var prefix)); + Assert.Equal("", prefix); +#if NET7_0_OR_GREATER + Assert.Equal("hello world", reader.ParseChars()); +#endif + /* interestingly, string does not implement IUtf8SpanParsable +#if NET8_0_OR_GREATER + Assert.Equal("hello world", reader.ParseBytes()); +#endif + */ + reader.DemandEnd(); + } + + [Theory, Resp("$0\r\n\r\n", "$?\r\n;0\r\n")] + public void EmptyBlobString(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.BulkString); + Assert.True(reader.Is(""u8)); + Assert.Equal("", reader.ReadString()); + reader.DemandEnd(); + } + + [Theory, Resp("+hello world\r\n")] + public void SimpleString(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.SimpleString); + Assert.True(reader.Is("hello world"u8)); + Assert.Equal("hello world", reader.ReadString()); + Assert.Equal("hello world", reader.ReadString(out var prefix)); + Assert.Equal("", prefix); + reader.DemandEnd(); + } + + [Theory, Resp("-ERR this is the error description\r\n")] + public void SimpleError_ImplicitErrors(RespPayload payload) + { + var ex = Assert.Throws(() => + { + var reader = payload.Reader(); + reader.MoveNext(); + }); + Assert.Equal("ERR this is the error description", ex.Message); + } + + [Theory, Resp("-ERR this is the error description\r\n")] + public void SimpleError_Careful(RespPayload payload) + { + var reader = payload.Reader(); + Assert.True(reader.TryReadNext()); + Assert.Equal(RespPrefix.SimpleError, reader.Prefix); + Assert.True(reader.Is("ERR this is the error description"u8)); + Assert.Equal("ERR this is the error description", reader.ReadString()); + reader.DemandEnd(); + } + + [Theory, Resp(":1234\r\n")] + public void Number(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Integer); + Assert.True(reader.Is("1234"u8)); + Assert.Equal("1234", reader.ReadString()); + Assert.Equal(1234, reader.ReadInt32()); + Assert.Equal(1234D, reader.ReadDouble()); + Assert.Equal(1234M, reader.ReadDecimal()); +#if NET7_0_OR_GREATER + Assert.Equal(1234, reader.ParseChars()); + Assert.Equal(1234D, reader.ParseChars()); + Assert.Equal(1234M, reader.ParseChars()); +#endif +#if NET8_0_OR_GREATER + Assert.Equal(1234, reader.ParseBytes()); + Assert.Equal(1234D, reader.ParseBytes()); + Assert.Equal(1234M, reader.ParseBytes()); +#endif + reader.DemandEnd(); + } + + [Theory, Resp("_\r\n")] + public void Null(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Null); + Assert.True(reader.Is(""u8)); + Assert.Null(reader.ReadString()); + reader.DemandEnd(); + } + + [Theory, Resp("$-1\r\n")] + public void NullString(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.BulkString); + Assert.True(reader.IsNull); + Assert.Null(reader.ReadString()); + Assert.Equal(0, reader.ScalarLength()); + Assert.True(reader.Is(""u8)); + Assert.True(reader.ScalarIsEmpty()); + + var iterator = reader.ScalarChunks(); + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + [Theory, Resp(",1.23\r\n")] + public void Double(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Double); + Assert.True(reader.Is("1.23"u8)); + Assert.Equal("1.23", reader.ReadString()); + Assert.Equal(1.23D, reader.ReadDouble()); + Assert.Equal(1.23M, reader.ReadDecimal()); + reader.DemandEnd(); + } + + [Theory, Resp(":10\r\n")] + public void Integer_Simple(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Integer); + Assert.True(reader.Is("10"u8)); + Assert.Equal("10", reader.ReadString()); + Assert.Equal(10, reader.ReadInt32()); + Assert.Equal(10D, reader.ReadDouble()); + Assert.Equal(10M, reader.ReadDecimal()); + reader.DemandEnd(); + } + + [Theory, Resp(",10\r\n")] + public void Double_Simple(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Double); + Assert.True(reader.Is("10"u8)); + Assert.Equal("10", reader.ReadString()); + Assert.Equal(10, reader.ReadInt32()); + Assert.Equal(10D, reader.ReadDouble()); + Assert.Equal(10M, reader.ReadDecimal()); + reader.DemandEnd(); + } + + [Theory, Resp(",inf\r\n")] + public void Double_Infinity(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Double); + Assert.True(reader.Is("inf"u8)); + Assert.Equal("inf", reader.ReadString()); + var val = reader.ReadDouble(); + Assert.True(double.IsInfinity(val)); + Assert.True(double.IsPositiveInfinity(val)); + reader.DemandEnd(); + } + + [Theory, Resp(",+inf\r\n")] + public void Double_PosInfinity(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Double); + Assert.True(reader.Is("+inf"u8)); + Assert.Equal("+inf", reader.ReadString()); + var val = reader.ReadDouble(); + Assert.True(double.IsInfinity(val)); + Assert.True(double.IsPositiveInfinity(val)); + reader.DemandEnd(); + } + + [Theory, Resp(",-inf\r\n")] + public void Double_NegInfinity(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Double); + Assert.True(reader.Is("-inf"u8)); + Assert.Equal("-inf", reader.ReadString()); + var val = reader.ReadDouble(); + Assert.True(double.IsInfinity(val)); + Assert.True(double.IsNegativeInfinity(val)); + reader.DemandEnd(); + } + + [Theory, Resp(",nan\r\n")] + public void Double_NaN(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Double); + Assert.True(reader.Is("nan"u8)); + Assert.Equal("nan", reader.ReadString()); + var val = reader.ReadDouble(); + Assert.True(double.IsNaN(val)); + reader.DemandEnd(); + } + + [Theory, Resp("#t\r\n")] + public void Boolean_T(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Boolean); + Assert.True(reader.ReadBoolean()); + reader.DemandEnd(); + } + + [Theory, Resp("#f\r\n")] + public void Boolean_F(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Boolean); + Assert.False(reader.ReadBoolean()); + reader.DemandEnd(); + } + + [Theory, Resp(":1\r\n")] + public void Boolean_1(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Integer); + Assert.True(reader.ReadBoolean()); + reader.DemandEnd(); + } + + [Theory, Resp(":0\r\n")] + public void Boolean_0(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Integer); + Assert.False(reader.ReadBoolean()); + reader.DemandEnd(); + } + + [Theory, Resp("!21\r\nSYNTAX invalid syntax\r\n", "!?\r\n;6\r\nSYNTAX\r\n;15\r\n invalid syntax\r\n;0\r\n")] + public void BlobError_ImplicitErrors(RespPayload payload) + { + var ex = Assert.Throws(() => + { + var reader = payload.Reader(); + reader.MoveNext(); + }); + Assert.Equal("SYNTAX invalid syntax", ex.Message); + } + + [Theory, Resp("!21\r\nSYNTAX invalid syntax\r\n", "!?\r\n;6\r\nSYNTAX\r\n;15\r\n invalid syntax\r\n;0\r\n")] + public void BlobError_Careful(RespPayload payload) + { + var reader = payload.Reader(); + Assert.True(reader.TryReadNext()); + Assert.Equal(RespPrefix.BulkError, reader.Prefix); + Assert.True(reader.Is("SYNTAX invalid syntax"u8)); + Assert.Equal("SYNTAX invalid syntax", reader.ReadString()); + reader.DemandEnd(); + } + + [Theory, Resp("=15\r\ntxt:Some string\r\n", "=?\r\n;4\r\ntxt:\r\n;11\r\nSome string\r\n;0\r\n")] + public void VerbatimString(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.VerbatimString); + Assert.Equal("Some string", reader.ReadString()); + Assert.Equal("Some string", reader.ReadString(out var prefix)); + Assert.Equal("txt", prefix); + + Assert.Equal("Some string", reader.ReadString(out var prefix2)); + Assert.Same(prefix, prefix2); // check prefix recognized and reuse literal + reader.DemandEnd(); + } + + [Theory, Resp("(3492890328409238509324850943850943825024385\r\n")] + public void BigIntegers(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.BigInteger); + Assert.Equal("3492890328409238509324850943850943825024385", reader.ReadString()); +#if NET8_0_OR_GREATER + var actual = reader.ParseChars(chars => BigInteger.Parse(chars, CultureInfo.InvariantCulture)); + + var expected = BigInteger.Parse("3492890328409238509324850943850943825024385"); + Assert.Equal(expected, actual); +#endif + } + + [Theory, Resp("*3\r\n:1\r\n:2\r\n:3\r\n", "*?\r\n:1\r\n:2\r\n:3\r\n.\r\n")] + public void Array(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array); + Assert.Equal(3, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(1, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(2, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(3, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext(RespPrefix.Integer)); + iterator.MovePast(out reader); + reader.DemandEnd(); + + reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array); + int[] arr = new int[reader.AggregateLength()]; + int i = 0; + foreach (var sub in reader.AggregateChildren()) + { + sub.MoveNext(RespPrefix.Integer); + arr[i++] = sub.ReadInt32(); + sub.DemandEnd(); + } + iterator.MovePast(out reader); + reader.DemandEnd(); + + Assert.Equal([1, 2, 3], arr); + } + + [Theory, Resp("*-1\r\n")] + public void NullArray(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array); + Assert.True(reader.IsNull); + Assert.Equal(0, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + [Theory, Resp("*2\r\n*3\r\n:1\r\n$5\r\nhello\r\n:2\r\n#f\r\n", "*?\r\n*?\r\n:1\r\n$5\r\nhello\r\n:2\r\n.\r\n#f\r\n.\r\n")] + public void NestedArray(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array); + + Assert.Equal(2, reader.AggregateLength()); + + var iterator = reader.AggregateChildren(); + Assert.True(iterator.MoveNext(RespPrefix.Array)); + + Assert.Equal(3, iterator.Value.AggregateLength()); + var subIterator = iterator.Value.AggregateChildren(); + Assert.True(subIterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(1, subIterator.Value.ReadInt64()); + subIterator.Value.DemandEnd(); + + Assert.True(subIterator.MoveNext(RespPrefix.BulkString)); + Assert.True(subIterator.Value.Is("hello"u8)); + subIterator.Value.DemandEnd(); + + Assert.True(subIterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(2, subIterator.Value.ReadInt64()); + subIterator.Value.DemandEnd(); + + Assert.False(subIterator.MoveNext()); + + Assert.True(iterator.MoveNext(RespPrefix.Boolean)); + Assert.False(iterator.Value.ReadBoolean()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + + reader.DemandEnd(); + } + + [Theory, Resp("%2\r\n+first\r\n:1\r\n+second\r\n:2\r\n", "%?\r\n+first\r\n:1\r\n+second\r\n:2\r\n.\r\n")] + public void Map(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Map); + + Assert.Equal(4, reader.AggregateLength()); + + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("first".AsSpan())); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(1, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("second"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(2, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + [Theory, Resp("~5\r\n+orange\r\n+apple\r\n#t\r\n:100\r\n:999\r\n", "~?\r\n+orange\r\n+apple\r\n#t\r\n:100\r\n:999\r\n.\r\n")] + public void Set(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Set); + + Assert.Equal(5, reader.AggregateLength()); + + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("orange".AsSpan())); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("apple"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Boolean)); + Assert.True(iterator.Value.ReadBoolean()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(100, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(999, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + private sealed class TestAttributeReader : RespAttributeReader<(int Count, int Ttl, decimal A, decimal B)> + { + public override void Read(ref RespReader reader, ref (int Count, int Ttl, decimal A, decimal B) value) + { + value.Count += ReadKeyValuePairs(ref reader, ref value); + } + private TestAttributeReader() { } + public static readonly TestAttributeReader Instance = new(); + public static (int Count, int Ttl, decimal A, decimal B) Zero = (0, 0, 0, 0); + public override bool ReadKeyValuePair(scoped ReadOnlySpan key, ref RespReader reader, ref (int Count, int Ttl, decimal A, decimal B) value) + { + if (key.SequenceEqual("ttl"u8) && reader.IsScalar) + { + value.Ttl = reader.ReadInt32(); + } + else if (key.SequenceEqual("key-popularity"u8) && reader.IsAggregate) + { + ReadKeyValuePairs(ref reader, ref value); // recurse to process a/b below + } + else if (key.SequenceEqual("a"u8) && reader.IsScalar) + { + value.A = reader.ReadDecimal(); + } + else if (key.SequenceEqual("b"u8) && reader.IsScalar) + { + value.B = reader.ReadDecimal(); + } + else + { + return false; // not recognized + } + return true; // recognized + } + } + + [Theory, Resp( + "|1\r\n+key-popularity\r\n%2\r\n$1\r\na\r\n,0.1923\r\n$1\r\nb\r\n,0.0012\r\n*2\r\n:2039123\r\n:9543892\r\n", + "|1\r\n+key-popularity\r\n%2\r\n$1\r\na\r\n,0.1923\r\n$1\r\nb\r\n,0.0012\r\n*?\r\n:2039123\r\n:9543892\r\n.\r\n")] + public void AttributeRoot(RespPayload payload) + { + // ignore the attribute data + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array); + Assert.Equal(2, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(2039123, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(9543892, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + reader.DemandEnd(); + + // process the attribute data + var state = TestAttributeReader.Zero; + reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array, TestAttributeReader.Instance, ref state); + Assert.Equal(1, state.Count); + Assert.Equal(0.1923M, state.A); + Assert.Equal(0.0012M, state.B); + state = TestAttributeReader.Zero; + + Assert.Equal(2, reader.AggregateLength()); + iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer, TestAttributeReader.Instance, ref state)); + Assert.Equal(2039123, iterator.Value.ReadInt32()); + Assert.Equal(0, state.Count); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer, TestAttributeReader.Instance, ref state)); + Assert.Equal(9543892, iterator.Value.ReadInt32()); + Assert.Equal(0, state.Count); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + [Theory, Resp("*3\r\n:1\r\n:2\r\n|1\r\n+ttl\r\n:3600\r\n:3\r\n", "*?\r\n:1\r\n:2\r\n|1\r\n+ttl\r\n:3600\r\n:3\r\n.\r\n")] + public void AttributeInner(RespPayload payload) + { + // ignore the attribute data + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array); + Assert.Equal(3, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(1, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(2, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(3, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + reader.DemandEnd(); + + // process the attribute data + var state = TestAttributeReader.Zero; + reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array, TestAttributeReader.Instance, ref state); + Assert.Equal(0, state.Count); + Assert.Equal(3, reader.AggregateLength()); + iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer, TestAttributeReader.Instance, ref state)); + Assert.Equal(0, state.Count); + Assert.Equal(1, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer, TestAttributeReader.Instance, ref state)); + Assert.Equal(0, state.Count); + Assert.Equal(2, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer, TestAttributeReader.Instance, ref state)); + Assert.Equal(1, state.Count); + Assert.Equal(3600, state.Ttl); + state = TestAttributeReader.Zero; // reset + Assert.Equal(3, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext(TestAttributeReader.Instance, ref state)); + Assert.Equal(0, state.Count); + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + [Theory, Resp(">3\r\n+message\r\n+somechannel\r\n+this is the message\r\n", OutOfBand = true)] + public void Push(RespPayload payload) + { + // ignore the attribute data + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Push); + Assert.Equal(3, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("message"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("somechannel"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("this is the message"u8)); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + [Theory, Resp(">3\r\n+message\r\n+somechannel\r\n+this is the message\r\n$9\r\nGet-Reply\r\n", Count = 2)] + public void PushThenGetReply(RespPayload payload) + { + // ignore the attribute data + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Push); + Assert.Equal(3, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("message"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("somechannel"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("this is the message"u8)); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + + reader.MoveNext(RespPrefix.BulkString); + Assert.True(reader.Is("Get-Reply"u8)); + reader.DemandEnd(); + } + + [Theory, Resp("$9\r\nGet-Reply\r\n>3\r\n+message\r\n+somechannel\r\n+this is the message\r\n", Count = 2)] + public void GetReplyThenPush(RespPayload payload) + { + // ignore the attribute data + var reader = payload.Reader(); + + reader.MoveNext(RespPrefix.BulkString); + Assert.True(reader.Is("Get-Reply"u8)); + + reader.MoveNext(RespPrefix.Push); + Assert.Equal(3, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("message"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("somechannel"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("this is the message"u8)); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + + reader.DemandEnd(); + } + + [Theory, Resp("*0\r\n$4\r\npass\r\n", "*1\r\n+ok\r\n$4\r\npass\r\n", "*-1\r\n$4\r\npass\r\n", "*?\r\n.\r\n$4\r\npass\r\n", Count = 2)] + public void ArrayThenString(RespPayload payload) + { + var reader = payload.Reader(); + Assert.True(reader.TryMoveNext(RespPrefix.Array)); + reader.SkipChildren(); + + Assert.True(reader.TryMoveNext(RespPrefix.BulkString)); + Assert.True(reader.Is("pass"u8)); + + reader.DemandEnd(); + + // and the same using child iterator + reader = payload.Reader(); + Assert.True(reader.TryMoveNext(RespPrefix.Array)); + var iterator = reader.AggregateChildren(); + iterator.MovePast(out reader); + + Assert.True(reader.TryMoveNext(RespPrefix.BulkString)); + Assert.True(reader.Is("pass"u8)); + + reader.DemandEnd(); + } + + private sealed class Segment : ReadOnlySequenceSegment + { + public override string ToString() => RespConstants.UTF8.GetString(Memory.Span) + .Replace("\r", "\\r").Replace("\n", "\\n"); + + public Segment(ReadOnlyMemory value, Segment? head) + { + Memory = value; + if (head is not null) + { + RunningIndex = head.RunningIndex + head.Memory.Length; + head.Next = this; + } + } + public bool IsEmpty => Memory.IsEmpty; + public int Length => Memory.Length; + } +} From a2720ffb25ff030ba46864a88659f08a43f0e908 Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Fri, 13 Feb 2026 17:07:12 +0000 Subject: [PATCH 02/11] WIP; migrate read code - main parse loop untested but complete-ish --- src/RESPite/Buffers/CycleBuffer.cs | 701 ++++++++++++++++ src/RESPite/Internal/DebugCounters.cs | 70 ++ .../Internal/RespOperationExtensions.cs | 58 ++ src/RESPite/Messages/RespAttributeReader.cs | 1 - src/RESPite/Messages/RespFrameScanner.cs | 1 - src/RESPite/Messages/RespPrefix.cs | 1 - src/RESPite/Messages/RespReader.cs | 3 +- src/RESPite/Messages/RespScanState.cs | 1 - src/RESPite/PublicAPI/PublicAPI.Unshipped.txt | 24 +- .../PublicAPI/net6.0/PublicAPI.Shipped.txt | 1 + .../PublicAPI/net6.0/PublicAPI.Unshipped.txt | 1 + src/RESPite/RESPite.csproj | 25 +- src/RESPite/RespException.cs | 1 - .../Shared}/Experiments.cs | 4 +- .../Shared/FrameworkShims.Encoding.cs} | 35 +- src/RESPite/Shared/FrameworkShims.Stream.cs | 107 +++ src/RESPite/Shared/FrameworkShims.cs | 15 + .../Shared}/NullableHacks.cs | 0 .../APITypes/StreamInfo.cs | 1 + .../ConfigurationOptions.cs | 9 +- src/StackExchange.Redis/ExceptionFactory.cs | 2 +- .../FrameworkShims.IsExternalInit.cs | 15 + src/StackExchange.Redis/HotKeys.cs | 1 + .../Interfaces/IDatabase.VectorSets.cs | 1 + .../Interfaces/IDatabase.cs | 1 + .../Interfaces/IDatabaseAsync.VectorSets.cs | 1 + .../Interfaces/IDatabaseAsync.cs | 1 + .../KeyPrefixed.VectorSets.cs | 1 + src/StackExchange.Redis/LoggerExtensions.cs | 2 +- src/StackExchange.Redis/Message.cs | 8 +- .../PhysicalConnection.Read.cs | 757 ++++++++++++++++++ src/StackExchange.Redis/PhysicalConnection.cs | 470 +---------- .../PublicAPI/net6.0/PublicAPI.Shipped.txt | 5 +- .../PublicAPI/net8.0/PublicAPI.Shipped.txt | 4 - .../netcoreapp3.1/PublicAPI.Shipped.txt | 2 - src/StackExchange.Redis/RedisSubscriber.cs | 18 - .../RespReaderExtensions.cs | 72 ++ src/StackExchange.Redis/ResultProcessor.cs | 8 +- .../StackExchange.Redis.csproj | 12 +- .../StreamConfiguration.cs | 1 + src/StackExchange.Redis/StreamIdempotentId.cs | 1 + src/StackExchange.Redis/ValueCondition.cs | 1 + .../VectorSetAddRequest.cs | 1 + src/StackExchange.Redis/VectorSetInfo.cs | 1 + src/StackExchange.Redis/VectorSetLink.cs | 1 + .../VectorSetQuantization.cs | 1 + .../VectorSetSimilaritySearchRequest.cs | 1 + .../VectorSetSimilaritySearchResult.cs | 1 + 48 files changed, 1879 insertions(+), 570 deletions(-) create mode 100644 src/RESPite/Buffers/CycleBuffer.cs create mode 100644 src/RESPite/Internal/DebugCounters.cs create mode 100644 src/RESPite/Internal/RespOperationExtensions.cs create mode 100644 src/RESPite/PublicAPI/net6.0/PublicAPI.Shipped.txt create mode 100644 src/RESPite/PublicAPI/net6.0/PublicAPI.Unshipped.txt rename src/{StackExchange.Redis => RESPite/Shared}/Experiments.cs (95%) rename src/{StackExchange.Redis/FrameworkShims.cs => RESPite/Shared/FrameworkShims.Encoding.cs} (57%) create mode 100644 src/RESPite/Shared/FrameworkShims.Stream.cs create mode 100644 src/RESPite/Shared/FrameworkShims.cs rename src/{StackExchange.Redis => RESPite/Shared}/NullableHacks.cs (100%) create mode 100644 src/StackExchange.Redis/FrameworkShims.IsExternalInit.cs create mode 100644 src/StackExchange.Redis/PhysicalConnection.Read.cs delete mode 100644 src/StackExchange.Redis/PublicAPI/net8.0/PublicAPI.Shipped.txt delete mode 100644 src/StackExchange.Redis/PublicAPI/netcoreapp3.1/PublicAPI.Shipped.txt create mode 100644 src/StackExchange.Redis/RespReaderExtensions.cs diff --git a/src/RESPite/Buffers/CycleBuffer.cs b/src/RESPite/Buffers/CycleBuffer.cs new file mode 100644 index 000000000..6ab982776 --- /dev/null +++ b/src/RESPite/Buffers/CycleBuffer.cs @@ -0,0 +1,701 @@ +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using RESPite.Internal; + +namespace RESPite.Buffers; + +/// +/// Manages the state for a based IO buffer. Unlike Pipe, +/// it is not intended for a separate producer-consumer - there is no thread-safety, and no +/// activation; it just handles the buffers. It is intended to be used as a mutable (non-readonly) +/// field in a type that performs IO; the internal state mutates - it should not be passed around. +/// +/// Notionally, there is an uncommitted area (write) and a committed area (read). Process: +/// - producer loop (*note no concurrency**) +/// - call to get a new scratch +/// - (write to that span) +/// - call to mark complete portions +/// - consumer loop (*note no concurrency**) +/// - call to see if there is a single-span chunk; otherwise +/// - call to get the multi-span chunk +/// - (process none, some, or all of that data) +/// - call to indicate how much data is no longer needed +/// Emphasis: no concurrency! This is intended for a single worker acting as both producer and consumer. +/// +/// There is a *lot* of validation in debug mode; we want to be super sure that we don't corrupt buffer state. +/// +public partial struct CycleBuffer +{ + // note: if someone uses an uninitialized CycleBuffer (via default): that's a skills issue; git gud + public static CycleBuffer Create(MemoryPool? pool = null, int pageSize = DefaultPageSize) + { + pool ??= MemoryPool.Shared; + if (pageSize <= 0) pageSize = DefaultPageSize; + if (pageSize > pool.MaxBufferSize) pageSize = pool.MaxBufferSize; + + return new CycleBuffer(pool, pageSize); + } + + private CycleBuffer(MemoryPool pool, int pageSize) + { + Pool = pool; + PageSize = pageSize; + } + + private const int DefaultPageSize = 8 * 1024; + + public int PageSize { get; } + public MemoryPool Pool { get; } + + private Segment? startSegment, endSegment; + + private int endSegmentCommitted, endSegmentLength; + + public bool TryGetCommitted(out ReadOnlySpan span) + { + DebugAssertValid(); + if (!ReferenceEquals(startSegment, endSegment)) + { + span = default; + return false; + } + + span = startSegment is null ? default : startSegment.Memory.Span.Slice(start: 0, length: endSegmentCommitted); + return true; + } + + /// + /// Commits data written to buffers from , making it available for consumption + /// via . This compares to . + /// + public void Commit(int count) + { + DebugAssertValid(); + if (count <= 0) + { + if (count < 0) Throw(); + return; + } + + var available = endSegmentLength - endSegmentCommitted; + if (count > available) Throw(); + endSegmentCommitted += count; + DebugAssertValid(); + + static void Throw() => throw new ArgumentOutOfRangeException(nameof(count)); + } + + public bool CommittedIsEmpty => ReferenceEquals(startSegment, endSegment) & endSegmentCommitted == 0; + + /// + /// Marks committed data as fully consumed; it will no longer appear in later calls to . + /// + public void DiscardCommitted(int count) + { + DebugAssertValid(); + // optimize for most common case, where we consume everything + if (ReferenceEquals(startSegment, endSegment) + & count == endSegmentCommitted + & count > 0) + { + /* + we are consuming all the data in the single segment; we can + just reset that segment back to full size and re-use as-is; + note that we also know that there must *be* a segment + for the count check to pass + */ + endSegmentCommitted = 0; + endSegmentLength = endSegment!.Untrim(expandBackwards: true); + DebugAssertValid(0); + DebugCounters.OnDiscardFull(count); + } + else if (count == 0) + { + // nothing to do + } + else + { + DiscardCommittedSlow(count); + } + } + + public void DiscardCommitted(long count) + { + DebugAssertValid(); + // optimize for most common case, where we consume everything + if (ReferenceEquals(startSegment, endSegment) + & count == endSegmentCommitted + & count > 0) // checks sign *and* non-trimmed + { + // see for logic + endSegmentCommitted = 0; + endSegmentLength = endSegment!.Untrim(expandBackwards: true); + DebugAssertValid(0); + DebugCounters.OnDiscardFull(count); + } + else if (count == 0) + { + // nothing to do + } + else + { + DiscardCommittedSlow(count); + } + } + + private void DiscardCommittedSlow(long count) + { + DebugCounters.OnDiscardPartial(count); +#if DEBUG + var originalLength = GetCommittedLength(); + var originalCount = count; + var expectedLength = originalLength - originalCount; + string blame = nameof(DiscardCommittedSlow); +#endif + while (count > 0) + { + DebugAssertValid(); + var segment = startSegment; + if (segment is null) break; + if (ReferenceEquals(segment, endSegment)) + { + // first==final==only segment + if (count == endSegmentCommitted) + { + endSegmentLength = startSegment!.Untrim(); + endSegmentCommitted = 0; // = untrimmed and unused +#if DEBUG + blame += ",full-final (t)"; +#endif + } + else + { + // discard from the start + int count32 = checked((int)count); + segment.TrimStart(count32); + endSegmentLength -= count32; + endSegmentCommitted -= count32; +#if DEBUG + blame += ",partial-final"; +#endif + } + + count = 0; + break; + } + else if (count < segment.Length) + { + // multiple, but can take some (not all) of the first buffer +#if DEBUG + var len = segment.Length; +#endif + segment.TrimStart((int)count); + Debug.Assert(segment.Length > 0, "parial trim should have left non-empty segment"); +#if DEBUG + Debug.Assert(segment.Length == len - count, "trim failure"); + blame += ",partial-first"; +#endif + count = 0; + break; + } + else + { + // multiple; discard the entire first segment + count -= segment.Length; + startSegment = + segment.ResetAndGetNext(); // we already did a ref-check, so we know this isn't going past endSegment + endSegment!.AppendOrRecycle(segment, maxDepth: 2); + DebugAssertValid(); +#if DEBUG + blame += ",full-first"; +#endif + } + } + + if (count != 0) ThrowCount(); +#if DEBUG + DebugAssertValid(expectedLength, blame); + _ = originalLength; + _ = originalCount; +#endif + + [DoesNotReturn] + static void ThrowCount() => throw new ArgumentOutOfRangeException(nameof(count)); + } + + [Conditional("DEBUG")] + private void DebugAssertValid(long expectedCommittedLength, [CallerMemberName] string caller = "") + { + DebugAssertValid(); + var actual = GetCommittedLength(); + Debug.Assert( + expectedCommittedLength >= 0, + $"Expected committed length is just... wrong: {expectedCommittedLength} (from {caller})"); + Debug.Assert( + expectedCommittedLength == actual, + $"Committed length mismatch: expected {expectedCommittedLength}, got {actual} (from {caller})"); + } + + [Conditional("DEBUG")] + private void DebugAssertValid() + { + if (startSegment is null) + { + Debug.Assert( + endSegmentLength == 0 & endSegmentCommitted == 0, + "un-init state should be zero"); + return; + } + + Debug.Assert(endSegment is not null, "end segment must not be null if start segment exists"); + Debug.Assert( + endSegmentLength == endSegment!.Length, + $"end segment length is incorrect - expected {endSegmentLength}, got {endSegment.Length}"); + Debug.Assert(endSegmentCommitted <= endSegmentLength, $"end segment is over-committed - {endSegmentCommitted} of {endSegmentLength}"); + + // check running indices + startSegment?.DebugAssertValidChain(); + } + + public long GetCommittedLength() + { + DebugAssertValid(); + if (ReferenceEquals(startSegment, endSegment)) + { + return endSegmentCommitted; + } + + // note that the start-segment is pre-trimmed; we don't need to account for an offset on the left + return (endSegment!.RunningIndex + endSegmentCommitted) - startSegment!.RunningIndex; + } + + /// + /// When used with , this means "any non-empty buffer". + /// + public const int GetAnything = 0; + + /// + /// When used with , this means "any full buffer". + /// + public const int GetFullPagesOnly = -1; + + public bool TryGetFirstCommittedSpan(int minBytes, out ReadOnlySpan span) + { + DebugAssertValid(); + if (TryGetFirstCommittedMemory(minBytes, out var memory)) + { + span = memory.Span; + return true; + } + + span = default; + return false; + } + + /// + /// The minLength arg: -ve means "full segments only" (useful when buffering outbound network data to avoid + /// packet fragmentation); otherwise, it is the minimum length we want. + /// + public bool TryGetFirstCommittedMemory(int minBytes, out ReadOnlyMemory memory) + { + if (minBytes == 0) minBytes = 1; // success always means "at least something" + DebugAssertValid(); + if (ReferenceEquals(startSegment, endSegment)) + { + // single page + var available = endSegmentCommitted; + if (available == 0) + { + // empty (includes uninitialized) + memory = default; + return false; + } + + memory = startSegment!.Memory; + var memLength = memory.Length; + if (available == memLength) + { + // full segment; is it enough to make the caller happy? + return available >= minBytes; + } + + // partial segment (and we know it isn't empty) + memory = memory.Slice(start: 0, length: available); + return available >= minBytes & minBytes > 0; // last check here applies the -ve logic + } + + // multi-page; hand out the first page (which is, by definition: full) + memory = startSegment!.Memory; + return memory.Length >= minBytes; + } + + /// + /// Note that this chain is invalidated by any other operations; no concurrency. + /// + public ReadOnlySequence GetAllCommitted() + { + if (ReferenceEquals(startSegment, endSegment)) + { + // single segment, fine + return startSegment is null + ? default + : new ReadOnlySequence(startSegment.Memory.Slice(start: 0, length: endSegmentCommitted)); + } + +#if PARSE_DETAIL + long length = GetCommittedLength(); +#endif + ReadOnlySequence ros = new(startSegment!, 0, endSegment!, endSegmentCommitted); +#if PARSE_DETAIL + Debug.Assert(ros.Length == length, $"length mismatch: calculated {length}, actual {ros.Length}"); +#endif + return ros; + } + + private Segment GetNextSegment() + { + DebugAssertValid(); + if (endSegment is not null) + { + endSegment.TrimEnd(endSegmentCommitted); + Debug.Assert(endSegment.Length == endSegmentCommitted, "trim failure"); + endSegmentLength = endSegmentCommitted; + DebugAssertValid(); + + var spare = endSegment.Next; + if (spare is not null) + { + // we already have a dangling segment; just update state + endSegment.DebugAssertValidChain(); + endSegment = spare; + endSegmentCommitted = 0; + endSegmentLength = spare.Length; + DebugAssertValid(); + return spare; + } + } + + Segment newSegment = Segment.Create(Pool.Rent(PageSize)); + if (endSegment is null) + { + // tabula rasa + endSegmentLength = newSegment.Length; + endSegment = startSegment = newSegment; + DebugAssertValid(); + return newSegment; + } + + endSegment.Append(newSegment); + endSegmentCommitted = 0; + endSegmentLength = newSegment.Length; + endSegment = newSegment; + DebugAssertValid(); + return newSegment; + } + + /// + /// Gets a scratch area for new data; this compares to . + /// + public Span GetUncommittedSpan(int hint = 0) + => GetUncommittedMemory(hint).Span; + + /// + /// Gets a scratch area for new data; this compares to . + /// + public Memory GetUncommittedMemory(int hint = 0) + { + DebugAssertValid(); + var segment = endSegment; + if (segment is not null) + { + var memory = segment.Memory; + if (endSegmentCommitted != 0) memory = memory.Slice(start: endSegmentCommitted); + if (hint <= 0) // allow anything non-empty + { + if (!memory.IsEmpty) return MemoryMarshal.AsMemory(memory); + } + else if (memory.Length >= Math.Min(hint, PageSize >> 2)) // respect the hint up to 1/4 of the page size + { + return MemoryMarshal.AsMemory(memory); + } + } + + // new segment, will always be entire + return MemoryMarshal.AsMemory(GetNextSegment().Memory); + } + + public int UncommittedAvailable + { + get + { + DebugAssertValid(); + return endSegmentLength - endSegmentCommitted; + } + } + + private sealed class Segment : ReadOnlySequenceSegment + { + private Segment() { } + private IMemoryOwner _lease = NullLease.Instance; + private static Segment? _spare; + private Flags _flags; + + [Flags] + private enum Flags + { + None = 0, + StartTrim = 1 << 0, + EndTrim = 1 << 2, + } + + public static Segment Create(IMemoryOwner lease) + { + Debug.Assert(lease is not null, "null lease"); + var memory = lease!.Memory; + if (memory.IsEmpty) ThrowEmpty(); + + var obj = Interlocked.Exchange(ref _spare, null) ?? new(); + return obj.Init(lease, memory); + static void ThrowEmpty() => throw new InvalidOperationException("leased segment is empty"); + } + + private Segment Init(IMemoryOwner lease, Memory memory) + { + _lease = lease; + Memory = memory; + return this; + } + + public int Length => Memory.Length; + + public void Append(Segment next) + { + Debug.Assert(Next is null, "current segment already has a next"); + Debug.Assert(next.Next is null && next.RunningIndex == 0, "inbound next segment is already in a chain"); + next.RunningIndex = RunningIndex + Length; + Next = next; + DebugAssertValidChain(); + } + + private void ApplyChainDelta(int delta) + { + if (delta != 0) + { + var node = Next; + while (node is not null) + { + node.RunningIndex += delta; + node = node.Next; + } + } + } + + public void TrimEnd(int newLength) + { + var delta = Length - newLength; + if (delta != 0) + { + // buffer wasn't fully used; trim + _flags |= Flags.EndTrim; + Memory = Memory.Slice(0, newLength); + ApplyChainDelta(-delta); + DebugAssertValidChain(); + } + } + + public void TrimStart(int remove) + { + if (remove != 0) + { + _flags |= Flags.StartTrim; + Memory = Memory.Slice(start: remove); + RunningIndex += remove; // so that ROS length keeps working; note we *don't* need to adjust the chain + DebugAssertValidChain(); + } + } + + public new Segment? Next + { + get => (Segment?)base.Next; + private set => base.Next = value; + } + + public Segment? ResetAndGetNext() + { + var next = Next; + Next = null; + RunningIndex = 0; + _flags = Flags.None; + Memory = _lease.Memory; // reset, in case we trimmed it + DebugAssertValidChain(); + return next; + } + + public void Recycle() + { + var lease = _lease; + _lease = NullLease.Instance; + lease.Dispose(); + Next = null; + Memory = default; + RunningIndex = 0; + _flags = Flags.None; + Interlocked.Exchange(ref _spare, this); + DebugAssertValidChain(); + } + + private sealed class NullLease : IMemoryOwner + { + private NullLease() { } + public static readonly NullLease Instance = new NullLease(); + public void Dispose() { } + + public Memory Memory => default; + } + + /// + /// Undo any trimming, returning the new full capacity. + /// + public int Untrim(bool expandBackwards = false) + { + var fullMemory = _lease.Memory; + var fullLength = fullMemory.Length; + var delta = fullLength - Length; + if (delta != 0) + { + _flags &= ~(Flags.StartTrim | Flags.EndTrim); + Memory = fullMemory; + if (expandBackwards & RunningIndex >= delta) + { + // push our origin earlier; only valid if + // we're the first segment, otherwise + // we break someone-else's chain + RunningIndex -= delta; + } + else + { + // push everyone else later + ApplyChainDelta(delta); + } + + DebugAssertValidChain(); + } + return fullLength; + } + + public bool StartTrimmed => (_flags & Flags.StartTrim) != 0; + public bool EndTrimmed => (_flags & Flags.EndTrim) != 0; + + [Conditional("DEBUG")] + public void DebugAssertValidChain([CallerMemberName] string blame = "") + { + var node = this; + var runningIndex = RunningIndex; + int index = 0; + while (node.Next is { } next) + { + index++; + var nextRunningIndex = runningIndex + node.Length; + if (nextRunningIndex != next.RunningIndex) ThrowRunningIndex(blame, index); + node = next; + runningIndex = nextRunningIndex; + static void ThrowRunningIndex(string blame, int index) => throw new InvalidOperationException( + $"Critical running index corruption in dangling chain, from '{blame}', segment {index}"); + } + } + + public void AppendOrRecycle(Segment segment, int maxDepth) + { + segment.Memory.DebugScramble(); + var node = this; + while (maxDepth-- > 0 && node is not null) + { + if (node.Next is null) // found somewhere to attach it + { + if (segment.Untrim() == 0) break; // turned out to be useless + segment.RunningIndex = node.RunningIndex + node.Length; + node.Next = segment; + return; + } + + node = node.Next; + } + + segment.Recycle(); + } + } + + /// + /// Discard all data and buffers. + /// + public void Release() + { + var node = startSegment; + startSegment = endSegment = null; + endSegmentCommitted = endSegmentLength = 0; + while (node is not null) + { + var next = node.Next; + node.Recycle(); + node = next; + } + } + + /// + /// Writes a value to the buffer; comparable to . + /// + public void Write(ReadOnlySpan value) + { + int srcLength = value.Length; + while (srcLength != 0) + { + var target = GetUncommittedSpan(hint: srcLength); + var tgtLength = target.Length; + if (tgtLength >= srcLength) + { + value.CopyTo(target); + Commit(srcLength); + return; + } + + value.Slice(0, tgtLength).CopyTo(target); + Commit(tgtLength); + value = value.Slice(tgtLength); + srcLength -= tgtLength; + } + } + + /// + /// Writes a value to the buffer; comparable to . + /// + public void Write(in ReadOnlySequence value) + { + if (value.IsSingleSegment) + { +#if NETCOREAPP3_0_OR_GREATER || NETSTANDARD2_1 + Write(value.FirstSpan); +#else + Write(value.First.Span); +#endif + } + else + { + WriteMultiSegment(ref this, in value); + } + + static void WriteMultiSegment(ref CycleBuffer @this, in ReadOnlySequence value) + { + foreach (var segment in value) + { +#if NETCOREAPP3_0_OR_GREATER || NETSTANDARD2_1 + @this.Write(value.FirstSpan); +#else + @this.Write(value.First.Span); +#endif + } + } + } +} diff --git a/src/RESPite/Internal/DebugCounters.cs b/src/RESPite/Internal/DebugCounters.cs new file mode 100644 index 000000000..6b0d0866d --- /dev/null +++ b/src/RESPite/Internal/DebugCounters.cs @@ -0,0 +1,70 @@ +using System.Diagnostics; + +namespace RESPite.Internal; + +internal partial class DebugCounters +{ +#if DEBUG + private static int + _tallyAsyncReadCount, + _tallyAsyncReadInlineCount, + _tallyDiscardFullCount, + _tallyDiscardPartialCount; + + private static long + _tallyReadBytes, + _tallyDiscardAverage; +#endif + + [Conditional("DEBUG")] + public static void OnDiscardFull(long count) + { +#if DEBUG + if (count > 0) + { + Interlocked.Increment(ref _tallyDiscardFullCount); + EstimatedMovingRangeAverage(ref _tallyDiscardAverage, count); + } +#endif + } + + [Conditional("DEBUG")] + public static void OnDiscardPartial(long count) + { +#if DEBUG + if (count > 0) + { + Interlocked.Increment(ref _tallyDiscardPartialCount); + EstimatedMovingRangeAverage(ref _tallyDiscardAverage, count); + } +#endif + } + + [Conditional("DEBUG")] + internal static void OnAsyncRead(int bytes, bool inline) + { +#if DEBUG + Interlocked.Increment(ref inline ? ref _tallyAsyncReadInlineCount : ref _tallyAsyncReadCount); + if (bytes > 0) Interlocked.Add(ref _tallyReadBytes, bytes); +#endif + } + +#if DEBUG + private static void EstimatedMovingRangeAverage(ref long field, long value) + { + var oldValue = Volatile.Read(ref field); + var delta = (value - oldValue) >> 3; // is is a 7:1 old:new EMRA, using integer/bit math (alplha=0.125) + if (delta != 0) Interlocked.Add(ref field, delta); + // note: strictly conflicting concurrent calls can skew the value incorrectly; this is, however, + // preferable to getting into a CEX squabble or requiring a lock - it is debug-only and just useful data + } + + public int AsyncReadCount { get; } = Interlocked.Exchange(ref _tallyAsyncReadCount, 0); + public int AsyncReadInlineCount { get; } = Interlocked.Exchange(ref _tallyAsyncReadInlineCount, 0); + public long ReadBytes { get; } = Interlocked.Exchange(ref _tallyReadBytes, 0); + + public long DiscardAverage { get; } = Interlocked.Exchange(ref _tallyDiscardAverage, 32); + public int DiscardFullCount { get; } = Interlocked.Exchange(ref _tallyDiscardFullCount, 0); + public int DiscardPartialCount { get; } = Interlocked.Exchange(ref _tallyDiscardPartialCount, 0); +#endif +} diff --git a/src/RESPite/Internal/RespOperationExtensions.cs b/src/RESPite/Internal/RespOperationExtensions.cs new file mode 100644 index 000000000..0aedccc69 --- /dev/null +++ b/src/RESPite/Internal/RespOperationExtensions.cs @@ -0,0 +1,58 @@ +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace RESPite.Internal; + +internal static class RespOperationExtensions +{ +#if PREVIEW_LANGVER + extension(in RespOperation operation) + { + // since this is valid... + public ref readonly RespOperation Self => ref operation; + + // so is this (the types are layout-identical) + public ref readonly RespOperation Untyped => ref Unsafe.As, RespOperation>( + ref Unsafe.AsRef(in operation)); + } +#endif + + // if we're recycling a buffer, we need to consider it trashable by other threads; for + // debug purposes, force this by overwriting with *****, aka the meaning of life + [Conditional("DEBUG")] + internal static void DebugScramble(this Span value) + => value.Fill(42); + + [Conditional("DEBUG")] + internal static void DebugScramble(this Memory value) + => value.Span.Fill(42); + + [Conditional("DEBUG")] + internal static void DebugScramble(this ReadOnlyMemory value) + => MemoryMarshal.AsMemory(value).Span.Fill(42); + + [Conditional("DEBUG")] + internal static void DebugScramble(this ReadOnlySequence value) + { + if (value.IsSingleSegment) + { + value.First.DebugScramble(); + } + else + { + foreach (var segment in value) + { + segment.DebugScramble(); + } + } + } + + [Conditional("DEBUG")] + internal static void DebugScramble(this byte[]? value) + { + if (value is not null) + value.AsSpan().Fill(42); + } +} diff --git a/src/RESPite/Messages/RespAttributeReader.cs b/src/RESPite/Messages/RespAttributeReader.cs index 46fd26a19..bfeaede79 100644 --- a/src/RESPite/Messages/RespAttributeReader.cs +++ b/src/RESPite/Messages/RespAttributeReader.cs @@ -1,5 +1,4 @@ using System.Diagnostics.CodeAnalysis; -using StackExchange.Redis; namespace RESPite.Messages; diff --git a/src/RESPite/Messages/RespFrameScanner.cs b/src/RESPite/Messages/RespFrameScanner.cs index 5034e994a..da4f9ca63 100644 --- a/src/RESPite/Messages/RespFrameScanner.cs +++ b/src/RESPite/Messages/RespFrameScanner.cs @@ -1,6 +1,5 @@ using System.Buffers; using System.Diagnostics.CodeAnalysis; -using StackExchange.Redis; using static RESPite.Internal.RespConstants; namespace RESPite.Messages; diff --git a/src/RESPite/Messages/RespPrefix.cs b/src/RESPite/Messages/RespPrefix.cs index 828c01d88..d58749120 100644 --- a/src/RESPite/Messages/RespPrefix.cs +++ b/src/RESPite/Messages/RespPrefix.cs @@ -1,5 +1,4 @@ using System.Diagnostics.CodeAnalysis; -using StackExchange.Redis; namespace RESPite.Messages; diff --git a/src/RESPite/Messages/RespReader.cs b/src/RESPite/Messages/RespReader.cs index 56e4ddefa..a44ef520a 100644 --- a/src/RESPite/Messages/RespReader.cs +++ b/src/RESPite/Messages/RespReader.cs @@ -7,7 +7,6 @@ using System.Runtime.CompilerServices; using System.Text; using RESPite.Internal; -using StackExchange.Redis; #if NETCOREAPP3_0_OR_GREATER using System.Runtime.Intrinsics; @@ -1384,7 +1383,7 @@ private readonly bool IsSlow(ReadOnlySpan testValue) /// /// The destination for the copy operation. /// The number of bytes successfully copied. - public readonly int CopyTo(Span target) + public readonly int CopyTo(scoped Span target) { if (TryGetSpan(out var value)) { diff --git a/src/RESPite/Messages/RespScanState.cs b/src/RESPite/Messages/RespScanState.cs index d7038c30d..0b8c99de2 100644 --- a/src/RESPite/Messages/RespScanState.cs +++ b/src/RESPite/Messages/RespScanState.cs @@ -1,7 +1,6 @@ using System.Buffers; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using StackExchange.Redis; namespace RESPite.Messages; diff --git a/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt b/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt index d45262688..95b06c251 100644 --- a/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt +++ b/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt @@ -1,4 +1,26 @@ #nullable enable +const RESPite.Buffers.CycleBuffer.GetAnything = 0 -> int +const RESPite.Buffers.CycleBuffer.GetFullPagesOnly = -1 -> int +RESPite.Buffers.CycleBuffer +RESPite.Buffers.CycleBuffer.Commit(int count) -> void +RESPite.Buffers.CycleBuffer.CommittedIsEmpty.get -> bool +RESPite.Buffers.CycleBuffer.CycleBuffer() -> void +RESPite.Buffers.CycleBuffer.DiscardCommitted(int count) -> void +RESPite.Buffers.CycleBuffer.DiscardCommitted(long count) -> void +RESPite.Buffers.CycleBuffer.GetAllCommitted() -> System.Buffers.ReadOnlySequence +RESPite.Buffers.CycleBuffer.GetCommittedLength() -> long +RESPite.Buffers.CycleBuffer.GetUncommittedMemory(int hint = 0) -> System.Memory +RESPite.Buffers.CycleBuffer.GetUncommittedSpan(int hint = 0) -> System.Span +RESPite.Buffers.CycleBuffer.PageSize.get -> int +RESPite.Buffers.CycleBuffer.Pool.get -> System.Buffers.MemoryPool! +RESPite.Buffers.CycleBuffer.Release() -> void +RESPite.Buffers.CycleBuffer.TryGetCommitted(out System.ReadOnlySpan span) -> bool +RESPite.Buffers.CycleBuffer.TryGetFirstCommittedMemory(int minBytes, out System.ReadOnlyMemory memory) -> bool +RESPite.Buffers.CycleBuffer.TryGetFirstCommittedSpan(int minBytes, out System.ReadOnlySpan span) -> bool +RESPite.Buffers.CycleBuffer.UncommittedAvailable.get -> int +RESPite.Buffers.CycleBuffer.Write(in System.Buffers.ReadOnlySequence value) -> void +RESPite.Buffers.CycleBuffer.Write(System.ReadOnlySpan value) -> void +static RESPite.Buffers.CycleBuffer.Create(System.Buffers.MemoryPool? pool = null, int pageSize = 8192) -> RESPite.Buffers.CycleBuffer [SER004]const RESPite.Messages.RespScanState.MinBytes = 3 -> int [SER004]override RESPite.Messages.RespScanState.Equals(object? obj) -> bool [SER004]override RESPite.Messages.RespScanState.GetHashCode() -> int @@ -47,8 +69,8 @@ [SER004]RESPite.Messages.RespReader.AggregateEnumerator.Value -> RESPite.Messages.RespReader [SER004]RESPite.Messages.RespReader.AggregateLength() -> int [SER004]RESPite.Messages.RespReader.BytesConsumed.get -> long +[SER004]RESPite.Messages.RespReader.CopyTo(scoped System.Span target) -> int [SER004]RESPite.Messages.RespReader.CopyTo(System.Buffers.IBufferWriter! target) -> int -[SER004]RESPite.Messages.RespReader.CopyTo(System.Span target) -> int [SER004]RESPite.Messages.RespReader.DemandAggregate() -> void [SER004]RESPite.Messages.RespReader.DemandEnd() -> void [SER004]RESPite.Messages.RespReader.DemandNotNull() -> void diff --git a/src/RESPite/PublicAPI/net6.0/PublicAPI.Shipped.txt b/src/RESPite/PublicAPI/net6.0/PublicAPI.Shipped.txt new file mode 100644 index 000000000..ab058de62 --- /dev/null +++ b/src/RESPite/PublicAPI/net6.0/PublicAPI.Shipped.txt @@ -0,0 +1 @@ +#nullable enable diff --git a/src/RESPite/PublicAPI/net6.0/PublicAPI.Unshipped.txt b/src/RESPite/PublicAPI/net6.0/PublicAPI.Unshipped.txt new file mode 100644 index 000000000..91b0e1a43 --- /dev/null +++ b/src/RESPite/PublicAPI/net6.0/PublicAPI.Unshipped.txt @@ -0,0 +1 @@ +#nullable enable \ No newline at end of file diff --git a/src/RESPite/RESPite.csproj b/src/RESPite/RESPite.csproj index 4ad8a0634..abde624b2 100644 --- a/src/RESPite/RESPite.csproj +++ b/src/RESPite/RESPite.csproj @@ -8,7 +8,6 @@ false 2025 - $([System.DateTime]::Now.Year) Marc Gravell readme.md - $(DefineConstants);RESPITE @@ -25,27 +24,19 @@ - + RespReader.cs - - Shared/Experiments.cs - - - Shared/FrameworkShims.cs - - - Shared/NullableHacks.cs - - - Shared/SkipLocalsInit.cs - - - + + + - + + + + diff --git a/src/RESPite/RespException.cs b/src/RESPite/RespException.cs index a6cb0c66a..6b5fd7c72 100644 --- a/src/RESPite/RespException.cs +++ b/src/RESPite/RespException.cs @@ -1,5 +1,4 @@ using System.Diagnostics.CodeAnalysis; -using StackExchange.Redis; namespace RESPite; diff --git a/src/StackExchange.Redis/Experiments.cs b/src/RESPite/Shared/Experiments.cs similarity index 95% rename from src/StackExchange.Redis/Experiments.cs rename to src/RESPite/Shared/Experiments.cs index 1ec2b6f09..e416e5b4d 100644 --- a/src/StackExchange.Redis/Experiments.cs +++ b/src/RESPite/Shared/Experiments.cs @@ -1,6 +1,4 @@ -using System.Diagnostics.CodeAnalysis; - -namespace StackExchange.Redis +namespace RESPite { // example usage: // [Experimental(Experiments.SomeFeature, UrlFormat = Experiments.UrlFormat)] diff --git a/src/StackExchange.Redis/FrameworkShims.cs b/src/RESPite/Shared/FrameworkShims.Encoding.cs similarity index 57% rename from src/StackExchange.Redis/FrameworkShims.cs rename to src/RESPite/Shared/FrameworkShims.Encoding.cs index c1c1bcfe2..95de016c4 100644 --- a/src/StackExchange.Redis/FrameworkShims.cs +++ b/src/RESPite/Shared/FrameworkShims.Encoding.cs @@ -1,35 +1,5 @@ -#pragma warning disable SA1403 // single namespace - -#if RESPITE // add nothing -#elif NET5_0_OR_GREATER -// context: https://github.com/StackExchange/StackExchange.Redis/issues/2619 -[assembly: System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Runtime.CompilerServices.IsExternalInit))] -#else -// To support { get; init; } properties -using System.ComponentModel; -using System.Text; - -namespace System.Runtime.CompilerServices -{ - [EditorBrowsable(EditorBrowsableState.Never)] - internal static class IsExternalInit { } -} -#endif - -#if !NET9_0_OR_GREATER -namespace System.Runtime.CompilerServices -{ - // see https://learn.microsoft.com/dotnet/api/system.runtime.compilerservices.overloadresolutionpriorityattribute - [AttributeUsage(AttributeTargets.Constructor | AttributeTargets.Method | AttributeTargets.Property, Inherited = false)] - internal sealed class OverloadResolutionPriorityAttribute(int priority) : Attribute - { - public int Priority => priority; - } -} -#endif - #if !(NETCOREAPP || NETSTANDARD2_1_OR_GREATER) - +// ReSharper disable once CheckNamespace namespace System.Text { internal static class EncodingExtensions @@ -74,6 +44,3 @@ public static unsafe string GetString(this Encoding encoding, ReadOnlySpan } } #endif - - -#pragma warning restore SA1403 diff --git a/src/RESPite/Shared/FrameworkShims.Stream.cs b/src/RESPite/Shared/FrameworkShims.Stream.cs new file mode 100644 index 000000000..56823abc4 --- /dev/null +++ b/src/RESPite/Shared/FrameworkShims.Stream.cs @@ -0,0 +1,107 @@ +using System.Buffers; +using System.Runtime.InteropServices; + +#if !(NETCOREAPP || NETSTANDARD2_1_OR_GREATER) +// ReSharper disable once CheckNamespace +namespace System.IO +{ + internal static class StreamExtensions + { + public static void Write(this Stream stream, ReadOnlyMemory value) + { + if (MemoryMarshal.TryGetArray(value, out var segment)) + { + stream.Write(segment.Array!, segment.Offset, segment.Count); + } + else + { + var leased = ArrayPool.Shared.Rent(value.Length); + value.CopyTo(leased); + stream.Write(leased, 0, value.Length); + ArrayPool.Shared.Return(leased); // on success only + } + } + + public static int Read(this Stream stream, Memory value) + { + if (MemoryMarshal.TryGetArray(value, out var segment)) + { + return stream.Read(segment.Array!, segment.Offset, segment.Count); + } + else + { + var leased = ArrayPool.Shared.Rent(value.Length); + int bytes = stream.Read(leased, 0, value.Length); + if (bytes > 0) + { + leased.AsSpan(0, bytes).CopyTo(value.Span); + } + ArrayPool.Shared.Return(leased); // on success only + return bytes; + } + } + + public static ValueTask ReadAsync(this Stream stream, Memory value, CancellationToken cancellationToken) + { + if (MemoryMarshal.TryGetArray(value, out var segment)) + { + return new(stream.ReadAsync(segment.Array!, segment.Offset, segment.Count, cancellationToken)); + } + else + { + var leased = ArrayPool.Shared.Rent(value.Length); + var pending = stream.ReadAsync(leased, 0, value.Length, cancellationToken); + if (!pending.IsCompleted) + { + return Awaited(pending, value, leased); + } + + var bytes = pending.GetAwaiter().GetResult(); + if (bytes > 0) + { + leased.AsSpan(0, bytes).CopyTo(value.Span); + } + ArrayPool.Shared.Return(leased); // on success only + return new(bytes); + + static async ValueTask Awaited(Task pending, Memory value, byte[] leased) + { + var bytes = await pending.ConfigureAwait(false); + if (bytes > 0) + { + leased.AsSpan(0, bytes).CopyTo(value.Span); + } + ArrayPool.Shared.Return(leased); // on success only + return bytes; + } + } + } + + public static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory value, CancellationToken cancellationToken) + { + if (MemoryMarshal.TryGetArray(value, out var segment)) + { + return new(stream.WriteAsync(segment.Array!, segment.Offset, segment.Count, cancellationToken)); + } + else + { + var leased = ArrayPool.Shared.Rent(value.Length); + value.CopyTo(leased); + var pending = stream.WriteAsync(leased, 0, value.Length, cancellationToken); + if (!pending.IsCompleted) + { + return Awaited(pending, leased); + } + pending.GetAwaiter().GetResult(); + ArrayPool.Shared.Return(leased); // on success only + return default; + } + static async ValueTask Awaited(Task pending, byte[] leased) + { + await pending.ConfigureAwait(false); + ArrayPool.Shared.Return(leased); // on success only + } + } + } +} +#endif diff --git a/src/RESPite/Shared/FrameworkShims.cs b/src/RESPite/Shared/FrameworkShims.cs new file mode 100644 index 000000000..0f7aa641c --- /dev/null +++ b/src/RESPite/Shared/FrameworkShims.cs @@ -0,0 +1,15 @@ +#pragma warning disable SA1403 // single namespace + +#if !NET9_0_OR_GREATER +namespace System.Runtime.CompilerServices +{ + // see https://learn.microsoft.com/dotnet/api/system.runtime.compilerservices.overloadresolutionpriorityattribute + [AttributeUsage(AttributeTargets.Constructor | AttributeTargets.Method | AttributeTargets.Property, Inherited = false)] + internal sealed class OverloadResolutionPriorityAttribute(int priority) : Attribute + { + public int Priority => priority; + } +} +#endif + +#pragma warning restore SA1403 diff --git a/src/StackExchange.Redis/NullableHacks.cs b/src/RESPite/Shared/NullableHacks.cs similarity index 100% rename from src/StackExchange.Redis/NullableHacks.cs rename to src/RESPite/Shared/NullableHacks.cs diff --git a/src/StackExchange.Redis/APITypes/StreamInfo.cs b/src/StackExchange.Redis/APITypes/StreamInfo.cs index e37df5add..1de0526ec 100644 --- a/src/StackExchange.Redis/APITypes/StreamInfo.cs +++ b/src/StackExchange.Redis/APITypes/StreamInfo.cs @@ -1,4 +1,5 @@ using System.Diagnostics.CodeAnalysis; +using RESPite; namespace StackExchange.Redis; diff --git a/src/StackExchange.Redis/ConfigurationOptions.cs b/src/StackExchange.Redis/ConfigurationOptions.cs index c0021f024..c226239ac 100644 --- a/src/StackExchange.Redis/ConfigurationOptions.cs +++ b/src/StackExchange.Redis/ConfigurationOptions.cs @@ -330,9 +330,9 @@ internal static LocalCertificateSelectionCallback CreatePemUserCertificateCallba { // PEM handshakes not universally supported and causes a runtime error about ephemeral certificates; to avoid, export as PFX using var pem = X509Certificate2.CreateFromPemFile(userCertificatePath, userKeyPath); -#pragma warning disable SYSLIB0057 // Type or member is obsolete +#pragma warning disable SYSLIB0057 // X509 loading var pfx = new X509Certificate2(pem.Export(X509ContentType.Pfx)); -#pragma warning restore SYSLIB0057 // Type or member is obsolete +#pragma warning restore SYSLIB0057 // X509 loading return (sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) => pfx; } @@ -340,7 +340,9 @@ internal static LocalCertificateSelectionCallback CreatePemUserCertificateCallba internal static LocalCertificateSelectionCallback CreatePfxUserCertificateCallback(string userCertificatePath, string? password, X509KeyStorageFlags storageFlags = X509KeyStorageFlags.DefaultKeySet) { +#pragma warning disable SYSLIB0057 // X509 loading var pfx = new X509Certificate2(userCertificatePath, password ?? "", storageFlags); +#pragma warning restore SYSLIB0057 // X509 loading return (sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) => pfx; } @@ -350,8 +352,11 @@ internal static LocalCertificateSelectionCallback CreatePfxUserCertificateCallba /// The issuer to trust. public void TrustIssuer(X509Certificate2 issuer) => CertificateValidationCallback = TrustIssuerCallback(issuer); +#pragma warning disable SYSLIB0057 // X509 loading internal static RemoteCertificateValidationCallback TrustIssuerCallback(string issuerCertificatePath) => TrustIssuerCallback(new X509Certificate2(issuerCertificatePath)); +#pragma warning restore SYSLIB0057 // X509 loading + private static RemoteCertificateValidationCallback TrustIssuerCallback(X509Certificate2 issuer) { if (issuer == null) throw new ArgumentNullException(nameof(issuer)); diff --git a/src/StackExchange.Redis/ExceptionFactory.cs b/src/StackExchange.Redis/ExceptionFactory.cs index 7e4eca49a..3cfb0268c 100644 --- a/src/StackExchange.Redis/ExceptionFactory.cs +++ b/src/StackExchange.Redis/ExceptionFactory.cs @@ -107,7 +107,7 @@ internal static Exception NoConnectionAvailable( serverSnapshot = new ServerEndPoint[] { server }; } - var innerException = PopulateInnerExceptions(serverSnapshot == default ? multiplexer.GetServerSnapshot() : serverSnapshot); + var innerException = PopulateInnerExceptions(serverSnapshot.IsEmpty ? multiplexer.GetServerSnapshot() : serverSnapshot); // Try to get a useful error message for the user. long attempts = multiplexer._connectAttemptCount, completions = multiplexer._connectCompletedCount; diff --git a/src/StackExchange.Redis/FrameworkShims.IsExternalInit.cs b/src/StackExchange.Redis/FrameworkShims.IsExternalInit.cs new file mode 100644 index 000000000..417975050 --- /dev/null +++ b/src/StackExchange.Redis/FrameworkShims.IsExternalInit.cs @@ -0,0 +1,15 @@ +using System.ComponentModel; + +#if NET5_0_OR_GREATER +// context: https://github.com/StackExchange/StackExchange.Redis/issues/2619 +[assembly: System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Runtime.CompilerServices.IsExternalInit))] +#else + +// To support { get; init; } properties +// ReSharper disable once CheckNamespace +namespace System.Runtime.CompilerServices +{ + [EditorBrowsable(EditorBrowsableState.Never)] + internal static class IsExternalInit { } +} +#endif diff --git a/src/StackExchange.Redis/HotKeys.cs b/src/StackExchange.Redis/HotKeys.cs index 270bcf9f7..2adb98eba 100644 --- a/src/StackExchange.Redis/HotKeys.cs +++ b/src/StackExchange.Redis/HotKeys.cs @@ -1,6 +1,7 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Threading.Tasks; +using RESPite; namespace StackExchange.Redis; diff --git a/src/StackExchange.Redis/Interfaces/IDatabase.VectorSets.cs b/src/StackExchange.Redis/Interfaces/IDatabase.VectorSets.cs index 039075ec8..91eb1e43c 100644 --- a/src/StackExchange.Redis/Interfaces/IDatabase.VectorSets.cs +++ b/src/StackExchange.Redis/Interfaces/IDatabase.VectorSets.cs @@ -1,5 +1,6 @@ using System; using System.Diagnostics.CodeAnalysis; +using RESPite; // ReSharper disable once CheckNamespace namespace StackExchange.Redis; diff --git a/src/StackExchange.Redis/Interfaces/IDatabase.cs b/src/StackExchange.Redis/Interfaces/IDatabase.cs index cf2ecafac..e26154652 100644 --- a/src/StackExchange.Redis/Interfaces/IDatabase.cs +++ b/src/StackExchange.Redis/Interfaces/IDatabase.cs @@ -3,6 +3,7 @@ using System.ComponentModel; using System.Diagnostics.CodeAnalysis; using System.Net; +using RESPite; // ReSharper disable once CheckNamespace namespace StackExchange.Redis diff --git a/src/StackExchange.Redis/Interfaces/IDatabaseAsync.VectorSets.cs b/src/StackExchange.Redis/Interfaces/IDatabaseAsync.VectorSets.cs index 3ac67d40f..2e3499557 100644 --- a/src/StackExchange.Redis/Interfaces/IDatabaseAsync.VectorSets.cs +++ b/src/StackExchange.Redis/Interfaces/IDatabaseAsync.VectorSets.cs @@ -1,6 +1,7 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Threading.Tasks; +using RESPite; // ReSharper disable once CheckNamespace namespace StackExchange.Redis; diff --git a/src/StackExchange.Redis/Interfaces/IDatabaseAsync.cs b/src/StackExchange.Redis/Interfaces/IDatabaseAsync.cs index 029c7975e..c581470ca 100644 --- a/src/StackExchange.Redis/Interfaces/IDatabaseAsync.cs +++ b/src/StackExchange.Redis/Interfaces/IDatabaseAsync.cs @@ -5,6 +5,7 @@ using System.IO; using System.Net; using System.Threading.Tasks; +using RESPite; // ReSharper disable once CheckNamespace namespace StackExchange.Redis diff --git a/src/StackExchange.Redis/KeyspaceIsolation/KeyPrefixed.VectorSets.cs b/src/StackExchange.Redis/KeyspaceIsolation/KeyPrefixed.VectorSets.cs index 809adad97..f1a08ecd1 100644 --- a/src/StackExchange.Redis/KeyspaceIsolation/KeyPrefixed.VectorSets.cs +++ b/src/StackExchange.Redis/KeyspaceIsolation/KeyPrefixed.VectorSets.cs @@ -1,6 +1,7 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Threading.Tasks; +using RESPite; // ReSharper disable once CheckNamespace namespace StackExchange.Redis.KeyspaceIsolation; diff --git a/src/StackExchange.Redis/LoggerExtensions.cs b/src/StackExchange.Redis/LoggerExtensions.cs index be51733ce..dfd8576b4 100644 --- a/src/StackExchange.Redis/LoggerExtensions.cs +++ b/src/StackExchange.Redis/LoggerExtensions.cs @@ -494,7 +494,7 @@ internal static void LogWithThreadPoolStats(this ILogger? log, string message) Level = LogLevel.Information, EventId = 71, Message = "Response from {BridgeName} / {CommandAndKey}: {Result}")] - internal static partial void LogInformationResponse(this ILogger logger, string? bridgeName, string commandAndKey, RawResult result); + internal static partial void LogInformationResponse(this ILogger logger, string? bridgeName, string commandAndKey, string result); [LoggerMessage( Level = LogLevel.Information, diff --git a/src/StackExchange.Redis/Message.cs b/src/StackExchange.Redis/Message.cs index faf25ba44..441019959 100644 --- a/src/StackExchange.Redis/Message.cs +++ b/src/StackExchange.Redis/Message.cs @@ -7,6 +7,7 @@ using System.Text; using System.Threading; using Microsoft.Extensions.Logging; +using RESPite.Messages; using StackExchange.Redis.Profiling; namespace StackExchange.Redis @@ -601,7 +602,7 @@ internal static CommandFlags SetPrimaryReplicaFlags(CommandFlags everything, Com internal void Cancel() => resultBox?.Cancel(); // true if ready to be completed (i.e. false if re-issued to another server) - internal bool ComputeResult(PhysicalConnection connection, in RawResult result) + internal bool ComputeResult(PhysicalConnection connection, in RespReader frame) { var box = resultBox; try @@ -610,11 +611,12 @@ internal bool ComputeResult(PhysicalConnection connection, in RawResult result) if (resultProcessor == null) return true; // false here would be things like resends (MOVED) - the message is not yet complete - return resultProcessor.SetResult(connection, this, result); + var mutable = frame; + return resultProcessor.SetResult(connection, this, ref mutable); } catch (Exception ex) { - ex.Data.Add("got", result.ToString()); + ex.Data.Add("got", frame.Prefix.ToString()); connection?.BridgeCouldBeNull?.Multiplexer?.OnMessageFaulted(this, ex); box?.SetException(ex); return box != null; // we still want to pulse/complete diff --git a/src/StackExchange.Redis/PhysicalConnection.Read.cs b/src/StackExchange.Redis/PhysicalConnection.Read.cs new file mode 100644 index 000000000..b7c976082 --- /dev/null +++ b/src/StackExchange.Redis/PhysicalConnection.Read.cs @@ -0,0 +1,757 @@ +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Net; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; +using Pipelines.Sockets.Unofficial; +using RESPite.Buffers; +using RESPite.Internal; +using RESPite.Messages; + +namespace StackExchange.Redis; + +internal sealed partial class PhysicalConnection +{ + private volatile ReadStatus _readStatus = ReadStatus.NotStarted; + internal ReadStatus GetReadStatus() => _readStatus; + + internal void StartReading() => ReadAllAsync(Stream.Null).RedisFireAndForget(); + + private async Task ReadAllAsync(Stream tail) + { + _readStatus = ReadStatus.Init; + RespScanState state = default; + var readBuffer = CycleBuffer.Create(); + try + { + int read; + do + { + _readStatus = ReadStatus.ReadAsync; + var buffer = readBuffer.GetUncommittedMemory(); + var pending = tail.ReadAsync(buffer, CancellationToken.None); +#if DEBUG + bool inline = pending.IsCompleted; +#endif + read = await pending.ConfigureAwait(false); + _readStatus = ReadStatus.UpdateWriteTime; + UpdateLastReadTime(); +#if DEBUG + DebugCounters.OnAsyncRead(read, inline); +#endif + _readStatus = ReadStatus.TryParseResult; + } + // another formatter glitch + while (CommitAndParseFrames(ref state, ref readBuffer, read)); + _readStatus = ReadStatus.ProcessBufferComplete; + + // Volatile.Write(ref _readStatus, ReaderCompleted); + readBuffer.Release(); // clean exit, we can recycle + _readStatus = ReadStatus.RanToCompletion; + RecordConnectionFailed(ConnectionFailureType.SocketClosed); + } + catch (EndOfStreamException) when (_readStatus is ReadStatus.ReadAsync) + { + _readStatus = ReadStatus.RanToCompletion; + RecordConnectionFailed(ConnectionFailureType.SocketClosed); + } + catch (Exception ex) + { + _readStatus = ReadStatus.Faulted; + RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex); + } + } + + private static byte[]? SharedNoLease; + + private bool CommitAndParseFrames(ref RespScanState state, ref CycleBuffer readBuffer, int bytesRead) + { + if (bytesRead <= 0) + { + return false; + } + +#if PARSE_DETAIL + string src = $"parse {bytesRead}"; + try +#endif + { + Debug.Assert(readBuffer.GetCommittedLength() >= 0, "multi-segment running-indices are corrupt"); +#if PARSE_DETAIL + src += $" ({readBuffer.GetCommittedLength()}+{bytesRead}-{state.TotalBytes})"; +#endif + Debug.Assert( + bytesRead <= readBuffer.UncommittedAvailable, + $"Insufficient bytes in {nameof(CommitAndParseFrames)}; got {bytesRead}, Available={readBuffer.UncommittedAvailable}"); + readBuffer.Commit(bytesRead); +#if PARSE_DETAIL + src += $",total {readBuffer.GetCommittedLength()}"; +#endif + var scanner = RespFrameScanner.Default; + + OperationStatus status = OperationStatus.NeedMoreData; + if (readBuffer.TryGetCommitted(out var fullSpan)) + { + int fullyConsumed = 0; + var toParse = fullSpan.Slice((int)state.TotalBytes); // skip what we've already parsed + + Debug.Assert(!toParse.IsEmpty); + while (true) + { +#if PARSE_DETAIL + src += $",span {toParse.Length}"; +#endif + int totalBytesBefore = (int)state.TotalBytes; + if (toParse.Length < RespScanState.MinBytes + || (status = scanner.TryRead(ref state, toParse)) != OperationStatus.Done) + { + break; + } + + Debug.Assert( + state is + { + IsComplete: true, TotalBytes: >= RespScanState.MinBytes, Prefix: not RespPrefix.None + }, + "Invalid RESP read state"); + + // extract the frame + var bytes = (int)state.TotalBytes; +#if PARSE_DETAIL + src += $",frame {bytes}"; +#endif + // send the frame somewhere (note this is the *full* frame, not just the bit we just parsed) + OnResponseFrame(state.Prefix, fullSpan.Slice(fullyConsumed, bytes), ref SharedNoLease); + UpdateBufferStats(bytes, toParse.Length); + + // update our buffers to the unread potions and reset for a new RESP frame + fullyConsumed += bytes; + toParse = toParse.Slice(bytes - totalBytesBefore); // move past the extra bytes we just read + state = default; + status = OperationStatus.NeedMoreData; + } + + readBuffer.DiscardCommitted(fullyConsumed); + } + else // the same thing again, but this time with multi-segment sequence + { + var fullSequence = readBuffer.GetAllCommitted(); + Debug.Assert( + fullSequence is { IsEmpty: false, IsSingleSegment: false }, + "non-trivial sequence expected"); + + long fullyConsumed = 0; + var toParse = fullSequence.Slice((int)state.TotalBytes); // skip what we've already parsed + while (true) + { +#if PARSE_DETAIL + src += $",ros {toParse.Length}"; +#endif + int totalBytesBefore = (int)state.TotalBytes; + if (toParse.Length < RespScanState.MinBytes + || (status = scanner.TryRead(ref state, toParse)) != OperationStatus.Done) + { + break; + } + + Debug.Assert( + state is + { + IsComplete: true, TotalBytes: >= RespScanState.MinBytes, Prefix: not RespPrefix.None + }, + "Invalid RESP read state"); + + // extract the frame + var bytes = (int)state.TotalBytes; +#if PARSE_DETAIL + src += $",frame {bytes}"; +#endif + // send the frame somewhere (note this is the *full* frame, not just the bit we just parsed) + OnResponseFrame(state.Prefix, fullSequence.Slice(fullyConsumed, bytes)); + UpdateBufferStats(bytes, toParse.Length); + + // update our buffers to the unread potions and reset for a new RESP frame + fullyConsumed += bytes; + toParse = toParse.Slice(bytes - totalBytesBefore); // move past the extra bytes we just read + state = default; + status = OperationStatus.NeedMoreData; + } + + readBuffer.DiscardCommitted(fullyConsumed); + } + + if (status != OperationStatus.NeedMoreData) + { + ThrowStatus(status); + + static void ThrowStatus(OperationStatus status) => + throw new InvalidOperationException($"Unexpected operation status: {status}"); + } + + return true; + } +#if PARSE_DETAIL + catch (Exception ex) + { + Debug.WriteLine($"{nameof(CommitAndParseFrames)}: {ex.Message}"); + Debug.WriteLine(src); + ActivationHelper.DebugBreak(); + throw new InvalidOperationException($"{src} lead to {ex.Message}", ex); + } +#endif + } + + private void OnResponseFrame(RespPrefix prefix, ReadOnlySequence payload) + { + if (payload.IsSingleSegment) + { +#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER + OnResponseFrame(prefix, payload.FirstSpan, ref SharedNoLease); +#else + OnResponseFrame(prefix, payload.First.Span, ref SharedNoLease); +#endif + } + else + { + var len = checked((int)payload.Length); + byte[]? oversized = ArrayPool.Shared.Rent(len); + payload.CopyTo(oversized); + OnResponseFrame(prefix, new(oversized, 0, len), ref oversized); + + // the lease could have been claimed by the activation code (to prevent another memcpy); otherwise, free + if (oversized is not null) + { + ArrayPool.Shared.Return(oversized); + } + } + } + + public RespMode Mode { get; set; } = RespMode.Resp2; + + public enum RespMode + { + Resp2, + Resp2PubSub, + Resp3, + } + + private void UpdateBufferStats(int lastResult, long inBuffer) + { + // Track the last result size *after* processing for the *next* error message + bytesInBuffer = inBuffer; + bytesLastResult = lastResult; + } + + private void OnResponseFrame(RespPrefix prefix, ReadOnlySpan frame, ref byte[]? lease) + { + DebugValidateSingleFrame(frame); + _readStatus = ReadStatus.MatchResult; + switch (prefix) + { + case RespPrefix.Push: // explicit push message + case RespPrefix.Array when Mode is RespMode.Resp2PubSub && !IsArrayPong(frame): // likely pub/sub payload + // out-of-band; pub/sub etc + if (OnOutOfBand(frame, ref lease)) + { + return; + } + break; + } + + // request/response; match to inbound + MatchNextResult(frame); + + static bool IsArrayPong(ReadOnlySpan payload) + { + if (payload.Length >= sizeof(ulong)) + { + var hash = payload.Hash64(); + switch (hash) + { + case ArrayPong_LC_Bulk.Hash when payload.StartsWith(ArrayPong_LC_Bulk.U8): + case ArrayPong_UC_Bulk.Hash when payload.StartsWith(ArrayPong_UC_Bulk.U8): + case ArrayPong_LC_Simple.Hash when payload.StartsWith(ArrayPong_LC_Simple.U8): + case ArrayPong_UC_Simple.Hash when payload.StartsWith(ArrayPong_UC_Simple.U8): + var reader = new RespReader(payload); + return reader.SafeTryMoveNext() // have root + && reader.Prefix == RespPrefix.Array // root is array + && reader.SafeTryMoveNext() // have first child + && (reader.IsInlneCpuUInt32(pong) || reader.IsInlneCpuUInt32(PONG)); // pong + } + } + + return false; + } + } + + private enum PushKind + { + None, + Message, + PMessage, + SMessage, + Subscribe, + PSubscribe, + SSubscribe, + Unsubscribe, + PUnsubscribe, + SUnsubscribe, + } + + private bool OnOutOfBand(ReadOnlySpan payload, ref byte[]? lease) + { + static ReadOnlySpan StackCopyLenChecked(scoped in RespReader reader, Span buffer) + { + var len = reader.CopyTo(buffer); + if (len == buffer.Length && reader.ScalarLength() > len) return default; // too small + return buffer.Slice(0, len); + } + + var muxer = BridgeCouldBeNull?.Multiplexer; + if (muxer is null) return true; // consume it blindly + + var reader = new RespReader(payload); + + // read the message kind from the first element + int len; + if (reader.SafeTryMoveNext() & reader.IsAggregate & !reader.IsStreaming + && (len = reader.AggregateLength()) >= 2 + && (reader.SafeTryMoveNext() & reader.IsInlineScalar & !reader.IsError)) + { + const int MAX_TYPE_LEN = 16; + var span = reader.TryGetSpan(out var tmp) + ? tmp : StackCopyLenChecked(in reader, stackalloc byte[MAX_TYPE_LEN]); + + var hash = span.Hash64(); + RedisChannel.RedisChannelOptions channelOptions = RedisChannel.RedisChannelOptions.None; + PushKind kind; + switch (hash) + { + case PushMessage.Hash when PushMessage.Is(hash, span) & len >= 3: + kind = PushKind.Message; + break; + case PushPMessage.Hash when PushPMessage.Is(hash, span) & len >= 4: + channelOptions = RedisChannel.RedisChannelOptions.Pattern; + kind = PushKind.PMessage; + break; + case PushSMessage.Hash when PushSMessage.Is(hash, span) & len >= 3: + channelOptions = RedisChannel.RedisChannelOptions.Sharded; + kind = PushKind.SMessage; + break; + case PushSubscribe.Hash when PushSubscribe.Is(hash, span): + kind = PushKind.Subscribe; + break; + case PushPSubscribe.Hash when PushPSubscribe.Is(hash, span): + channelOptions = RedisChannel.RedisChannelOptions.Pattern; + kind = PushKind.PSubscribe; + break; + case PushSSubscribe.Hash when PushSSubscribe.Is(hash, span): + channelOptions = RedisChannel.RedisChannelOptions.Sharded; + kind = PushKind.SSubscribe; + break; + case PushUnsubscribe.Hash when PushUnsubscribe.Is(hash, span): + kind = PushKind.Unsubscribe; + break; + case PushPUnsubscribe.Hash when PushPUnsubscribe.Is(hash, span): + channelOptions = RedisChannel.RedisChannelOptions.Pattern; + kind = PushKind.PUnsubscribe; + break; + case PushSUnsubscribe.Hash when PushSUnsubscribe.Is(hash, span): + channelOptions = RedisChannel.RedisChannelOptions.Sharded; + kind = PushKind.SUnsubscribe; + break; + default: + kind = PushKind.None; + break; + } + + static bool TryMoveNextString(ref RespReader reader) + => reader.SafeTryMoveNext() & reader.IsInlineScalar & + reader.Prefix is RespPrefix.BulkString or RespPrefix.SimpleString; + + if (kind is PushKind.None || !TryMoveNextString(ref reader)) return false; + + // the channel is always the second element + var subscriptionChannel = AsRedisChannel(reader, channelOptions); + + switch (kind) + { + case (PushKind.Message or PushKind.SMessage) when reader.SafeTryMoveNext(): + _readStatus = kind is PushKind.Message ? ReadStatus.PubSubMessage : ReadStatus.PubSubSMessage; + + // special-case the configuration change broadcasts (we don't keep that in the usual pub/sub registry) + var configChanged = muxer.ConfigurationChangedChannel; + if (configChanged != null && reader.Prefix is RespPrefix.BulkString or RespPrefix.SimpleString && reader.Is(configChanged)) + { + EndPoint? blame = null; + try + { + if (!reader.Is("*"u8)) + { + // We don't want to fail here, just trying to identify + _ = Format.TryParseEndPoint(reader.ReadString(), out blame); + } + } + catch + { + /* no biggie */ + } + + Trace("Configuration changed: " + Format.ToString(blame)); + _readStatus = ReadStatus.Reconfigure; + muxer.ReconfigureIfNeeded(blame, true, "broadcast"); + } + + // invoke the handlers + if (!subscriptionChannel.IsNull) + { + Trace($"{kind}: {subscriptionChannel}"); + OnMessage(muxer, subscriptionChannel, subscriptionChannel, in reader); + } + + return true; + case PushKind.PMessage when TryMoveNextString(ref reader): + _readStatus = ReadStatus.PubSubPMessage; + + var messageChannel = AsRedisChannel(reader, RedisChannel.RedisChannelOptions.None); + if (!messageChannel.IsNull && reader.SafeTryMoveNext()) + { + Trace($"{kind}: {messageChannel} via {subscriptionChannel}"); + OnMessage(muxer, subscriptionChannel, messageChannel, in reader); + } + + return true; + case PushKind.SUnsubscribe when !PeekChannelMessage(RedisCommand.SUNSUBSCRIBE, subscriptionChannel): + // then it was *unsolicited* - this probably means the slot was migrated + // (otherwise, we'll let the command-processor deal with it) + _readStatus = ReadStatus.PubSubUnsubscribe; + var server = BridgeCouldBeNull?.ServerEndPoint; + if (server is not null && muxer.TryGetSubscription(subscriptionChannel, out var subscription)) + { + // wipe and reconnect; but: to where? + // counter-intuitively, the only server we *know* already knows the new route is: + // the outgoing server, since it had to change to MIGRATING etc; the new INCOMING server + // knows, but *we don't know who that is*, and other nodes: aren't guaranteed to know (yet) + muxer.DefaultSubscriber.ResubscribeToServer(subscription, subscriptionChannel, server, cause: PushSUnsubscribe.Text); + } + return true; + } + } + return false; + } + + private void OnMessage( + ConnectionMultiplexer muxer, + in RedisChannel subscriptionChannel, + in RedisChannel messageChannel, + in RespReader reader) + { + // note: this could be multi-message: https://github.com/StackExchange/StackExchange.Redis/issues/2507 + _readStatus = ReadStatus.InvokePubSub; + switch (reader.Prefix) + { + case RespPrefix.BulkString: + case RespPrefix.SimpleString: + muxer.OnMessage(subscriptionChannel, messageChannel, reader.ReadRedisValue()); + break; + case RespPrefix.Array: + var iter = reader.AggregateChildren(); + while (iter.MoveNext()) + { + muxer.OnMessage(subscriptionChannel, messageChannel, iter.Current.ReadRedisValue()); + } + + break; + } + } + + private void MatchNextResult(ReadOnlySpan frame) + { + Trace("Matching result..."); + + Message? msg = null; + // check whether we're waiting for a high-integrity mode post-response checksum (using cheap null-check first) + if (_awaitingToken is not null && (msg = Interlocked.Exchange(ref _awaitingToken, null)) is not null) + { + _readStatus = ReadStatus.ResponseSequenceCheck; + if (!ProcessHighIntegrityResponseToken(msg, frame, BridgeCouldBeNull)) + { + RecordConnectionFailed(ConnectionFailureType.ResponseIntegrityFailure, origin: nameof(ReadStatus.ResponseSequenceCheck)); + } + return; + } + + _readStatus = ReadStatus.DequeueResult; + lock (_writtenAwaitingResponse) + { + if (msg is not null) + { + _awaitingToken = null; + } + + if (!_writtenAwaitingResponse.TryDequeue(out msg)) + { + Throw(frame); + + [DoesNotReturn] + static void Throw(ReadOnlySpan frame) + { + var prefix = RespReaderExtensions.GetRespPrefix(frame); + throw new InvalidOperationException("Received response with no message waiting: " + prefix.ToString()); + } + } + } + _activeMessage = msg; + + Trace("Response to: " + msg); + _readStatus = ReadStatus.ComputeResult; + var reader = new RespReader(frame); + if (msg.ComputeResult(this, reader)) + { + _readStatus = msg.ResultBoxIsAsync ? ReadStatus.CompletePendingMessageAsync : ReadStatus.CompletePendingMessageSync; + if (!msg.IsHighIntegrity) + { + // can't complete yet if needs checksum + msg.Complete(); + } + } + if (msg.IsHighIntegrity) + { + // stash this for the next non-OOB response + Volatile.Write(ref _awaitingToken, msg); + } + + _readStatus = ReadStatus.MatchResultComplete; + _activeMessage = null; + + static bool ProcessHighIntegrityResponseToken(Message message, ReadOnlySpan frame, PhysicalBridge? bridge) + { + bool isValid = false; + var reader = new RespReader(frame); + if (reader.SafeTryMoveNext() + & reader.Resp2PrefixBulkString is RespPrefix.BulkString + & reader.ScalarLength() is 4) + { + uint interpreted; + if (reader.TryGetSpan(out var span)) + { + interpreted = BinaryPrimitives.ReadUInt32LittleEndian(span); + } + else + { + Span tmp = stackalloc byte[4]; + reader.CopyTo(tmp); + interpreted = BinaryPrimitives.ReadUInt32LittleEndian(tmp); + } + isValid = interpreted == message.HighIntegrityToken; + } + if (isValid) + { + message.Complete(); + return true; + } + else + { + message.SetExceptionAndComplete(new InvalidOperationException("High-integrity mode detected possible protocol de-sync"), bridge); + return false; + } + } + } + + private bool PeekChannelMessage(RedisCommand command, in RedisChannel channel) + { + Message? msg; + bool haveMsg; + lock (_writtenAwaitingResponse) + { + haveMsg = _writtenAwaitingResponse.TryPeek(out msg); + } + + return haveMsg && msg is Message.CommandChannelBase typed + && typed.Command == command && typed.Channel == channel; + } + + internal RedisChannel AsRedisChannel(in RespReader reader, RedisChannel.RedisChannelOptions options) + { + var channelPrefix = ChannelPrefix; + if (channelPrefix is null) + { + // no channel-prefix enabled, just use as-is + return new RedisChannel(reader.ReadByteArray(), options); + } + + byte[] lease = []; + var span = reader.TryGetSpan(out var tmp) ? tmp : reader.Buffer(ref lease, stackalloc byte[256]); + + if (span.StartsWith(channelPrefix)) + { + // we have a channel-prefix, and it matches; strip it + span = span.Slice(channelPrefix.Length); + } + else if (span.StartsWith("__keyspace@"u8) || span.StartsWith("__keyevent@"u8)) + { + // we shouldn't get unexpected events, so to get here: we've received a notification + // on a channel that doesn't match our prefix; this *should* be limited to + // key notifications (see: IgnoreChannelPrefix), but: we need to be sure + + // leave alone + } + else + { + // no idea what this is + span = default; + } + + RedisChannel channel = span.IsEmpty ? default : new(span.ToArray(), options); + if (lease.Length != 0) ArrayPool.Shared.Return(lease); + return channel; + } + + [FastHash("message")] + private static partial class PushMessage { } + + [FastHash("pmessage")] + private static partial class PushPMessage { } + + [FastHash("smessage")] + private static partial class PushSMessage { } + + [FastHash("subscribe")] + private static partial class PushSubscribe { } + + [FastHash("psubscribe")] + private static partial class PushPSubscribe { } + + [FastHash("ssubscribe")] + private static partial class PushSSubscribe { } + + [FastHash("unsubscribe")] + private static partial class PushUnsubscribe { } + + [FastHash("punsubscribe")] + private static partial class PushPUnsubscribe { } + + [FastHash("sunsubscribe")] + private static partial class PushSUnsubscribe { } + + [FastHash("*2\r\n$4\r\npong\r\n$")] + private static partial class ArrayPong_LC_Bulk { } + [FastHash("*2\r\n$4\r\nPONG\r\n$")] + private static partial class ArrayPong_UC_Bulk { } + [FastHash("*2\r\n+pong\r\n$")] + private static partial class ArrayPong_LC_Simple { } + [FastHash("*2\r\n+PONG\r\n$")] + private static partial class ArrayPong_UC_Simple { } + + // ReSharper disable InconsistentNaming + private static readonly uint + pong = RespConstants.UnsafeCpuUInt32("pong"u8), + PONG = RespConstants.UnsafeCpuUInt32("PONG"u8); + + // ReSharper restore InconsistentNaming + [Conditional("DEBUG")] + private static void DebugValidateSingleFrame(ReadOnlySpan payload) + { + var reader = new RespReader(payload); + if (!reader.SafeTryMoveNext()) + { + throw new InvalidOperationException("No root RESP element"); + } + reader.SkipChildren(); + + if (reader.SafeTryMoveNext()) + { + throw new InvalidOperationException($"Unexpected trailing {reader.Prefix}"); + } + + if (reader.ProtocolBytesRemaining != 0) + { + var copy = reader; // leave reader alone for inspection + var prefix = copy.SafeTryMoveNext() ? copy.Prefix : RespPrefix.None; + throw new InvalidOperationException( + $"Unexpected additional {reader.ProtocolBytesRemaining} bytes remaining, {prefix}"); + } + } + + /* + private async Task ReadFromPipe() + { + _readBuffer = CycleBuffer.Create(); + bool allowSyncRead = true, isReading = false; + try + { + _readStatus = ReadStatus.Init; + while (true) + { + var input = _ioPipe?.Input; + if (input == null) break; + + // note: TryRead will give us back the same buffer in a tight loop + // - so: only use that if we're making progress + isReading = true; + _readStatus = ReadStatus.ReadSync; + if (!(allowSyncRead && input.TryRead(out var readResult))) + { + _readStatus = ReadStatus.ReadAsync; + readResult = await input.ReadAsync().ForAwait(); + } + isReading = false; + _readStatus = ReadStatus.UpdateWriteTime; + UpdateLastReadTime(); + + _readStatus = ReadStatus.ProcessBuffer; + var buffer = readResult.Buffer; + int handled = 0; + if (!buffer.IsEmpty) + { + handled = ProcessBuffer(ref buffer); // updates buffer.Start + } + + allowSyncRead = handled != 0; + + _readStatus = ReadStatus.MarkProcessed; + Trace($"Processed {handled} messages"); + input.AdvanceTo(buffer.Start, buffer.End); + + if (handled == 0 && readResult.IsCompleted) + { + break; // no more data, or trailing incomplete messages + } + } + Trace("EOF"); + RecordConnectionFailed(ConnectionFailureType.SocketClosed); + _readStatus = ReadStatus.RanToCompletion; + } + catch (Exception ex) + { + _readStatus = ReadStatus.Faulted; + // this CEX is just a hardcore "seriously, read the actual value" - there's no + // convenient "Thread.VolatileRead(ref T field) where T : class", and I don't + // want to make the field volatile just for this one place that needs it + if (isReading) + { + var pipe = Volatile.Read(ref _ioPipe); + if (pipe == null) + { + return; + // yeah, that's fine... don't worry about it; we nuked it + } + + // check for confusing read errors - no need to present "Reading is not allowed after reader was completed." + if (pipe is SocketConnection sc && sc.ShutdownKind == PipeShutdownKind.ReadEndOfStream) + { + RecordConnectionFailed(ConnectionFailureType.SocketClosed, new EndOfStreamException()); + return; + } + } + Trace("Faulted"); + RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex); + } + } + */ +} diff --git a/src/StackExchange.Redis/PhysicalConnection.cs b/src/StackExchange.Redis/PhysicalConnection.cs index 857902f48..30cd4d27b 100644 --- a/src/StackExchange.Redis/PhysicalConnection.cs +++ b/src/StackExchange.Redis/PhysicalConnection.cs @@ -19,6 +19,7 @@ using Microsoft.Extensions.Logging; using Pipelines.Sockets.Unofficial; using Pipelines.Sockets.Unofficial.Arenas; + using static StackExchange.Redis.Message; namespace StackExchange.Redis @@ -315,8 +316,8 @@ public void Dispose() RecordConnectionFailed(ConnectionFailureType.ConnectionDisposed); } OnCloseEcho(); - _arena.Dispose(); _reusableFlushSyncTokenSource?.Dispose(); + // ReSharper disable once GCSuppressFinalizeForTypeWithoutDestructor GC.SuppressFinalize(this); } @@ -912,7 +913,7 @@ internal void WriteHeader(RedisCommand command, int arguments, CommandBytes comm internal void RecordQuit() { // don't blame redis if we fired the first shot - Thread.VolatileWrite(ref clientSentQuit, 1); + Volatile.Write(ref clientSentQuit, 1); (_ioPipe as SocketConnection)?.TrySetProtocolShutdown(PipeShutdownKind.ProtocolExitClient); } @@ -1667,351 +1668,6 @@ internal async ValueTask ConnectedAsync(Socket? socket, ILogger? log, Sock } } - private enum PushKind - { - None, - Message, - PMessage, - SMessage, - Subscribe, - PSubscribe, - SSubscribe, - Unsubscribe, - PUnsubscribe, - SUnsubscribe, - } - private PushKind GetPushKind(in Sequence result, out RedisChannel channel) - { - var len = result.Length; - if (len < 2) - { - // for supported cases, we demand at least the kind and the subscription channel - channel = default; - return PushKind.None; - } - - const int MAX_LEN = 16; - Debug.Assert(MAX_LEN >= Enumerable.Max( - [ - PushMessage.Length, PushPMessage.Length, PushSMessage.Length, - PushSubscribe.Length, PushPSubscribe.Length, PushSSubscribe.Length, - PushUnsubscribe.Length, PushPUnsubscribe.Length, PushSUnsubscribe.Length, - ])); - ref readonly RawResult pushKind = ref result[0]; - var multiSegmentPayload = pushKind.Payload; - if (multiSegmentPayload.Length <= MAX_LEN) - { - var span = multiSegmentPayload.IsSingleSegment - ? multiSegmentPayload.First.Span - : CopyTo(stackalloc byte[MAX_LEN], multiSegmentPayload); - - var hash = FastHash.Hash64(span); - RedisChannel.RedisChannelOptions channelOptions = RedisChannel.RedisChannelOptions.None; - PushKind kind; - switch (hash) - { - case PushMessage.Hash when PushMessage.Is(hash, span) & len >= 3: - kind = PushKind.Message; - break; - case PushPMessage.Hash when PushPMessage.Is(hash, span) & len >= 4: - channelOptions = RedisChannel.RedisChannelOptions.Pattern; - kind = PushKind.PMessage; - break; - case PushSMessage.Hash when PushSMessage.Is(hash, span) & len >= 3: - channelOptions = RedisChannel.RedisChannelOptions.Sharded; - kind = PushKind.SMessage; - break; - case PushSubscribe.Hash when PushSubscribe.Is(hash, span): - kind = PushKind.Subscribe; - break; - case PushPSubscribe.Hash when PushPSubscribe.Is(hash, span): - channelOptions = RedisChannel.RedisChannelOptions.Pattern; - kind = PushKind.PSubscribe; - break; - case PushSSubscribe.Hash when PushSSubscribe.Is(hash, span): - channelOptions = RedisChannel.RedisChannelOptions.Sharded; - kind = PushKind.SSubscribe; - break; - case PushUnsubscribe.Hash when PushUnsubscribe.Is(hash, span): - kind = PushKind.Unsubscribe; - break; - case PushPUnsubscribe.Hash when PushPUnsubscribe.Is(hash, span): - channelOptions = RedisChannel.RedisChannelOptions.Pattern; - kind = PushKind.PUnsubscribe; - break; - case PushSUnsubscribe.Hash when PushSUnsubscribe.Is(hash, span): - channelOptions = RedisChannel.RedisChannelOptions.Sharded; - kind = PushKind.SUnsubscribe; - break; - default: - kind = PushKind.None; - break; - } - if (kind != PushKind.None) - { - // the channel is always the second element - channel = result[1].AsRedisChannel(ChannelPrefix, channelOptions); - return kind; - } - } - channel = default; - return PushKind.None; - - static ReadOnlySpan CopyTo(Span target, in ReadOnlySequence source) - { - source.CopyTo(target); - return target.Slice(0, (int)source.Length); - } - } - - [FastHash("message")] - private static partial class PushMessage { } - - [FastHash("pmessage")] - private static partial class PushPMessage { } - - [FastHash("smessage")] - private static partial class PushSMessage { } - - [FastHash("subscribe")] - private static partial class PushSubscribe { } - - [FastHash("psubscribe")] - private static partial class PushPSubscribe { } - - [FastHash("ssubscribe")] - private static partial class PushSSubscribe { } - - [FastHash("unsubscribe")] - private static partial class PushUnsubscribe { } - - [FastHash("punsubscribe")] - private static partial class PushPUnsubscribe { } - - [FastHash("sunsubscribe")] - private static partial class PushSUnsubscribe { } - - private void MatchResult(in RawResult result) - { - // check to see if it could be an out-of-band pubsub message - if ((connectionType == ConnectionType.Subscription && result.Resp2TypeArray == ResultType.Array) || result.Resp3Type == ResultType.Push) - { - var muxer = BridgeCouldBeNull?.Multiplexer; - if (muxer == null) return; - - // out of band message does not match to a queued message - var items = result.GetItems(); - var kind = GetPushKind(items, out var subscriptionChannel); - switch (kind) - { - case PushKind.Message: - case PushKind.SMessage: - _readStatus = kind is PushKind.Message ? ReadStatus.PubSubMessage : ReadStatus.PubSubSMessage; - - // special-case the configuration change broadcasts (we don't keep that in the usual pub/sub registry) - var configChanged = muxer.ConfigurationChangedChannel; - if (configChanged != null && items[1].IsEqual(configChanged)) - { - EndPoint? blame = null; - try - { - if (!items[2].IsEqual(CommonReplies.wildcard)) - { - // We don't want to fail here, just trying to identify - _ = Format.TryParseEndPoint(items[2].GetString(), out blame); - } - } - catch - { - /* no biggie */ - } - - Trace("Configuration changed: " + Format.ToString(blame)); - _readStatus = ReadStatus.Reconfigure; - muxer.ReconfigureIfNeeded(blame, true, "broadcast"); - } - - // invoke the handlers - if (!subscriptionChannel.IsNull) - { - Trace($"{kind}: {subscriptionChannel}"); - if (TryGetPubSubPayload(items[2], out var payload)) - { - _readStatus = ReadStatus.InvokePubSub; - muxer.OnMessage(subscriptionChannel, subscriptionChannel, payload); - } - // could be multi-message: https://github.com/StackExchange/StackExchange.Redis/issues/2507 - else if (TryGetMultiPubSubPayload(items[2], out var payloads)) - { - _readStatus = ReadStatus.InvokePubSub; - muxer.OnMessage(subscriptionChannel, subscriptionChannel, payloads); - } - } - return; // and stop processing - case PushKind.PMessage: - _readStatus = ReadStatus.PubSubPMessage; - - var messageChannel = items[2].AsRedisChannel(ChannelPrefix, RedisChannel.RedisChannelOptions.None); - if (!messageChannel.IsNull) - { - Trace($"{kind}: {messageChannel} via {subscriptionChannel}"); - if (TryGetPubSubPayload(items[3], out var payload)) - { - _readStatus = ReadStatus.InvokePubSub; - muxer.OnMessage(subscriptionChannel, messageChannel, payload); - } - else if (TryGetMultiPubSubPayload(items[3], out var payloads)) - { - _readStatus = ReadStatus.InvokePubSub; - muxer.OnMessage(subscriptionChannel, messageChannel, payloads); - } - } - return; // and stop processing - case PushKind.SUnsubscribe when !PeekChannelMessage(RedisCommand.SUNSUBSCRIBE, subscriptionChannel): - // then it was *unsolicited* - this probably means the slot was migrated - // (otherwise, we'll let the command-processor deal with it) - _readStatus = ReadStatus.PubSubUnsubscribe; - var server = BridgeCouldBeNull?.ServerEndPoint; - if (server is not null && muxer.TryGetSubscription(subscriptionChannel, out var subscription)) - { - // wipe and reconnect; but: to where? - // counter-intuitively, the only server we *know* already knows the new route is: - // the outgoing server, since it had to change to MIGRATING etc; the new INCOMING server - // knows, but *we don't know who that is*, and other nodes: aren't guaranteed to know (yet) - muxer.DefaultSubscriber.ResubscribeToServer(subscription, subscriptionChannel, server, cause: PushSUnsubscribe.Text); - } - return; // and STOP PROCESSING; unsolicited - } - } - Trace("Matching result..."); - - Message? msg = null; - // check whether we're waiting for a high-integrity mode post-response checksum (using cheap null-check first) - if (_awaitingToken is not null && (msg = Interlocked.Exchange(ref _awaitingToken, null)) is not null) - { - _readStatus = ReadStatus.ResponseSequenceCheck; - if (!ProcessHighIntegrityResponseToken(msg, in result, BridgeCouldBeNull)) - { - RecordConnectionFailed(ConnectionFailureType.ResponseIntegrityFailure, origin: nameof(ReadStatus.ResponseSequenceCheck)); - } - return; - } - - _readStatus = ReadStatus.DequeueResult; - lock (_writtenAwaitingResponse) - { - if (msg is not null) - { - _awaitingToken = null; - } - - if (!_writtenAwaitingResponse.TryDequeue(out msg)) - { - throw new InvalidOperationException("Received response with no message waiting: " + result.ToString()); - } - } - _activeMessage = msg; - - Trace("Response to: " + msg); - _readStatus = ReadStatus.ComputeResult; - if (msg.ComputeResult(this, result)) - { - _readStatus = msg.ResultBoxIsAsync ? ReadStatus.CompletePendingMessageAsync : ReadStatus.CompletePendingMessageSync; - if (!msg.IsHighIntegrity) - { - // can't complete yet if needs checksum - msg.Complete(); - } - } - if (msg.IsHighIntegrity) - { - // stash this for the next non-OOB response - Volatile.Write(ref _awaitingToken, msg); - } - - _readStatus = ReadStatus.MatchResultComplete; - _activeMessage = null; - - static bool ProcessHighIntegrityResponseToken(Message message, in RawResult result, PhysicalBridge? bridge) - { - bool isValid = false; - if (result.Resp2TypeBulkString == ResultType.BulkString) - { - var payload = result.Payload; - if (payload.Length == 4) - { - uint interpreted; - if (payload.IsSingleSegment) - { - interpreted = BinaryPrimitives.ReadUInt32LittleEndian(payload.First.Span); - } - else - { - Span span = stackalloc byte[4]; - payload.CopyTo(span); - interpreted = BinaryPrimitives.ReadUInt32LittleEndian(span); - } - isValid = interpreted == message.HighIntegrityToken; - } - } - if (isValid) - { - message.Complete(); - return true; - } - else - { - message.SetExceptionAndComplete(new InvalidOperationException("High-integrity mode detected possible protocol de-sync"), bridge); - return false; - } - } - - static bool TryGetPubSubPayload(in RawResult value, out RedisValue parsed, bool allowArraySingleton = true) - { - if (value.IsNull) - { - parsed = RedisValue.Null; - return true; - } - switch (value.Resp2TypeBulkString) - { - case ResultType.Integer: - case ResultType.SimpleString: - case ResultType.BulkString: - parsed = value.AsRedisValue(); - return true; - case ResultType.Array when allowArraySingleton && value.ItemsCount == 1: - return TryGetPubSubPayload(in value[0], out parsed, allowArraySingleton: false); - } - parsed = default; - return false; - } - - static bool TryGetMultiPubSubPayload(in RawResult value, out Sequence parsed) - { - if (value.Resp2TypeArray == ResultType.Array && value.ItemsCount != 0) - { - parsed = value.GetItems(); - return true; - } - parsed = default; - return false; - } - } - - private bool PeekChannelMessage(RedisCommand command, in RedisChannel channel) - { - Message? msg; - bool haveMsg; - lock (_writtenAwaitingResponse) - { - haveMsg = _writtenAwaitingResponse.TryPeek(out msg); - } - - return haveMsg && msg is CommandChannelBase typed - && typed.Command == command && typed.Channel == channel; - } - private volatile Message? _activeMessage; internal void GetHeadMessages(out Message? now, out Message? next) @@ -2053,122 +1709,6 @@ private void OnDebugAbort() partial void OnWrapForLogging(ref IDuplexPipe pipe, string name, SocketManager mgr); internal void UpdateLastReadTime() => Interlocked.Exchange(ref lastReadTickCount, Environment.TickCount); - private async Task ReadFromPipe() - { - bool allowSyncRead = true, isReading = false; - try - { - _readStatus = ReadStatus.Init; - while (true) - { - var input = _ioPipe?.Input; - if (input == null) break; - - // note: TryRead will give us back the same buffer in a tight loop - // - so: only use that if we're making progress - isReading = true; - _readStatus = ReadStatus.ReadSync; - if (!(allowSyncRead && input.TryRead(out var readResult))) - { - _readStatus = ReadStatus.ReadAsync; - readResult = await input.ReadAsync().ForAwait(); - } - isReading = false; - _readStatus = ReadStatus.UpdateWriteTime; - UpdateLastReadTime(); - - _readStatus = ReadStatus.ProcessBuffer; - var buffer = readResult.Buffer; - int handled = 0; - if (!buffer.IsEmpty) - { - handled = ProcessBuffer(ref buffer); // updates buffer.Start - } - - allowSyncRead = handled != 0; - - _readStatus = ReadStatus.MarkProcessed; - Trace($"Processed {handled} messages"); - input.AdvanceTo(buffer.Start, buffer.End); - - if (handled == 0 && readResult.IsCompleted) - { - break; // no more data, or trailing incomplete messages - } - } - Trace("EOF"); - RecordConnectionFailed(ConnectionFailureType.SocketClosed); - _readStatus = ReadStatus.RanToCompletion; - } - catch (Exception ex) - { - _readStatus = ReadStatus.Faulted; - // this CEX is just a hardcore "seriously, read the actual value" - there's no - // convenient "Thread.VolatileRead(ref T field) where T : class", and I don't - // want to make the field volatile just for this one place that needs it - if (isReading) - { - var pipe = Volatile.Read(ref _ioPipe); - if (pipe == null) - { - return; - // yeah, that's fine... don't worry about it; we nuked it - } - - // check for confusing read errors - no need to present "Reading is not allowed after reader was completed." - if (pipe is SocketConnection sc && sc.ShutdownKind == PipeShutdownKind.ReadEndOfStream) - { - RecordConnectionFailed(ConnectionFailureType.SocketClosed, new EndOfStreamException()); - return; - } - } - Trace("Faulted"); - RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex); - } - } - - private static readonly ArenaOptions s_arenaOptions = new ArenaOptions(); - private readonly Arena _arena = new Arena(s_arenaOptions); - - private int ProcessBuffer(ref ReadOnlySequence buffer) - { - int messageCount = 0; - bytesInBuffer = buffer.Length; - - while (!buffer.IsEmpty) - { - _readStatus = ReadStatus.TryParseResult; - var reader = new BufferReader(buffer); - var result = TryParseResult(_protocol >= RedisProtocol.Resp3, _arena, in buffer, ref reader, IncludeDetailInExceptions, this); - try - { - if (result.HasValue) - { - buffer = reader.SliceFromCurrent(); - - messageCount++; - Trace(result.ToString()); - _readStatus = ReadStatus.MatchResult; - MatchResult(result); - - // Track the last result size *after* processing for the *next* error message - bytesInBuffer = buffer.Length; - bytesLastResult = result.Payload.Length; - } - else - { - break; // remaining buffer isn't enough; give up - } - } - finally - { - _readStatus = ReadStatus.ResetArena; - _arena.Reset(); - } - } - _readStatus = ReadStatus.ProcessBufferComplete; - return messageCount; - } private static RawResult.ResultFlags AsNull(RawResult.ResultFlags flags) => flags & ~RawResult.ResultFlags.NonNull; @@ -2308,10 +1848,6 @@ internal enum ReadStatus PubSubUnsubscribe, NA = -1, } - private volatile ReadStatus _readStatus; - internal ReadStatus GetReadStatus() => _readStatus; - - internal void StartReading() => ReadFromPipe().RedisFireAndForget(); internal static RawResult TryParseResult( bool isResp3, diff --git a/src/StackExchange.Redis/PublicAPI/net6.0/PublicAPI.Shipped.txt b/src/StackExchange.Redis/PublicAPI/net6.0/PublicAPI.Shipped.txt index fae4f65ce..bec516e5c 100644 --- a/src/StackExchange.Redis/PublicAPI/net6.0/PublicAPI.Shipped.txt +++ b/src/StackExchange.Redis/PublicAPI/net6.0/PublicAPI.Shipped.txt @@ -1,4 +1,5 @@ -StackExchange.Redis.ConfigurationOptions.SslClientAuthenticationOptions.get -> System.Func? +#nullable enable +StackExchange.Redis.ConfigurationOptions.SslClientAuthenticationOptions.get -> System.Func? StackExchange.Redis.ConfigurationOptions.SslClientAuthenticationOptions.set -> void System.Runtime.CompilerServices.IsExternalInit (forwarded, contained in System.Runtime) -StackExchange.Redis.ConfigurationOptions.SetUserPemCertificate(string! userCertificatePath, string? userKeyPath = null) -> void \ No newline at end of file +StackExchange.Redis.ConfigurationOptions.SetUserPemCertificate(string! userCertificatePath, string? userKeyPath = null) -> void diff --git a/src/StackExchange.Redis/PublicAPI/net8.0/PublicAPI.Shipped.txt b/src/StackExchange.Redis/PublicAPI/net8.0/PublicAPI.Shipped.txt deleted file mode 100644 index fae4f65ce..000000000 --- a/src/StackExchange.Redis/PublicAPI/net8.0/PublicAPI.Shipped.txt +++ /dev/null @@ -1,4 +0,0 @@ -StackExchange.Redis.ConfigurationOptions.SslClientAuthenticationOptions.get -> System.Func? -StackExchange.Redis.ConfigurationOptions.SslClientAuthenticationOptions.set -> void -System.Runtime.CompilerServices.IsExternalInit (forwarded, contained in System.Runtime) -StackExchange.Redis.ConfigurationOptions.SetUserPemCertificate(string! userCertificatePath, string? userKeyPath = null) -> void \ No newline at end of file diff --git a/src/StackExchange.Redis/PublicAPI/netcoreapp3.1/PublicAPI.Shipped.txt b/src/StackExchange.Redis/PublicAPI/netcoreapp3.1/PublicAPI.Shipped.txt deleted file mode 100644 index 194e1b51b..000000000 --- a/src/StackExchange.Redis/PublicAPI/netcoreapp3.1/PublicAPI.Shipped.txt +++ /dev/null @@ -1,2 +0,0 @@ -StackExchange.Redis.ConfigurationOptions.SslClientAuthenticationOptions.get -> System.Func? -StackExchange.Redis.ConfigurationOptions.SslClientAuthenticationOptions.set -> void \ No newline at end of file diff --git a/src/StackExchange.Redis/RedisSubscriber.cs b/src/StackExchange.Redis/RedisSubscriber.cs index bd2434771..425fb7534 100644 --- a/src/StackExchange.Redis/RedisSubscriber.cs +++ b/src/StackExchange.Redis/RedisSubscriber.cs @@ -94,24 +94,6 @@ internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, i } } - internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, Sequence payload) - { - if (payload.IsSingleSegment) - { - foreach (var message in payload.FirstSpan) - { - OnMessage(subscription, channel, message.AsRedisValue()); - } - } - else - { - foreach (var message in payload) - { - OnMessage(subscription, channel, message.AsRedisValue()); - } - } - } - /// /// Updates all subscriptions re-evaluating their state. /// This clears the current server if it's not connected, prepping them to reconnect. diff --git a/src/StackExchange.Redis/RespReaderExtensions.cs b/src/StackExchange.Redis/RespReaderExtensions.cs new file mode 100644 index 000000000..cb60c8883 --- /dev/null +++ b/src/StackExchange.Redis/RespReaderExtensions.cs @@ -0,0 +1,72 @@ +using System; +using RESPite.Messages; + +namespace StackExchange.Redis; + +internal static class RespReaderExtensions +{ + extension(in RespReader reader) + { + public RespPrefix Resp2PrefixBulkString => reader.Prefix.ToResp2(RespPrefix.BulkString); + // if null, assume array + public RespPrefix Resp2PrefixArray => reader.Prefix.ToResp2(RespPrefix.Array); + + public RedisValue ReadRedisValue() + { + reader.DemandScalar(); + if (reader.IsNull) return RedisValue.Null; + + return reader.Prefix switch + { + RespPrefix.Boolean => reader.ReadBoolean(), + RespPrefix.Integer => reader.ReadInt64(), + _ => reader.ReadByteArray(), + }; + } + + public string OverviewString() + { + if (reader.IsNull) return "(null)"; + + return reader.Resp2PrefixBulkString switch + { + RespPrefix.SimpleString or RespPrefix.Integer or RespPrefix.SimpleError => $"{reader.Prefix}: {reader.ReadString()}", + _ when reader.IsScalar => $"{reader.Prefix}: {reader.ScalarLength()} bytes", + _ when reader.IsAggregate => $"{reader.Prefix}: {reader.AggregateLength()} items", + _ => $"(unknown: {reader.Prefix})", + }; + } + } + + extension(ref RespReader reader) + { + public bool SafeTryMoveNext() => reader.TryMoveNext(checkError: false) & !reader.IsError; + } + + public static RespPrefix GetRespPrefix(ReadOnlySpan frame) + { + var reader = new RespReader(frame); + reader.SafeTryMoveNext(); + return reader.Prefix; + } + + extension(RespPrefix prefix) + { + public RespPrefix ToResp2(RespPrefix nullValue) + { + return prefix switch + { + // null: map to what the caller prefers + RespPrefix.Null => nullValue, + // RESP 3: map to closest RESP 2 equivalent + RespPrefix.Boolean => RespPrefix.Integer, + RespPrefix.Double or RespPrefix.BigInteger => RespPrefix.SimpleString, + RespPrefix.BulkError => RespPrefix.SimpleError, + RespPrefix.VerbatimString => RespPrefix.BulkString, + RespPrefix.Map or RespPrefix.Set or RespPrefix.Push or RespPrefix.Attribute => RespPrefix.Array, + // RESP 2 or anything exotic: leave alone + _ => prefix, + }; + } + } +} diff --git a/src/StackExchange.Redis/ResultProcessor.cs b/src/StackExchange.Redis/ResultProcessor.cs index 926fe8950..dac25f427 100644 --- a/src/StackExchange.Redis/ResultProcessor.cs +++ b/src/StackExchange.Redis/ResultProcessor.cs @@ -11,6 +11,7 @@ using System.Text.RegularExpressions; using Microsoft.Extensions.Logging; using Pipelines.Sockets.Unofficial.Arenas; +using RESPite.Messages; namespace StackExchange.Redis { @@ -222,18 +223,19 @@ public static void SetException(Message? message, Exception ex) box?.SetException(ex); } // true if ready to be completed (i.e. false if re-issued to another server) - public virtual bool SetResult(PhysicalConnection connection, Message message, in RawResult result) + public bool SetResult(PhysicalConnection connection, Message message, ref RespReader reader) { + reader.SafeTryMoveNext(); var bridge = connection.BridgeCouldBeNull; if (message is LoggingMessage logging) { try { - logging.Log?.LogInformationResponse(bridge?.Name, message.CommandAndKey, result); + logging.Log?.LogInformationResponse(bridge?.Name, message.CommandAndKey, reader.OverviewString()); } catch { } } - if (result.IsError) + if (reader.IsError) { if (result.StartsWith(CommonReplies.NOAUTH)) { diff --git a/src/StackExchange.Redis/StackExchange.Redis.csproj b/src/StackExchange.Redis/StackExchange.Redis.csproj index 2c2e7702a..29620d7ab 100644 --- a/src/StackExchange.Redis/StackExchange.Redis.csproj +++ b/src/StackExchange.Redis/StackExchange.Redis.csproj @@ -2,7 +2,7 @@ enable - net461;netstandard2.0;net472;netcoreapp3.1;net6.0;net8.0 + net461;netstandard2.0;net472;net6.0;net8.0;net10.0 High performance Redis client, incorporating both synchronous and asynchronous usage. StackExchange.Redis StackExchange.Redis @@ -41,10 +41,11 @@ - - - - + + + + + @@ -55,5 +56,6 @@ + \ No newline at end of file diff --git a/src/StackExchange.Redis/StreamConfiguration.cs b/src/StackExchange.Redis/StreamConfiguration.cs index 71bbe483e..46e5d0ba3 100644 --- a/src/StackExchange.Redis/StreamConfiguration.cs +++ b/src/StackExchange.Redis/StreamConfiguration.cs @@ -1,4 +1,5 @@ using System.Diagnostics.CodeAnalysis; +using RESPite; namespace StackExchange.Redis; diff --git a/src/StackExchange.Redis/StreamIdempotentId.cs b/src/StackExchange.Redis/StreamIdempotentId.cs index 601890d1f..1ad331eda 100644 --- a/src/StackExchange.Redis/StreamIdempotentId.cs +++ b/src/StackExchange.Redis/StreamIdempotentId.cs @@ -1,5 +1,6 @@ using System; using System.Diagnostics.CodeAnalysis; +using RESPite; namespace StackExchange.Redis; diff --git a/src/StackExchange.Redis/ValueCondition.cs b/src/StackExchange.Redis/ValueCondition.cs index d61a2f00e..c5cf4bd5a 100644 --- a/src/StackExchange.Redis/ValueCondition.cs +++ b/src/StackExchange.Redis/ValueCondition.cs @@ -4,6 +4,7 @@ using System.Diagnostics.CodeAnalysis; using System.IO.Hashing; using System.Runtime.CompilerServices; +using RESPite; namespace StackExchange.Redis; diff --git a/src/StackExchange.Redis/VectorSetAddRequest.cs b/src/StackExchange.Redis/VectorSetAddRequest.cs index 987118c09..8428d6031 100644 --- a/src/StackExchange.Redis/VectorSetAddRequest.cs +++ b/src/StackExchange.Redis/VectorSetAddRequest.cs @@ -1,5 +1,6 @@ using System; using System.Diagnostics.CodeAnalysis; +using RESPite; namespace StackExchange.Redis; diff --git a/src/StackExchange.Redis/VectorSetInfo.cs b/src/StackExchange.Redis/VectorSetInfo.cs index c9277eae5..afbc3fece 100644 --- a/src/StackExchange.Redis/VectorSetInfo.cs +++ b/src/StackExchange.Redis/VectorSetInfo.cs @@ -1,5 +1,6 @@ using System; using System.Diagnostics.CodeAnalysis; +using RESPite; namespace StackExchange.Redis; diff --git a/src/StackExchange.Redis/VectorSetLink.cs b/src/StackExchange.Redis/VectorSetLink.cs index c18e8a95f..5d58a8d7f 100644 --- a/src/StackExchange.Redis/VectorSetLink.cs +++ b/src/StackExchange.Redis/VectorSetLink.cs @@ -1,4 +1,5 @@ using System.Diagnostics.CodeAnalysis; +using RESPite; namespace StackExchange.Redis; diff --git a/src/StackExchange.Redis/VectorSetQuantization.cs b/src/StackExchange.Redis/VectorSetQuantization.cs index d78f4b34b..688f699e9 100644 --- a/src/StackExchange.Redis/VectorSetQuantization.cs +++ b/src/StackExchange.Redis/VectorSetQuantization.cs @@ -1,4 +1,5 @@ using System.Diagnostics.CodeAnalysis; +using RESPite; namespace StackExchange.Redis; diff --git a/src/StackExchange.Redis/VectorSetSimilaritySearchRequest.cs b/src/StackExchange.Redis/VectorSetSimilaritySearchRequest.cs index d0c0fd4cc..1343fd3f1 100644 --- a/src/StackExchange.Redis/VectorSetSimilaritySearchRequest.cs +++ b/src/StackExchange.Redis/VectorSetSimilaritySearchRequest.cs @@ -1,6 +1,7 @@ using System; using System.ComponentModel; using System.Diagnostics.CodeAnalysis; +using RESPite; using VsimFlags = StackExchange.Redis.VectorSetSimilaritySearchMessage.VsimFlags; namespace StackExchange.Redis; diff --git a/src/StackExchange.Redis/VectorSetSimilaritySearchResult.cs b/src/StackExchange.Redis/VectorSetSimilaritySearchResult.cs index fd912898b..e16f91fdb 100644 --- a/src/StackExchange.Redis/VectorSetSimilaritySearchResult.cs +++ b/src/StackExchange.Redis/VectorSetSimilaritySearchResult.cs @@ -1,4 +1,5 @@ using System.Diagnostics.CodeAnalysis; +using RESPite; namespace StackExchange.Redis; From 0bc33ebd7b88845cb312b7d1c727dff37a140e77 Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Sun, 15 Feb 2026 09:40:07 +0000 Subject: [PATCH 03/11] nearly compiles --- src/RESPite/Messages/RespReader.cs | 57 +++ src/RESPite/PublicAPI/PublicAPI.Unshipped.txt | 2 + src/StackExchange.Redis/ClientInfo.cs | 2 +- src/StackExchange.Redis/CommandTrace.cs | 35 +- src/StackExchange.Redis/Condition.cs | 2 +- .../HotKeys.ResultProcessor.cs | 2 +- src/StackExchange.Redis/Message.cs | 13 +- .../PhysicalConnection.Read.cs | 2 +- src/StackExchange.Redis/RedisDatabase.cs | 4 +- src/StackExchange.Redis/RedisLiterals.cs | 25 -- src/StackExchange.Redis/RedisServer.cs | 2 +- src/StackExchange.Redis/RedisTransaction.cs | 15 +- .../RespReaderExtensions.cs | 40 +- .../ResultProcessor.Digest.cs | 2 +- .../ResultProcessor.Lease.cs | 10 +- .../ResultProcessor.Literals.cs | 39 ++ .../ResultProcessor.VectorSets.cs | 2 +- src/StackExchange.Redis/ResultProcessor.cs | 376 ++++++++++-------- .../VectorSetSimilaritySearchMessage.cs | 2 +- 19 files changed, 395 insertions(+), 237 deletions(-) create mode 100644 src/StackExchange.Redis/ResultProcessor.Literals.cs diff --git a/src/RESPite/Messages/RespReader.cs b/src/RESPite/Messages/RespReader.cs index a44ef520a..5a023a397 100644 --- a/src/RESPite/Messages/RespReader.cs +++ b/src/RESPite/Messages/RespReader.cs @@ -1293,6 +1293,13 @@ internal readonly bool IsOK() // go mad with this, because it is used so often public readonly bool Is(ReadOnlySpan value) => TryGetSpan(out var span) ? span.SequenceEqual(value) : IsSlow(value); + /// + /// Indicates whether the current element is a scalar with a value that starts with the provided . + /// + /// The payload value to verify. + public readonly bool StartsWith(ReadOnlySpan value) + => TryGetSpan(out var span) ? span.StartsWith(value) : StartsWithSlow(value); + /// /// Indicates whether the current element is a scalar with a value that matches the provided . /// @@ -1378,6 +1385,42 @@ private readonly bool IsSlow(ReadOnlySpan testValue) } } + private readonly bool StartsWithSlow(ReadOnlySpan testValue) + { + DemandScalar(); + if (IsNull) return false; // nothing equals null + if (testValue.IsEmpty) return true; // every non-null scalar starts-with empty + if (TotalAvailable < testValue.Length) return false; + + if (!IsStreaming && testValue.Length < ScalarLength()) return false; + + var iterator = ScalarChunks(); + while (true) + { + if (testValue.IsEmpty) + { + return true; + } + + if (!iterator.MoveNext()) + { + return false; // test is longer + } + + var current = iterator.Current; + if (testValue.Length <= current.Length) + { + // current fragment exhausts the test data; check it with StartsWith + return testValue.StartsWith(current); + } + + // current fragment is longer than the test data; the overlap must match exactly + if (!current.SequenceEqual(testValue.Slice(0, current.Length))) return false; // payload is different + + testValue = testValue.Slice(current.Length); // validated; continue + } + } + /// /// Copy the current scalar value out into the supplied , or as much as can be copied. /// @@ -1667,7 +1710,21 @@ public readonly T ReadEnum(T unknownValue = default) where T : struct, Enum #endif } + /// + /// Reads an aggregate as an array of elements without changing the position. + /// + /// The type of data to be projected. public TResult[]? ReadArray(Projection projection, bool scalar = false) + { + var copy = this; + return copy.ReadPastArray(projection, scalar); + } + + /// + /// Reads an aggregate as an array of elements, moving past the data as a side effect. + /// + /// The type of data to be projected. + public TResult[]? ReadPastArray(Projection projection, bool scalar = false) { DemandAggregate(); if (IsNull) return null; diff --git a/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt b/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt index 95b06c251..d3340f383 100644 --- a/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt +++ b/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt @@ -109,6 +109,7 @@ static RESPite.Buffers.CycleBuffer.Create(System.Buffers.MemoryPool? pool [SER004]RESPite.Messages.RespReader.ReadInt32() -> int [SER004]RESPite.Messages.RespReader.ReadInt64() -> long [SER004]RESPite.Messages.RespReader.ReadPairArray(RESPite.Messages.RespReader.Projection! first, RESPite.Messages.RespReader.Projection! second, System.Func! combine, bool scalar = true) -> TResult[]? +[SER004]RESPite.Messages.RespReader.ReadPastArray(RESPite.Messages.RespReader.Projection! projection, bool scalar = false) -> TResult[]? [SER004]RESPite.Messages.RespReader.ReadString() -> string? [SER004]RESPite.Messages.RespReader.ReadString(out string! prefix) -> string? [SER004]RESPite.Messages.RespReader.RespReader() -> void @@ -127,6 +128,7 @@ static RESPite.Buffers.CycleBuffer.Create(System.Buffers.MemoryPool? pool [SER004]RESPite.Messages.RespReader.ScalarLength() -> int [SER004]RESPite.Messages.RespReader.ScalarLongLength() -> long [SER004]RESPite.Messages.RespReader.SkipChildren() -> void +[SER004]RESPite.Messages.RespReader.StartsWith(System.ReadOnlySpan value) -> bool [SER004]RESPite.Messages.RespReader.TryGetSpan(out System.ReadOnlySpan value) -> bool [SER004]RESPite.Messages.RespReader.TryMoveNext() -> bool [SER004]RESPite.Messages.RespReader.TryMoveNext(bool checkError) -> bool diff --git a/src/StackExchange.Redis/ClientInfo.cs b/src/StackExchange.Redis/ClientInfo.cs index d743affff..15c3f641a 100644 --- a/src/StackExchange.Redis/ClientInfo.cs +++ b/src/StackExchange.Redis/ClientInfo.cs @@ -289,7 +289,7 @@ private static void AddFlag(ref ClientFlags value, string raw, ClientFlags toAdd private sealed class ClientInfoProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { diff --git a/src/StackExchange.Redis/CommandTrace.cs b/src/StackExchange.Redis/CommandTrace.cs index a61499f0c..ddf966644 100644 --- a/src/StackExchange.Redis/CommandTrace.cs +++ b/src/StackExchange.Redis/CommandTrace.cs @@ -1,4 +1,5 @@ using System; +using RESPite.Messages; namespace StackExchange.Redis { @@ -71,21 +72,33 @@ internal CommandTrace(long uniqueId, long time, long duration, RedisValue[] argu private sealed class CommandTraceProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - switch (result.Resp2TypeArray) + // see: SLOWLOG GET + switch (reader.Resp2PrefixArray) { - case ResultType.Array: - var parts = result.GetItems(); - CommandTrace[] arr = new CommandTrace[parts.Length]; - int i = 0; - foreach (var item in parts) + case RespPrefix.Array: + + static CommandTrace ParseOne(ref RespReader reader) { - var subParts = item.GetItems(); - if (!subParts[0].TryGetInt64(out long uniqueid) || !subParts[1].TryGetInt64(out long time) || !subParts[2].TryGetInt64(out long duration)) - return false; - arr[i++] = new CommandTrace(uniqueid, time, duration, subParts[3].GetItemsAsValues()!); + CommandTrace result = null!; + if (reader.IsAggregate) + { + long uniqueId = 0, time = 0, duration = 0; + if (reader.TryMoveNext() && reader.IsScalar && reader.TryReadInt64(out uniqueId) + && reader.TryMoveNext() && reader.IsScalar && reader.TryReadInt64(out time) + && reader.TryMoveNext() && reader.IsScalar && reader.TryReadInt64(out duration) + && reader.TryMoveNext() && reader.IsAggregate) + { + var values = reader.ReadPastRedisValues() ?? []; + result = new CommandTrace(uniqueId, time, duration, values); + } + } + return result; } + var arr = reader.ReadPastArray(ParseOne, scalar: false)!; + if (arr.AnyNull()) return false; + SetResult(message, arr); return true; } diff --git a/src/StackExchange.Redis/Condition.cs b/src/StackExchange.Redis/Condition.cs index ec7ee53b6..861abbbd2 100644 --- a/src/StackExchange.Redis/Condition.cs +++ b/src/StackExchange.Redis/Condition.cs @@ -388,7 +388,7 @@ public static Message CreateMessage(Condition condition, int db, CommandFlags fl new ConditionMessage(condition, db, flags, command, key, value, value1, value2, value3, value4); [System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0071:Simplify interpolation", Justification = "Allocations (string.Concat vs. string.Format)")] - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { connection?.BridgeCouldBeNull?.Multiplexer?.OnTransactionLog($"condition '{message.CommandAndKey}' got '{result.ToString()}'"); var msg = message as ConditionMessage; diff --git a/src/StackExchange.Redis/HotKeys.ResultProcessor.cs b/src/StackExchange.Redis/HotKeys.ResultProcessor.cs index a0f5b2892..d819e6dee 100644 --- a/src/StackExchange.Redis/HotKeys.ResultProcessor.cs +++ b/src/StackExchange.Redis/HotKeys.ResultProcessor.cs @@ -6,7 +6,7 @@ public sealed partial class HotKeysResult private sealed class HotKeysResultProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.IsNull) { diff --git a/src/StackExchange.Redis/Message.cs b/src/StackExchange.Redis/Message.cs index 441019959..5d5a6f050 100644 --- a/src/StackExchange.Redis/Message.cs +++ b/src/StackExchange.Redis/Message.cs @@ -602,8 +602,14 @@ internal static CommandFlags SetPrimaryReplicaFlags(CommandFlags everything, Com internal void Cancel() => resultBox?.Cancel(); // true if ready to be completed (i.e. false if re-issued to another server) - internal bool ComputeResult(PhysicalConnection connection, in RespReader frame) + internal bool ComputeResult(PhysicalConnection connection, ref RespReader reader) { + // we don't want to mutate reader, so that processors can consume attributes; however, + // we also don't want to force the entire reader to copy each time, so: snapshot + // just the prefix + var prefix = reader.GetFirstPrefix(); + + // intentionally "frame" is an isolated copy var box = resultBox; try { @@ -611,12 +617,11 @@ internal bool ComputeResult(PhysicalConnection connection, in RespReader frame) if (resultProcessor == null) return true; // false here would be things like resends (MOVED) - the message is not yet complete - var mutable = frame; - return resultProcessor.SetResult(connection, this, ref mutable); + return resultProcessor.SetResult(connection, this, ref reader); } catch (Exception ex) { - ex.Data.Add("got", frame.Prefix.ToString()); + ex.Data.Add("got", prefix.ToString()); connection?.BridgeCouldBeNull?.Multiplexer?.OnMessageFaulted(this, ex); box?.SetException(ex); return box != null; // we still want to pulse/complete diff --git a/src/StackExchange.Redis/PhysicalConnection.Read.cs b/src/StackExchange.Redis/PhysicalConnection.Read.cs index b7c976082..b2d790f1b 100644 --- a/src/StackExchange.Redis/PhysicalConnection.Read.cs +++ b/src/StackExchange.Redis/PhysicalConnection.Read.cs @@ -511,7 +511,7 @@ static void Throw(ReadOnlySpan frame) Trace("Response to: " + msg); _readStatus = ReadStatus.ComputeResult; var reader = new RespReader(frame); - if (msg.ComputeResult(this, reader)) + if (msg.ComputeResult(this, ref reader)) { _readStatus = msg.ResultBoxIsAsync ? ReadStatus.CompletePendingMessageAsync : ReadStatus.CompletePendingMessageSync; if (!msg.IsHighIntegrity) diff --git a/src/StackExchange.Redis/RedisDatabase.cs b/src/StackExchange.Redis/RedisDatabase.cs index ac3c14bcc..8c0027d13 100644 --- a/src/StackExchange.Redis/RedisDatabase.cs +++ b/src/StackExchange.Redis/RedisDatabase.cs @@ -5580,7 +5580,7 @@ private abstract class ScanResultProcessor : ResultProcessor Default = new StringGetWithExpiryProcessor(); private StringGetWithExpiryProcessor() { } - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { diff --git a/src/StackExchange.Redis/RedisLiterals.cs b/src/StackExchange.Redis/RedisLiterals.cs index be79b3267..9a8f54971 100644 --- a/src/StackExchange.Redis/RedisLiterals.cs +++ b/src/StackExchange.Redis/RedisLiterals.cs @@ -39,31 +39,6 @@ public static readonly CommandBytes id = "id"; } - internal static partial class CommonRepliesHash - { -#pragma warning disable CS8981, SA1300, SA1134 // forgive naming - // ReSharper disable InconsistentNaming - [FastHash] internal static partial class length { } - [FastHash] internal static partial class radix_tree_keys { } - [FastHash] internal static partial class radix_tree_nodes { } - [FastHash] internal static partial class last_generated_id { } - [FastHash] internal static partial class max_deleted_entry_id { } - [FastHash] internal static partial class entries_added { } - [FastHash] internal static partial class recorded_first_entry_id { } - [FastHash] internal static partial class idmp_duration { } - [FastHash] internal static partial class idmp_maxsize { } - [FastHash] internal static partial class pids_tracked { } - [FastHash] internal static partial class first_entry { } - [FastHash] internal static partial class last_entry { } - [FastHash] internal static partial class groups { } - [FastHash] internal static partial class iids_tracked { } - [FastHash] internal static partial class iids_added { } - [FastHash] internal static partial class iids_duplicates { } - - // ReSharper restore InconsistentNaming -#pragma warning restore CS8981, SA1300, SA1134 // forgive naming - } - internal static class RedisLiterals { // unlike primary commands, these do not get altered by the command-map; we may as diff --git a/src/StackExchange.Redis/RedisServer.cs b/src/StackExchange.Redis/RedisServer.cs index 2d7e184ad..f5b2814a3 100644 --- a/src/StackExchange.Redis/RedisServer.cs +++ b/src/StackExchange.Redis/RedisServer.cs @@ -891,7 +891,7 @@ private protected override Message CreateMessage(in RedisValue cursor) public static readonly ResultProcessor processor = new ScanResultProcessor(); private sealed class ScanResultProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeArray) { diff --git a/src/StackExchange.Redis/RedisTransaction.cs b/src/StackExchange.Redis/RedisTransaction.cs index f0a9600fa..5943238a1 100644 --- a/src/StackExchange.Redis/RedisTransaction.cs +++ b/src/StackExchange.Redis/RedisTransaction.cs @@ -3,6 +3,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using RESPite.Messages; namespace StackExchange.Redis { @@ -201,7 +202,7 @@ private sealed class QueuedProcessor : ResultProcessor { public static readonly ResultProcessor Default = new QueuedProcessor(); - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeBulkString == ResultType.SimpleString && result.IsEqual(CommonReplies.QUEUED)) { @@ -469,11 +470,13 @@ private sealed class TransactionProcessor : ResultProcessor { public static readonly TransactionProcessor Default = new(); - public override bool SetResult(PhysicalConnection connection, Message message, in RawResult result) + public override bool SetResult(PhysicalConnection connection, Message message, ref RespReader reader) { - if (result.IsError && message is TransactionMessage tran) + var copy = reader; + reader.MovePastBof(); + if (reader.IsError && message is TransactionMessage tran) { - string error = result.GetString()!; + string error = reader.ReadString()!; foreach (var op in tran.InnerOperations) { var inner = op.Wrapped; @@ -481,10 +484,10 @@ public override bool SetResult(PhysicalConnection connection, Message message, i inner.Complete(); } } - return base.SetResult(connection, message, result); + return base.SetResult(connection, message, ref copy); } - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { var muxer = connection.BridgeCouldBeNull?.Multiplexer; muxer?.OnTransactionLog($"got {result} for {message.CommandAndKey}"); diff --git a/src/StackExchange.Redis/RespReaderExtensions.cs b/src/StackExchange.Redis/RespReaderExtensions.cs index cb60c8883..6a0f93be1 100644 --- a/src/StackExchange.Redis/RespReaderExtensions.cs +++ b/src/StackExchange.Redis/RespReaderExtensions.cs @@ -8,9 +8,13 @@ internal static class RespReaderExtensions extension(in RespReader reader) { public RespPrefix Resp2PrefixBulkString => reader.Prefix.ToResp2(RespPrefix.BulkString); - // if null, assume array public RespPrefix Resp2PrefixArray => reader.Prefix.ToResp2(RespPrefix.Array); + [Obsolete("Use Resp2PrefixBulkString instead", error: true)] + public RespPrefix Resp2TypeBulkString => reader.Resp2PrefixBulkString; + [Obsolete("Use Resp2PrefixArray instead", error: true)] + public RespPrefix Resp2TypeArray => reader.Resp2PrefixArray; + public RedisValue ReadRedisValue() { reader.DemandScalar(); @@ -36,11 +40,32 @@ public string OverviewString() _ => $"(unknown: {reader.Prefix})", }; } + + public RespPrefix GetFirstPrefix() + { + var prefix = reader.Prefix; + if (prefix is RespPrefix.None) + { + var mutable = reader; + mutable.MovePastBof(); + prefix = mutable.Prefix; + } + return prefix; + } } extension(ref RespReader reader) { public bool SafeTryMoveNext() => reader.TryMoveNext(checkError: false) & !reader.IsError; + + public void MovePastBof() + { + // if we're at BOF, read the first element, ignoring errors + if (reader.Prefix is RespPrefix.None) reader.SafeTryMoveNext(); + } + + public RedisValue[]? ReadPastRedisValues() + => reader.ReadPastArray(static (ref r) => r.ReadRedisValue(), scalar: true); } public static RespPrefix GetRespPrefix(ReadOnlySpan frame) @@ -69,4 +94,17 @@ public RespPrefix ToResp2(RespPrefix nullValue) }; } } + + extension(T?[] array) where T : class + { + internal bool AnyNull() + { + foreach (var el in array) + { + if (el is null) return true; + } + + return false; + } + } } diff --git a/src/StackExchange.Redis/ResultProcessor.Digest.cs b/src/StackExchange.Redis/ResultProcessor.Digest.cs index 757009ea5..1546fa7b9 100644 --- a/src/StackExchange.Redis/ResultProcessor.Digest.cs +++ b/src/StackExchange.Redis/ResultProcessor.Digest.cs @@ -11,7 +11,7 @@ internal abstract partial class ResultProcessor private sealed class DigestProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.IsNull) // for example, key doesn't exist { diff --git a/src/StackExchange.Redis/ResultProcessor.Lease.cs b/src/StackExchange.Redis/ResultProcessor.Lease.cs index c0f9e6d8e..e86a05c63 100644 --- a/src/StackExchange.Redis/ResultProcessor.Lease.cs +++ b/src/StackExchange.Redis/ResultProcessor.Lease.cs @@ -17,7 +17,7 @@ public static readonly ResultProcessor> private abstract class LeaseProcessor : ResultProcessor?> { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray != ResultType.Array) { @@ -56,7 +56,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private abstract class InterleavedLeaseProcessor : ResultProcessor?> { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray != ResultType.Array) { @@ -120,7 +120,7 @@ protected virtual bool TryReadOne(in RawResult result, out T value) return false; } - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray != ResultType.Array) { @@ -183,7 +183,7 @@ protected override bool TryParse(in RawResult raw, out float parsed) private sealed class LeaseProcessor : ResultProcessor> { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -199,7 +199,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class LeaseFromArrayProcessor : ResultProcessor> { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { diff --git a/src/StackExchange.Redis/ResultProcessor.Literals.cs b/src/StackExchange.Redis/ResultProcessor.Literals.cs new file mode 100644 index 000000000..79ea83150 --- /dev/null +++ b/src/StackExchange.Redis/ResultProcessor.Literals.cs @@ -0,0 +1,39 @@ +namespace StackExchange.Redis; + +internal partial class ResultProcessor +{ + internal partial class Literals + { +#pragma warning disable CS8981, SA1300, SA1134 // forgive naming etc + // ReSharper disable InconsistentNaming + [FastHash] internal static partial class NOAUTH { } + [FastHash] internal static partial class WRONGPASS { } + [FastHash] internal static partial class NOSCRIPT { } + [FastHash] internal static partial class MOVED { } + [FastHash] internal static partial class ASK { } + [FastHash] internal static partial class READONLY { } + [FastHash] internal static partial class LOADING { } + [FastHash("ERR operation not permitted")] + internal static partial class ERR_not_permitted { } + + [FastHash] internal static partial class length { } + [FastHash] internal static partial class radix_tree_keys { } + [FastHash] internal static partial class radix_tree_nodes { } + [FastHash] internal static partial class last_generated_id { } + [FastHash] internal static partial class max_deleted_entry_id { } + [FastHash] internal static partial class entries_added { } + [FastHash] internal static partial class recorded_first_entry_id { } + [FastHash] internal static partial class idmp_duration { } + [FastHash] internal static partial class idmp_maxsize { } + [FastHash] internal static partial class pids_tracked { } + [FastHash] internal static partial class first_entry { } + [FastHash] internal static partial class last_entry { } + [FastHash] internal static partial class groups { } + [FastHash] internal static partial class iids_tracked { } + [FastHash] internal static partial class iids_added { } + [FastHash] internal static partial class iids_duplicates { } + + // ReSharper restore InconsistentNaming +#pragma warning restore CS8981, SA1300, SA1134 // forgive naming etc + } +} diff --git a/src/StackExchange.Redis/ResultProcessor.VectorSets.cs b/src/StackExchange.Redis/ResultProcessor.VectorSets.cs index 8743ebd0b..70f548264 100644 --- a/src/StackExchange.Redis/ResultProcessor.VectorSets.cs +++ b/src/StackExchange.Redis/ResultProcessor.VectorSets.cs @@ -45,7 +45,7 @@ protected override bool TryReadOne(in RawResult result, out RedisValue value) private sealed partial class VectorSetInfoProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray == ResultType.Array) { diff --git a/src/StackExchange.Redis/ResultProcessor.cs b/src/StackExchange.Redis/ResultProcessor.cs index dac25f427..2d18d7ebe 100644 --- a/src/StackExchange.Redis/ResultProcessor.cs +++ b/src/StackExchange.Redis/ResultProcessor.cs @@ -223,9 +223,9 @@ public static void SetException(Message? message, Exception ex) box?.SetException(ex); } // true if ready to be completed (i.e. false if re-issued to another server) - public bool SetResult(PhysicalConnection connection, Message message, ref RespReader reader) + public virtual bool SetResult(PhysicalConnection connection, Message message, ref RespReader reader) { - reader.SafeTryMoveNext(); + reader.MovePastBof(); var bridge = connection.BridgeCouldBeNull; if (message is LoggingMessage logging) { @@ -233,117 +233,133 @@ public bool SetResult(PhysicalConnection connection, Message message, ref RespRe { logging.Log?.LogInformationResponse(bridge?.Name, message.CommandAndKey, reader.OverviewString()); } - catch { } + catch (Exception ex) + { + Debug.WriteLine(ex.Message); + } } if (reader.IsError) { - if (result.StartsWith(CommonReplies.NOAUTH)) - { - bridge?.Multiplexer.SetAuthSuspect(new RedisServerException("NOAUTH Returned - connection has not yet authenticated")); - } - else if (result.StartsWith(CommonReplies.WRONGPASS)) - { - bridge?.Multiplexer.SetAuthSuspect(new RedisServerException(result.ToString())); - } + return HandleCommonError(message, reader, bridge); + } - var server = bridge?.ServerEndPoint; - bool log = !message.IsInternalCall; - bool isMoved = result.StartsWith(CommonReplies.MOVED); - bool wasNoRedirect = (message.Flags & CommandFlags.NoRedirect) != 0; - string? err = string.Empty; - bool unableToConnectError = false; - if (isMoved || result.StartsWith(CommonReplies.ASK)) - { - message.SetResponseReceived(); + var copy = reader; + if (SetResultCore(connection, message, ref reader)) + { + bridge?.Multiplexer.Trace("Completed with success: " + copy.OverviewString() + " (" + GetType().Name + ")", ToString()); + } + else + { + UnexpectedResponse(message, in copy); + } + return true; + } + + private bool HandleCommonError(Message message, RespReader reader, PhysicalBridge? bridge) + { + if (reader.StartsWith(Literals.NOAUTH.U8)) + { + bridge?.Multiplexer.SetAuthSuspect(new RedisServerException("NOAUTH Returned - connection has not yet authenticated")); + } + else if (reader.StartsWith(Literals.WRONGPASS.U8)) + { + bridge?.Multiplexer.SetAuthSuspect(new RedisServerException(reader.OverviewString())); + } + + var server = bridge?.ServerEndPoint; + bool log = !message.IsInternalCall; + bool isMoved = reader.StartsWith(Literals.MOVED.U8); + bool wasNoRedirect = (message.Flags & CommandFlags.NoRedirect) != 0; + string? err = string.Empty; + bool unableToConnectError = false; + if (isMoved || reader.StartsWith(Literals.ASK.U8)) + { + message.SetResponseReceived(); - log = false; - string[] parts = result.GetString()!.Split(StringSplits.Space, 3); - if (Format.TryParseInt32(parts[1], out int hashSlot) - && Format.TryParseEndPoint(parts[2], out var endpoint)) + log = false; + string[] parts = reader.ReadString()!.Split(StringSplits.Space, 3); + if (Format.TryParseInt32(parts[1], out int hashSlot) + && Format.TryParseEndPoint(parts[2], out var endpoint)) + { + // no point sending back to same server, and no point sending to a dead server + if (!Equals(server?.EndPoint, endpoint)) { - // no point sending back to same server, and no point sending to a dead server - if (!Equals(server?.EndPoint, endpoint)) + if (bridge is null) { - if (bridge is null) - { - // already toast - } - else if (bridge.Multiplexer.TryResend(hashSlot, message, endpoint, isMoved)) + // already toast + } + else if (bridge.Multiplexer.TryResend(hashSlot, message, endpoint, isMoved)) + { + bridge.Multiplexer.Trace(message.Command + " re-issued to " + endpoint, isMoved ? "MOVED" : "ASK"); + return false; + } + else + { + if (isMoved && wasNoRedirect) { - bridge.Multiplexer.Trace(message.Command + " re-issued to " + endpoint, isMoved ? "MOVED" : "ASK"); - return false; + if (bridge.Multiplexer.RawConfig.IncludeDetailInExceptions) + { + err = $"Key has MOVED to Endpoint {endpoint} and hashslot {hashSlot} but CommandFlags.NoRedirect was specified - redirect not followed for {message.CommandAndKey}. "; + } + else + { + err = "Key has MOVED but CommandFlags.NoRedirect was specified - redirect not followed. "; + } } else { - if (isMoved && wasNoRedirect) + unableToConnectError = true; + if (bridge.Multiplexer.RawConfig.IncludeDetailInExceptions) { - if (bridge.Multiplexer.RawConfig.IncludeDetailInExceptions) - { - err = $"Key has MOVED to Endpoint {endpoint} and hashslot {hashSlot} but CommandFlags.NoRedirect was specified - redirect not followed for {message.CommandAndKey}. "; - } - else - { - err = "Key has MOVED but CommandFlags.NoRedirect was specified - redirect not followed. "; - } + err = $"Endpoint {endpoint} serving hashslot {hashSlot} is not reachable at this point of time. Please check connectTimeout value. If it is low, try increasing it to give the ConnectionMultiplexer a chance to recover from the network disconnect. " + + PerfCounterHelper.GetThreadPoolAndCPUSummary(); } else { - unableToConnectError = true; - if (bridge.Multiplexer.RawConfig.IncludeDetailInExceptions) - { - err = $"Endpoint {endpoint} serving hashslot {hashSlot} is not reachable at this point of time. Please check connectTimeout value. If it is low, try increasing it to give the ConnectionMultiplexer a chance to recover from the network disconnect. " - + PerfCounterHelper.GetThreadPoolAndCPUSummary(); - } - else - { - err = "Endpoint is not reachable at this point of time. Please check connectTimeout value. If it is low, try increasing it to give the ConnectionMultiplexer a chance to recover from the network disconnect. "; - } + err = "Endpoint is not reachable at this point of time. Please check connectTimeout value. If it is low, try increasing it to give the ConnectionMultiplexer a chance to recover from the network disconnect. "; } } } } } + } - if (string.IsNullOrWhiteSpace(err)) - { - err = result.GetString()!; - } + if (string.IsNullOrWhiteSpace(err)) + { + err = reader.ReadString()!; + } - if (log && server != null) - { - bridge?.Multiplexer.OnErrorMessage(server.EndPoint, err); - } - bridge?.Multiplexer.Trace("Completed with error: " + err + " (" + GetType().Name + ")", ToString()); - if (unableToConnectError) - { - ConnectionFail(message, ConnectionFailureType.UnableToConnect, err); - } - else - { - ServerFail(message, err); - } + if (log && server != null) + { + bridge?.Multiplexer.OnErrorMessage(server.EndPoint, err); + } + bridge?.Multiplexer.Trace("Completed with error: " + err + " (" + GetType().Name + ")", ToString()); + if (unableToConnectError) + { + ConnectionFail(message, ConnectionFailureType.UnableToConnect, err); } else { - bool coreResult = SetResultCore(connection, message, result); - if (coreResult) - { - bridge?.Multiplexer.Trace("Completed with success: " + result.ToString() + " (" + GetType().Name + ")", ToString()); - } - else - { - UnexpectedResponse(message, result); - } + ServerFail(message, err); } + return true; } - protected abstract bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result); + protected virtual bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) + { + // temp hack so we can compile; this should be abstract + return false; + } - private void UnexpectedResponse(Message message, in RawResult result) + // temp hack so we can compile; this should be removed + protected virtual bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + => throw new NotImplementedException(GetType().Name + "." + nameof(SetResultCore)); + + private void UnexpectedResponse(Message message, in RespReader reader) { ConnectionMultiplexer.TraceWithoutContext("From " + GetType().Name, "Unexpected Response"); - ConnectionFail(message, ConnectionFailureType.ProtocolFailure, "Unexpected response to " + (message?.CommandString ?? "n/a") + ": " + result.ToString()); + ConnectionFail(message, ConnectionFailureType.ProtocolFailure, "Unexpected response to " + (message?.CommandString ?? "n/a") + ": " + reader.OverviewString()); } public sealed class TimeSpanProcessor : ResultProcessor @@ -385,7 +401,7 @@ public bool TryParse(in RawResult result, out TimeSpan? expiry) return false; } - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (TryParse(result, out TimeSpan? expiry)) { @@ -403,7 +419,7 @@ public sealed class TimingProcessor : ResultProcessor public static TimerMessage CreateMessage(int db, CommandFlags flags, RedisCommand command, RedisValue value = default) => new TimerMessage(db, flags, command, value); - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.IsError) { @@ -461,7 +477,7 @@ public sealed class TrackSubscriptionsProcessor : ResultProcessor private ConnectionMultiplexer.Subscription? Subscription { get; } public TrackSubscriptionsProcessor(ConnectionMultiplexer.Subscription? sub) => Subscription = sub; - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray == ResultType.Array) { @@ -519,7 +535,7 @@ public static bool TryGet(in RawResult result, out bool value) return false; } - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (TryGet(result, out bool value)) { @@ -572,7 +588,7 @@ static int FromHex(char c) // note that top-level error messages still get handled by SetResult, but nested errors // (is that a thing?) will be wrapped in the RedisResult - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -618,7 +634,7 @@ public static bool TryParse(in RawResult result, out SortedSetEntry? entry) } } - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (TryParse(result, out SortedSetEntry? entry)) { @@ -637,7 +653,7 @@ protected override SortedSetEntry Parse(in RawResult first, in RawResult second, internal sealed class SortedSetPopResultProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray == ResultType.Array) { @@ -658,7 +674,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes internal sealed class ListPopResultProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray == ResultType.Array) { @@ -799,7 +815,7 @@ public bool TryParse(in RawResult result, out T[]? pairs, bool allowOversized, o } protected abstract T Parse(in RawResult first, in RawResult second, object? state); - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (TryParse(result, out T[]? arr)) { @@ -815,9 +831,11 @@ internal sealed class AutoConfigureProcessor : ResultProcessor private ILogger? Log { get; } public AutoConfigureProcessor(ILogger? log = null) => Log = log; - public override bool SetResult(PhysicalConnection connection, Message message, in RawResult result) + public override bool SetResult(PhysicalConnection connection, Message message, ref RespReader reader) { - if (result.IsError && result.StartsWith(CommonReplies.READONLY)) + var copy = reader; + reader.MovePastBof(); + if (reader.IsError && reader.StartsWith(Literals.READONLY.U8)) { var bridge = connection.BridgeCouldBeNull; if (bridge != null) @@ -828,10 +846,10 @@ public override bool SetResult(PhysicalConnection connection, Message message, i } } - return base.SetResult(connection, message, result); + return base.SetResult(connection, message, ref copy); } - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { var server = connection.BridgeCouldBeNull?.ServerEndPoint; if (server == null) return false; @@ -1072,7 +1090,7 @@ private static bool TryParseRole(string? val, out bool isReplica) private sealed class BooleanProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.IsNull) { @@ -1110,7 +1128,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class ByteArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1133,7 +1151,7 @@ internal static ClusterConfiguration Parse(PhysicalConnection connection, string return config; } - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1156,7 +1174,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class ClusterNodesRawProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1181,7 +1199,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class ConnectionIdentityProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (connection.BridgeCouldBeNull is PhysicalBridge bridge) { @@ -1194,7 +1212,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class DateTimeProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { long unixTime; switch (result.Resp2TypeArray) @@ -1239,7 +1257,7 @@ public sealed class NullableDateTimeProcessor : ResultProcessor private readonly bool isMilliseconds; public NullableDateTimeProcessor(bool fromMilliseconds) => isMilliseconds = fromMilliseconds; - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1264,7 +1282,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class DoubleProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1298,7 +1316,7 @@ public ExpectBasicStringProcessor(CommandBytes expected, bool startsWith = false _startsWith = startsWith; } - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (_startsWith ? result.StartsWith(_expected) : result.IsEqual(_expected)) { @@ -1312,7 +1330,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class InfoProcessor : ResultProcessor>[]> { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeBulkString == ResultType.BulkString) { @@ -1355,7 +1373,7 @@ private sealed class Int64DefaultValueProcessor : ResultProcessor public Int64DefaultValueProcessor(long defaultValue) => _defaultValue = defaultValue; - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.IsNull) { @@ -1373,7 +1391,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private class Int64Processor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1393,7 +1411,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private class Int32Processor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1422,7 +1440,7 @@ private sealed class Int32EnumProcessor : ResultProcessor where T : unmana private Int32EnumProcessor() { } public static readonly Int32EnumProcessor Instance = new(); - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1456,7 +1474,7 @@ private sealed class Int32EnumArrayProcessor : ResultProcessor where T : private Int32EnumArrayProcessor() { } public static readonly Int32EnumArrayProcessor Instance = new(); - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeArray) { @@ -1484,7 +1502,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class PubSubNumSubProcessor : Int64Processor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray == ResultType.Array) { @@ -1501,7 +1519,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class NullableDoubleArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray == ResultType.Array && !result.IsNull) { @@ -1515,7 +1533,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class NullableDoubleProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1540,7 +1558,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class NullableInt64Processor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1576,7 +1594,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class ExpireResultArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray == ResultType.Array || result.IsNull) { @@ -1591,7 +1609,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class PersistResultArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray == ResultType.Array || result.IsNull) { @@ -1622,7 +1640,7 @@ public ChannelState(byte[]? prefix, RedisChannel.RedisChannelOptions options) Options = options; } } - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeArray) { @@ -1640,7 +1658,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class RedisKeyArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeArray) { @@ -1655,7 +1673,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class RedisKeyProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1671,7 +1689,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class RedisTypeProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1690,7 +1708,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class RedisValueArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1713,7 +1731,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class Int64ArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray == ResultType.Array && !result.IsNull) { @@ -1728,7 +1746,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class NullableStringArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeArray) { @@ -1744,7 +1762,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class StringArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeArray) { @@ -1759,7 +1777,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class BooleanArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray == ResultType.Array && !result.IsNull) { @@ -1773,7 +1791,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class RedisValueGeoPositionProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeArray) { @@ -1789,7 +1807,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class RedisValueGeoPositionArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeArray) { @@ -1826,7 +1844,7 @@ private GeoRadiusResultArrayProcessor(GeoRadiusOptions options) this.options = options; } - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeArray) { @@ -1889,7 +1907,7 @@ The geohash integer. /// private sealed class LongestCommonSubsequenceProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeArray) { @@ -1923,7 +1941,7 @@ private static LCSMatchResult Parse(in RawResult result) private sealed class RedisValueProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1939,7 +1957,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class RedisValueFromArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -1958,7 +1976,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class RoleProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { var items = result.GetItems(); if (items.IsEmpty) @@ -2087,20 +2105,22 @@ private static bool TryParsePrimaryReplica(in Sequence items, out Rol private sealed class ScriptResultProcessor : ResultProcessor { - public override bool SetResult(PhysicalConnection connection, Message message, in RawResult result) + public override bool SetResult(PhysicalConnection connection, Message message, ref RespReader reader) { - if (result.IsError && result.StartsWith(CommonReplies.NOSCRIPT)) + var copy = reader; + reader.MovePastBof(); + if (reader.IsError && reader.StartsWith(Literals.NOSCRIPT.U8)) { // scripts are not flushed individually, so assume the entire script cache is toast ("SCRIPT FLUSH") connection.BridgeCouldBeNull?.ServerEndPoint?.FlushScriptCache(); message.SetScriptUnavailable(); } // and apply usual processing for the rest - return base.SetResult(connection, message, result); + return base.SetResult(connection, message, ref copy); } // note that top-level error messages still get handled by SetResult, but nested errors // (is that a thing?) will be wrapped in the RedisResult - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (RedisResult.TryCreate(connection, result, out var value)) { @@ -2123,7 +2143,7 @@ public SingleStreamProcessor(bool skipStreamName = false) /// /// Handles . /// - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.IsNull) { @@ -2230,7 +2250,7 @@ Multibulk array. (note that XREADGROUP may include additional interior elements; see ParseRedisStreamEntries) */ - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.IsNull) { @@ -2287,7 +2307,7 @@ protected override RedisStream Parse(in RawResult first, in RawResult second, ob /// internal sealed class StreamAutoClaimProcessor : StreamProcessorBase { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { // See https://redis.io/commands/xautoclaim for command documentation. // Note that the result should never be null, so intentionally treating it as a failure to parse here @@ -2316,7 +2336,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes /// internal sealed class StreamAutoClaimIdsOnlyProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { // See https://redis.io/commands/xautoclaim for command documentation. // Note that the result should never be null, so intentionally treating it as a failure to parse here @@ -2492,7 +2512,7 @@ internal abstract class InterleavedStreamInfoProcessorBase : ResultProcessor< { protected abstract T ParseItem(in RawResult result); - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray != ResultType.Array) { @@ -2529,7 +2549,7 @@ internal sealed class StreamInfoProcessor : StreamProcessorBase // 12) 1) 1526569544280-0 // 2) 1) "message" // 2) "banana" - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray != ResultType.Array) { @@ -2554,54 +2574,54 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes var hash = key.Payload.Hash64(); switch (hash) { - case CommonRepliesHash.length.Hash when CommonRepliesHash.length.Is(hash, key): + case Literals.length.Hash when Literals.length.Is(hash, key): if (!value.TryGetInt64(out length)) return false; break; - case CommonRepliesHash.radix_tree_keys.Hash when CommonRepliesHash.radix_tree_keys.Is(hash, key): + case Literals.radix_tree_keys.Hash when Literals.radix_tree_keys.Is(hash, key): if (!value.TryGetInt64(out radixTreeKeys)) return false; break; - case CommonRepliesHash.radix_tree_nodes.Hash when CommonRepliesHash.radix_tree_nodes.Is(hash, key): + case Literals.radix_tree_nodes.Hash when Literals.radix_tree_nodes.Is(hash, key): if (!value.TryGetInt64(out radixTreeNodes)) return false; break; - case CommonRepliesHash.groups.Hash when CommonRepliesHash.groups.Is(hash, key): + case Literals.groups.Hash when Literals.groups.Is(hash, key): if (!value.TryGetInt64(out groups)) return false; break; - case CommonRepliesHash.last_generated_id.Hash when CommonRepliesHash.last_generated_id.Is(hash, key): + case Literals.last_generated_id.Hash when Literals.last_generated_id.Is(hash, key): lastGeneratedId = value.AsRedisValue(); break; - case CommonRepliesHash.first_entry.Hash when CommonRepliesHash.first_entry.Is(hash, key): + case Literals.first_entry.Hash when Literals.first_entry.Is(hash, key): firstEntry = ParseRedisStreamEntry(value); break; - case CommonRepliesHash.last_entry.Hash when CommonRepliesHash.last_entry.Is(hash, key): + case Literals.last_entry.Hash when Literals.last_entry.Is(hash, key): lastEntry = ParseRedisStreamEntry(value); break; // 7.0 - case CommonRepliesHash.max_deleted_entry_id.Hash when CommonRepliesHash.max_deleted_entry_id.Is(hash, key): + case Literals.max_deleted_entry_id.Hash when Literals.max_deleted_entry_id.Is(hash, key): maxDeletedEntryId = value.AsRedisValue(); break; - case CommonRepliesHash.recorded_first_entry_id.Hash when CommonRepliesHash.recorded_first_entry_id.Is(hash, key): + case Literals.recorded_first_entry_id.Hash when Literals.recorded_first_entry_id.Is(hash, key): recordedFirstEntryId = value.AsRedisValue(); break; - case CommonRepliesHash.entries_added.Hash when CommonRepliesHash.entries_added.Is(hash, key): + case Literals.entries_added.Hash when Literals.entries_added.Is(hash, key): if (!value.TryGetInt64(out entriesAdded)) return false; break; // 8.6 - case CommonRepliesHash.idmp_duration.Hash when CommonRepliesHash.idmp_duration.Is(hash, key): + case Literals.idmp_duration.Hash when Literals.idmp_duration.Is(hash, key): if (!value.TryGetInt64(out idmpDuration)) return false; break; - case CommonRepliesHash.idmp_maxsize.Hash when CommonRepliesHash.idmp_maxsize.Is(hash, key): + case Literals.idmp_maxsize.Hash when Literals.idmp_maxsize.Is(hash, key): if (!value.TryGetInt64(out idmpMaxsize)) return false; break; - case CommonRepliesHash.pids_tracked.Hash when CommonRepliesHash.pids_tracked.Is(hash, key): + case Literals.pids_tracked.Hash when Literals.pids_tracked.Is(hash, key): if (!value.TryGetInt64(out pidsTracked)) return false; break; - case CommonRepliesHash.iids_tracked.Hash when CommonRepliesHash.iids_tracked.Is(hash, key): + case Literals.iids_tracked.Hash when Literals.iids_tracked.Is(hash, key): if (!value.TryGetInt64(out iidsTracked)) return false; break; - case CommonRepliesHash.iids_added.Hash when CommonRepliesHash.iids_added.Is(hash, key): + case Literals.iids_added.Hash when Literals.iids_added.Is(hash, key): if (!value.TryGetInt64(out iidsAdded)) return false; break; - case CommonRepliesHash.iids_duplicates.Hash when CommonRepliesHash.iids_duplicates.Is(hash, key): + case Literals.iids_duplicates.Hash when Literals.iids_duplicates.Is(hash, key): if (!value.TryGetInt64(out iidsDuplicates)) return false; break; } @@ -2632,7 +2652,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes internal sealed class StreamPendingInfoProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { // Example: // > XPENDING mystream mygroup @@ -2684,7 +2704,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes internal sealed class StreamPendingMessagesProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (result.Resp2TypeArray != ResultType.Array) { @@ -2788,7 +2808,7 @@ protected override KeyValuePair Parse(in RawResult first, in Raw private sealed class StringProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -2812,7 +2832,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class TieBreakerProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeBulkString) { @@ -2845,23 +2865,29 @@ public TracerProcessor(bool establishConnection) this.establishConnection = establishConnection; } - public override bool SetResult(PhysicalConnection connection, Message message, in RawResult result) + public override bool SetResult(PhysicalConnection connection, Message message, ref RespReader reader) { - connection.BridgeCouldBeNull?.Multiplexer.OnInfoMessage($"got '{result}' for '{message.CommandAndKey}' on '{connection}'"); - var final = base.SetResult(connection, message, result); - if (result.IsError) + reader.MovePastBof(); + bool isError = reader.IsError; + var copy = reader; + + connection.BridgeCouldBeNull?.Multiplexer.OnInfoMessage($"got '{reader.Prefix}' for '{message.CommandAndKey}' on '{connection}'"); + var final = base.SetResult(connection, message, ref reader); + + if (isError) { - if (result.StartsWith(CommonReplies.authFail_trimmed) || result.StartsWith(CommonReplies.NOAUTH)) + reader = copy; // rewind and re-parse + if (reader.StartsWith(Literals.ERR_not_permitted.U8) || reader.StartsWith(Literals.NOAUTH.U8)) { - connection.RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure, new Exception(result.ToString() + " Verify if the Redis password provided is correct. Attempted command: " + message.Command)); + connection.RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure, new Exception(reader.OverviewString() + " Verify if the Redis password provided is correct. Attempted command: " + message.Command)); } - else if (result.StartsWith(CommonReplies.loading)) + else if (reader.StartsWith(Literals.LOADING.U8)) { connection.RecordConnectionFailed(ConnectionFailureType.Loading); } else { - connection.RecordConnectionFailed(ConnectionFailureType.ProtocolFailure, new RedisServerException(result.ToString())); + connection.RecordConnectionFailed(ConnectionFailureType.ProtocolFailure, new RedisServerException(reader.OverviewString())); } } @@ -2875,7 +2901,7 @@ public override bool SetResult(PhysicalConnection connection, Message message, i } [System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0071:Simplify interpolation", Justification = "Allocations (string.Concat vs. string.Format)")] - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { bool happy; switch (message.Command) @@ -2932,7 +2958,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class SentinelGetPrimaryAddressByNameProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeArray) { @@ -2960,7 +2986,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class SentinelGetSentinelAddressesProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { List endPoints = []; @@ -2994,7 +3020,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class SentinelGetReplicaAddressesProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { List endPoints = []; @@ -3033,7 +3059,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class SentinelArrayOfArraysProcessor : ResultProcessor[][]> { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { if (StringPairInterleaved is not StringPairInterleavedProcessor innerProcessor) { @@ -3081,7 +3107,7 @@ protected static void SetResult(Message? message, T value) internal abstract class ArrayResultProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { switch (result.Resp2TypeArray) { diff --git a/src/StackExchange.Redis/VectorSetSimilaritySearchMessage.cs b/src/StackExchange.Redis/VectorSetSimilaritySearchMessage.cs index 1bbc418d5..a492be4d8 100644 --- a/src/StackExchange.Redis/VectorSetSimilaritySearchMessage.cs +++ b/src/StackExchange.Redis/VectorSetSimilaritySearchMessage.cs @@ -87,7 +87,7 @@ private sealed class VectorSetSimilaritySearchProcessor : ResultProcessor Date: Sun, 15 Feb 2026 11:10:50 +0000 Subject: [PATCH 04/11] getting there --- .../RespReader.AggregateEnumerator.cs | 8 ++ .../PhysicalConnection.Read.cs | 2 +- src/StackExchange.Redis/RedisTransaction.cs | 92 ++++++++----------- .../RespReaderExtensions.cs | 2 +- src/StackExchange.Redis/ResultProcessor.cs | 24 ++--- tests/RESPite.Tests/RespReaderTests.cs | 2 + .../ResultProcessorUnitTests.cs | 69 ++++++++++++++ 7 files changed, 131 insertions(+), 68 deletions(-) create mode 100644 tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs diff --git a/src/RESPite/Messages/RespReader.AggregateEnumerator.cs b/src/RESPite/Messages/RespReader.AggregateEnumerator.cs index 1853d2ee6..be10cd5cb 100644 --- a/src/RESPite/Messages/RespReader.AggregateEnumerator.cs +++ b/src/RESPite/Messages/RespReader.AggregateEnumerator.cs @@ -1,5 +1,6 @@ using System.Collections; using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; #pragma warning disable IDE0079 // Remove unnecessary suppression #pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct @@ -41,6 +42,13 @@ public AggregateEnumerator(scoped in RespReader reader) /// [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] +#if DEBUG +#if NET6_0 || NET8_0 + [Experimental("SERDBG")] +#else + [Experimental("SERDBG", Message = $"Prefer {nameof(Value)}")] +#endif +#endif public RespReader Current => Value; /// diff --git a/src/StackExchange.Redis/PhysicalConnection.Read.cs b/src/StackExchange.Redis/PhysicalConnection.Read.cs index b2d790f1b..830096189 100644 --- a/src/StackExchange.Redis/PhysicalConnection.Read.cs +++ b/src/StackExchange.Redis/PhysicalConnection.Read.cs @@ -463,7 +463,7 @@ private void OnMessage( var iter = reader.AggregateChildren(); while (iter.MoveNext()) { - muxer.OnMessage(subscriptionChannel, messageChannel, iter.Current.ReadRedisValue()); + muxer.OnMessage(subscriptionChannel, messageChannel, iter.Value.ReadRedisValue()); } break; diff --git a/src/StackExchange.Redis/RedisTransaction.cs b/src/StackExchange.Redis/RedisTransaction.cs index 5943238a1..deeee46b4 100644 --- a/src/StackExchange.Redis/RedisTransaction.cs +++ b/src/StackExchange.Redis/RedisTransaction.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -487,72 +488,55 @@ public override bool SetResult(PhysicalConnection connection, Message message, r return base.SetResult(connection, message, ref copy); } - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { var muxer = connection.BridgeCouldBeNull?.Multiplexer; - muxer?.OnTransactionLog($"got {result} for {message.CommandAndKey}"); + muxer?.OnTransactionLog($"got {reader.GetOverview()} for {message.CommandAndKey}"); if (message is TransactionMessage tran) { var wrapped = tran.InnerOperations; - switch (result.Resp2TypeArray) + + if (reader.IsNull & !tran.IsAborted) // EXEC returned with a NULL { - case ResultType.SimpleString: - if (tran.IsAborted && result.IsEqual(CommonReplies.OK)) - { - connection.Trace("Acknowledging UNWATCH (aborted electively)"); - SetResult(message, false); - return true; - } - // EXEC returned with a NULL - if (!tran.IsAborted && result.IsNull) - { - connection.Trace("Server aborted due to failed EXEC"); - // cancel the commands in the transaction and mark them as complete with the completion manager - foreach (var op in wrapped) - { - var inner = op.Wrapped; - inner.Cancel(); - inner.Complete(); - } - SetResult(message, false); - return true; - } - break; - case ResultType.Array: - if (!tran.IsAborted) + muxer?.OnTransactionLog("Aborting wrapped messages (failed watch)"); + connection.Trace("Server aborted due to failed WATCH"); + foreach (var op in wrapped) + { + var inner = op.Wrapped; + inner.Cancel(); + inner.Complete(); + } + SetResult(message, false); + return true; + } + + switch (reader.Resp2PrefixArray) + { + case RespPrefix.SimpleString when tran.IsAborted & reader.IsOK(): + connection.Trace("Acknowledging UNWATCH (aborted electively)"); + SetResult(message, false); + return true; + case RespPrefix.Array when !tran.IsAborted: + var len = reader.AggregateLength(); + if (len == wrapped.Length) { - var arr = result.GetItems(); - if (result.IsNull) + connection.Trace("Server committed; processing nested replies"); + muxer?.OnTransactionLog($"Processing {len} wrapped messages"); + + var iter = reader.AggregateChildren(); + int i = 0; + while (iter.MoveNext()) { - muxer?.OnTransactionLog("Aborting wrapped messages (failed watch)"); - connection.Trace("Server aborted due to failed WATCH"); - foreach (var op in wrapped) + var inner = wrapped[i++].Wrapped; + muxer?.OnTransactionLog($"> got {iter.Value.GetOverview()} for {inner.CommandAndKey}"); + if (inner.ComputeResult(connection, ref iter.Value)) { - var inner = op.Wrapped; - inner.Cancel(); inner.Complete(); } - SetResult(message, false); - return true; - } - else if (wrapped.Length == arr.Length) - { - connection.Trace("Server committed; processing nested replies"); - muxer?.OnTransactionLog($"Processing {arr.Length} wrapped messages"); - - int i = 0; - foreach (ref RawResult item in arr) - { - var inner = wrapped[i++].Wrapped; - muxer?.OnTransactionLog($"> got {item} for {inner.CommandAndKey}"); - if (inner.ComputeResult(connection, in item)) - { - inner.Complete(); - } - } - SetResult(message, true); - return true; } + Debug.Assert(i == len, "we pre-checked the lengths"); + SetResult(message, true); + return true; } break; } diff --git a/src/StackExchange.Redis/RespReaderExtensions.cs b/src/StackExchange.Redis/RespReaderExtensions.cs index 6a0f93be1..179bc986c 100644 --- a/src/StackExchange.Redis/RespReaderExtensions.cs +++ b/src/StackExchange.Redis/RespReaderExtensions.cs @@ -28,7 +28,7 @@ public RedisValue ReadRedisValue() }; } - public string OverviewString() + public string GetOverview() { if (reader.IsNull) return "(null)"; diff --git a/src/StackExchange.Redis/ResultProcessor.cs b/src/StackExchange.Redis/ResultProcessor.cs index 2d18d7ebe..422fe8041 100644 --- a/src/StackExchange.Redis/ResultProcessor.cs +++ b/src/StackExchange.Redis/ResultProcessor.cs @@ -231,7 +231,7 @@ public virtual bool SetResult(PhysicalConnection connection, Message message, re { try { - logging.Log?.LogInformationResponse(bridge?.Name, message.CommandAndKey, reader.OverviewString()); + logging.Log?.LogInformationResponse(bridge?.Name, message.CommandAndKey, reader.GetOverview()); } catch (Exception ex) { @@ -246,7 +246,7 @@ public virtual bool SetResult(PhysicalConnection connection, Message message, re var copy = reader; if (SetResultCore(connection, message, ref reader)) { - bridge?.Multiplexer.Trace("Completed with success: " + copy.OverviewString() + " (" + GetType().Name + ")", ToString()); + bridge?.Multiplexer.Trace("Completed with success: " + copy.GetOverview() + " (" + GetType().Name + ")", ToString()); } else { @@ -263,7 +263,7 @@ private bool HandleCommonError(Message message, RespReader reader, PhysicalBridg } else if (reader.StartsWith(Literals.WRONGPASS.U8)) { - bridge?.Multiplexer.SetAuthSuspect(new RedisServerException(reader.OverviewString())); + bridge?.Multiplexer.SetAuthSuspect(new RedisServerException(reader.GetOverview())); } var server = bridge?.ServerEndPoint; @@ -359,7 +359,7 @@ protected virtual bool SetResultCore(PhysicalConnection connection, Message mess private void UnexpectedResponse(Message message, in RespReader reader) { ConnectionMultiplexer.TraceWithoutContext("From " + GetType().Name, "Unexpected Response"); - ConnectionFail(message, ConnectionFailureType.ProtocolFailure, "Unexpected response to " + (message?.CommandString ?? "n/a") + ": " + reader.OverviewString()); + ConnectionFail(message, ConnectionFailureType.ProtocolFailure, "Unexpected response to " + (message?.CommandString ?? "n/a") + ": " + reader.GetOverview()); } public sealed class TimeSpanProcessor : ResultProcessor @@ -1411,14 +1411,14 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private class Int32Processor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - switch (result.Resp2TypeBulkString) + switch (reader.Resp2PrefixBulkString) { - case ResultType.Integer: - case ResultType.SimpleString: - case ResultType.BulkString: - if (result.TryGetInt64(out long i64)) + case RespPrefix.Integer: + case RespPrefix.SimpleString: + case RespPrefix.BulkString: + if (reader.TryReadInt64(out long i64)) { SetResult(message, checked((int)i64)); return true; @@ -2879,7 +2879,7 @@ public override bool SetResult(PhysicalConnection connection, Message message, r reader = copy; // rewind and re-parse if (reader.StartsWith(Literals.ERR_not_permitted.U8) || reader.StartsWith(Literals.NOAUTH.U8)) { - connection.RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure, new Exception(reader.OverviewString() + " Verify if the Redis password provided is correct. Attempted command: " + message.Command)); + connection.RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure, new Exception(reader.GetOverview() + " Verify if the Redis password provided is correct. Attempted command: " + message.Command)); } else if (reader.StartsWith(Literals.LOADING.U8)) { @@ -2887,7 +2887,7 @@ public override bool SetResult(PhysicalConnection connection, Message message, r } else { - connection.RecordConnectionFailed(ConnectionFailureType.ProtocolFailure, new RedisServerException(reader.OverviewString())); + connection.RecordConnectionFailed(ConnectionFailureType.ProtocolFailure, new RedisServerException(reader.GetOverview())); } } diff --git a/tests/RESPite.Tests/RespReaderTests.cs b/tests/RESPite.Tests/RespReaderTests.cs index 4b250e7ec..690235795 100644 --- a/tests/RESPite.Tests/RespReaderTests.cs +++ b/tests/RESPite.Tests/RespReaderTests.cs @@ -462,7 +462,9 @@ public void Array(RespPayload payload) reader.MoveNext(RespPrefix.Array); int[] arr = new int[reader.AggregateLength()]; int i = 0; +#pragma warning disable SERDBG // warning about .Current vs .Value foreach (var sub in reader.AggregateChildren()) +#pragma warning restore SERDBG { sub.MoveNext(RespPrefix.Integer); arr[i++] = sub.ReadInt32(); diff --git a/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs b/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs new file mode 100644 index 000000000..755492ab3 --- /dev/null +++ b/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs @@ -0,0 +1,69 @@ +using System; +using System.Buffers; +using System.Text; +using RESPite.Messages; +using Xunit; + +namespace StackExchange.Redis.Tests; + +public class ResultProcessorUnitTests(ITestOutputHelper log) +{ + public void Log(string message) => log?.WriteLine(message); + + private protected static Message DummyMessage() + => Message.Create(0, default, RedisCommand.UNKNOWN); + + [Theory] + [InlineData(":1\r\n", 1)] + [InlineData("+1\r\n", 1)] + [InlineData("$1\r\n1\r\n", 1)] + public void Int32(string resp, int value) + { + var result = Execute(resp, ResultProcessor.Int32); + Assert.Equal(value, result); + } + + [Theory] + [InlineData(":1\r\n", 1)] + [InlineData("+1\r\n", 1)] + [InlineData("$1\r\n1\r\n", 1)] + public void Int64(string resp, int value) + { + var result = Execute(resp, ResultProcessor.Int32); + Assert.Equal(value, result); + } + + private protected static T? Execute(string resp, ResultProcessor processor) + { + Assert.True(TryExecute(resp, processor, out var value, out var ex)); + Assert.Null(ex); + return value; + } + + private protected static bool TryExecute(string resp, ResultProcessor processor, out T? value, out Exception? exception) + { + byte[]? lease = null; + try + { + var maxLen = Encoding.UTF8.GetMaxByteCount(resp.Length); + const int MAX_STACK = 128; + Span oversized = maxLen <= MAX_STACK + ? stackalloc byte[MAX_STACK] + : (lease = ArrayPool.Shared.Rent(maxLen)); + + var msg = DummyMessage(); + var box = SimpleResultBox.Get(); + msg.SetSource(processor, box); + + var reader = new RespReader(oversized.Slice(0, Encoding.UTF8.GetBytes(resp, oversized))); + bool success = processor.SetResult(null!, msg, ref reader); + exception = null; + value = success ? box.GetResult(out exception, canRecycle: true) : default; + return success; + } + finally + { + if (lease is not null) ArrayPool.Shared.Return(lease); + } + } +} From 70125100dd65ab89f3cc1aec3c8db3c39241682a Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Sun, 15 Feb 2026 14:18:31 +0000 Subject: [PATCH 05/11] working unit tests and API shim --- .../PhysicalConnection.Read.cs | 6 +- src/StackExchange.Redis/PhysicalConnection.cs | 17 +++- src/StackExchange.Redis/ResultProcessor.cs | 54 +++++++++++- .../Helpers/SharedConnectionFixture.cs | 30 +++---- .../ResultProcessorUnitTests.cs | 87 ++++++++++++++----- 5 files changed, 148 insertions(+), 46 deletions(-) diff --git a/src/StackExchange.Redis/PhysicalConnection.Read.cs b/src/StackExchange.Redis/PhysicalConnection.Read.cs index 830096189..3e0e0772e 100644 --- a/src/StackExchange.Redis/PhysicalConnection.Read.cs +++ b/src/StackExchange.Redis/PhysicalConnection.Read.cs @@ -4,11 +4,11 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; +using System.IO.Pipelines; using System.Net; -using System.Net.Sockets; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; -using Pipelines.Sockets.Unofficial; using RESPite.Buffers; using RESPite.Internal; using RESPite.Messages; @@ -17,6 +17,8 @@ namespace StackExchange.Redis; internal sealed partial class PhysicalConnection { + internal static PhysicalConnection Dummy() => new(null!); + private volatile ReadStatus _readStatus = ReadStatus.NotStarted; internal ReadStatus GetReadStatus() => _readStatus; diff --git a/src/StackExchange.Redis/PhysicalConnection.cs b/src/StackExchange.Redis/PhysicalConnection.cs index 30cd4d27b..10dc9250b 100644 --- a/src/StackExchange.Redis/PhysicalConnection.cs +++ b/src/StackExchange.Redis/PhysicalConnection.cs @@ -85,6 +85,21 @@ internal void GetBytes(out long sent, out long received) private Socket? _socket; internal Socket? VolatileSocket => Volatile.Read(ref _socket); + // used for dummy test connections + public PhysicalConnection( + ConnectionType connectionType = ConnectionType.Interactive, + RedisProtocol protocol = RedisProtocol.Resp2, + [CallerMemberName] string name = "") + { + lastWriteTickCount = lastReadTickCount = Environment.TickCount; + lastBeatTickCount = 0; + this.connectionType = connectionType; + _protocol = protocol; + _bridge = new WeakReference(null); + _physicalName = name; + + OnCreateEcho(); + } public PhysicalConnection(PhysicalBridge bridge) { lastWriteTickCount = lastReadTickCount = Environment.TickCount; @@ -275,7 +290,7 @@ private enum ReadMode : byte private RedisProtocol _protocol; // note starts at **zero**, not RESP2 public RedisProtocol? Protocol => _protocol == 0 ? null : _protocol; - internal void SetProtocol(RedisProtocol value) + public void SetProtocol(RedisProtocol value) { _protocol = value; BridgeCouldBeNull?.SetProtocol(value); diff --git a/src/StackExchange.Redis/ResultProcessor.cs b/src/StackExchange.Redis/ResultProcessor.cs index 422fe8041..a9263dc56 100644 --- a/src/StackExchange.Redis/ResultProcessor.cs +++ b/src/StackExchange.Redis/ResultProcessor.cs @@ -348,8 +348,50 @@ private bool HandleCommonError(Message message, RespReader reader, PhysicalBridg protected virtual bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - // temp hack so we can compile; this should be abstract - return false; + // spoof the old API from the new API; this is a transitional step only, and is inefficient + var rawResult = AsRaw(ref reader, connection.Protocol is RedisProtocol.Resp3); + return SetResultCore(connection, message, rawResult); + } + + private static RawResult AsRaw(ref RespReader reader, bool resp3) + { + var flags = RawResult.ResultFlags.HasValue; + if (!reader.IsNull) flags |= RawResult.ResultFlags.NonNull; + if (resp3) flags |= RawResult.ResultFlags.Resp3; + var type = Type(reader.Prefix); + if (reader.IsAggregate) + { + var inner = reader.ReadPastArray((ref value) => AsRaw(ref value, resp3), false) ?? []; + return new RawResult(type, new Sequence(inner), flags); + } + + if (reader.IsScalar) + { + ReadOnlySequence blob = new(reader.ReadByteArray() ?? []); + return new RawResult(type, blob, flags); + } + + return default; + + static ResultType Type(RespPrefix prefix) => prefix switch + { + RespPrefix.Array => ResultType.Array, + RespPrefix.Attribute => ResultType.Attribute, + RespPrefix.BigInteger => ResultType.BigInteger, + RespPrefix.Boolean => ResultType.Boolean, + RespPrefix.BulkError => ResultType.BlobError, + RespPrefix.BulkString => ResultType.BulkString, + RespPrefix.SimpleString => ResultType.SimpleString, + RespPrefix.Map => ResultType.Map, + RespPrefix.Set => ResultType.Set, + RespPrefix.Double => ResultType.Double, + RespPrefix.Integer => ResultType.Integer, + RespPrefix.SimpleError => ResultType.Error, + RespPrefix.Null => ResultType.Null, + RespPrefix.VerbatimString => ResultType.VerbatimString, + RespPrefix.Push=> ResultType.Push, + _ => throw new ArgumentOutOfRangeException(nameof(prefix), prefix, null), + }; } // temp hack so we can compile; this should be removed @@ -1733,7 +1775,13 @@ private sealed class Int64ArrayProcessor : ResultProcessor { protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) { - if (result.Resp2TypeArray == ResultType.Array && !result.IsNull) + if (result.IsNull) + { + SetResult(message, null!); + return true; + } + + if (result.Resp2TypeArray == ResultType.Array) { var arr = result.ToArray((in RawResult x) => (long)x.AsRedisValue())!; SetResult(message, arr); diff --git a/tests/StackExchange.Redis.Tests/Helpers/SharedConnectionFixture.cs b/tests/StackExchange.Redis.Tests/Helpers/SharedConnectionFixture.cs index 9656ee45b..74dbd7bbb 100644 --- a/tests/StackExchange.Redis.Tests/Helpers/SharedConnectionFixture.cs +++ b/tests/StackExchange.Redis.Tests/Helpers/SharedConnectionFixture.cs @@ -17,23 +17,8 @@ namespace StackExchange.Redis.Tests; public class SharedConnectionFixture : IDisposable { - public bool IsEnabled { get; } - - private readonly ConnectionMultiplexer _actualConnection; - public string Configuration { get; } - - public SharedConnectionFixture() - { - IsEnabled = TestConfig.Current.UseSharedConnection; - Configuration = TestBase.GetDefaultConfiguration(); - _actualConnection = TestBase.CreateDefault( - output: null, - clientName: nameof(SharedConnectionFixture), - configuration: Configuration, - allowAdmin: true); - _actualConnection.InternalError += OnInternalError; - _actualConnection.ConnectionFailed += OnConnectionFailed; - } + public bool IsEnabled { get; } = TestConfig.Current.UseSharedConnection; + public string Configuration { get; } = TestBase.GetDefaultConfiguration(); private NonDisposingConnection? resp2, resp3; internal IInternalConnectionMultiplexer GetConnection(TestBase obj, RedisProtocol protocol, [CallerMemberName] string caller = "") @@ -273,11 +258,16 @@ public void Teardown(TextWriter output) } // Assert.True(false, $"There were {privateFailCount} private ambient exceptions."); } + TearDown(resp2, output); + TearDown(resp3, output); + } - if (_actualConnection != null) + private void TearDown(IInternalConnectionMultiplexer? connection, TextWriter output) + { + if (connection is { } conn) { - TestBase.Log(output, "Connection Counts: " + _actualConnection.GetCounters().ToString()); - foreach (var ep in _actualConnection.GetServerSnapshot()) + TestBase.Log(output, "Connection Counts: " + conn.GetCounters().ToString()); + foreach (var ep in conn.GetServerSnapshot()) { var interactive = ep.GetBridge(ConnectionType.Interactive); TestBase.Log(output, $" {Format.ToString(interactive)}: {interactive?.GetStatus()}"); diff --git a/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs b/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs index 755492ab3..41560290e 100644 --- a/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs +++ b/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs @@ -1,5 +1,7 @@ using System; using System.Buffers; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text; using RESPite.Messages; using Xunit; @@ -8,39 +10,84 @@ namespace StackExchange.Redis.Tests; public class ResultProcessorUnitTests(ITestOutputHelper log) { - public void Log(string message) => log?.WriteLine(message); - - private protected static Message DummyMessage() - => Message.Create(0, default, RedisCommand.UNKNOWN); - [Theory] [InlineData(":1\r\n", 1)] [InlineData("+1\r\n", 1)] [InlineData("$1\r\n1\r\n", 1)] - public void Int32(string resp, int value) - { - var result = Execute(resp, ResultProcessor.Int32); - Assert.Equal(value, result); - } + [InlineData(":-42\r\n", -42)] + [InlineData("+-42\r\n", -42)] + [InlineData("$3\r\n-42\r\n", -42)] + public void Int32(string resp, int value) => Assert.Equal(value, Execute(resp, ResultProcessor.Int32)); + + [Theory] + [InlineData("+OK\r\n")] + [InlineData("$4\r\nPONG\r\n")] + public void FailingInt32(string resp) => ExecuteUnexpected(resp, ResultProcessor.Int32); [Theory] [InlineData(":1\r\n", 1)] [InlineData("+1\r\n", 1)] [InlineData("$1\r\n1\r\n", 1)] - public void Int64(string resp, int value) + [InlineData(":-42\r\n", -42)] + [InlineData("+-42\r\n", -42)] + [InlineData("$3\r\n-42\r\n", -42)] + public void Int64(string resp, long value) => Assert.Equal(value, Execute(resp, ResultProcessor.Int64)); + + [Theory] + [InlineData("+OK\r\n")] + [InlineData("$4\r\nPONG\r\n")] + public void FailingInt64(string resp) => ExecuteUnexpected(resp, ResultProcessor.Int64); + + [Theory] + [InlineData("*-1\r\n", null)] + [InlineData("*0\r\n", "")] + [InlineData("*1\r\n+42\r\n", "42")] + [InlineData("*2\r\n+42\r\n:78\r\n", "42,78")] + public void Int64Array(string resp, string? value) => Assert.Equal(value, Join(Execute(resp, ResultProcessor.Int64Array))); + + [return: NotNullIfNotNull(nameof(array))] + protected static string? Join(T[]? array, string separator = ",") { - var result = Execute(resp, ResultProcessor.Int32); - Assert.Equal(value, result); + if (array is null) return null; + return string.Join(separator, array); } - private protected static T? Execute(string resp, ResultProcessor processor) + public void Log(string message) => log?.WriteLine(message); + + private protected static Message DummyMessage() + => Message.Create(0, default, RedisCommand.UNKNOWN); + + private protected void ExecuteUnexpected( + string resp, + ResultProcessor processor, + ConnectionType connectionType = ConnectionType.Interactive, + RedisProtocol protocol = RedisProtocol.Resp2, + [CallerMemberName] string caller = "") + { + Assert.False(TryExecute(resp, processor, out _, out var ex)); + if (ex is not null) Log(ex.Message); + Assert.StartsWith("Unexpected response to UNKNOWN:", Assert.IsType(ex).Message); + } + private protected static T? Execute( + string resp, + ResultProcessor processor, + ConnectionType connectionType = ConnectionType.Interactive, + RedisProtocol protocol = RedisProtocol.Resp2, + [CallerMemberName] string caller = "") { Assert.True(TryExecute(resp, processor, out var value, out var ex)); Assert.Null(ex); return value; } - private protected static bool TryExecute(string resp, ResultProcessor processor, out T? value, out Exception? exception) + private protected static bool TryExecute( + string resp, + ResultProcessor processor, + out T? value, + out Exception? exception, + ConnectionType connectionType = ConnectionType.Interactive, + RedisProtocol protocol = RedisProtocol.Resp2, + [CallerMemberName] string caller = "") { byte[]? lease = null; try @@ -51,15 +98,15 @@ private protected static bool TryExecute(string resp, ResultProcessor proc ? stackalloc byte[MAX_STACK] : (lease = ArrayPool.Shared.Rent(maxLen)); - var msg = DummyMessage(); + var msg = DummyMessage(); var box = SimpleResultBox.Get(); msg.SetSource(processor, box); var reader = new RespReader(oversized.Slice(0, Encoding.UTF8.GetBytes(resp, oversized))); - bool success = processor.SetResult(null!, msg, ref reader); - exception = null; - value = success ? box.GetResult(out exception, canRecycle: true) : default; - return success; + PhysicalConnection connection = new(connectionType, protocol, caller); + Assert.True(processor.SetResult(connection, msg, ref reader)); + value = box.GetResult(out exception, canRecycle: true); + return exception is null; } finally { From befe691eeb52e41900228facc077f5b890fb6eb6 Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Mon, 16 Feb 2026 08:08:19 +0000 Subject: [PATCH 06/11] WIP --- src/RESPite/Buffers/CycleBuffer.cs | 6 +- .../ChannelMessageQueue.cs | 2 +- .../PhysicalConnection.Read.cs | 51 ++- src/StackExchange.Redis/PhysicalConnection.cs | 328 ++---------------- .../RespReaderExtensions.cs | 8 + .../ResultProcessor.Literals.cs | 1 - src/StackExchange.Redis/ServerEndPoint.cs | 2 +- 7 files changed, 89 insertions(+), 309 deletions(-) diff --git a/src/RESPite/Buffers/CycleBuffer.cs b/src/RESPite/Buffers/CycleBuffer.cs index 6ab982776..6d9f8ee12 100644 --- a/src/RESPite/Buffers/CycleBuffer.cs +++ b/src/RESPite/Buffers/CycleBuffer.cs @@ -149,6 +149,7 @@ public void DiscardCommitted(long count) private void DiscardCommittedSlow(long count) { DebugCounters.OnDiscardPartial(count); + DebugAssertValid(); #if DEBUG var originalLength = GetCommittedLength(); var originalCount = count; @@ -262,7 +263,6 @@ private void DebugAssertValid() public long GetCommittedLength() { - DebugAssertValid(); if (ReferenceEquals(startSegment, endSegment)) { return endSegmentCommitted; @@ -427,6 +427,10 @@ public Memory GetUncommittedMemory(int hint = 0) return MemoryMarshal.AsMemory(GetNextSegment().Memory); } + /// + /// This is the available unused buffer space, commonly used as the IO read-buffer to avoid + /// additional buffer-copy operations. + /// public int UncommittedAvailable { get diff --git a/src/StackExchange.Redis/ChannelMessageQueue.cs b/src/StackExchange.Redis/ChannelMessageQueue.cs index 9f962e52a..9263a0199 100644 --- a/src/StackExchange.Redis/ChannelMessageQueue.cs +++ b/src/StackExchange.Redis/ChannelMessageQueue.cs @@ -279,7 +279,7 @@ private async Task OnMessageAsyncImpl() try { var task = handler?.Invoke(next); - if (task != null && task.Status != TaskStatus.RanToCompletion) await task.ForAwait(); + if (task != null && !task.IsCompletedSuccessfully) await task.ForAwait(); } catch { } // matches MessageCompletable } diff --git a/src/StackExchange.Redis/PhysicalConnection.Read.cs b/src/StackExchange.Redis/PhysicalConnection.Read.cs index 3e0e0772e..75f1ff379 100644 --- a/src/StackExchange.Redis/PhysicalConnection.Read.cs +++ b/src/StackExchange.Redis/PhysicalConnection.Read.cs @@ -22,20 +22,21 @@ internal sealed partial class PhysicalConnection private volatile ReadStatus _readStatus = ReadStatus.NotStarted; internal ReadStatus GetReadStatus() => _readStatus; - internal void StartReading() => ReadAllAsync(Stream.Null).RedisFireAndForget(); + internal void StartReading() => ReadAllAsync().RedisFireAndForget(); - private async Task ReadAllAsync(Stream tail) + private async Task ReadAllAsync() { + var tail = _ioStream ?? Stream.Null; _readStatus = ReadStatus.Init; RespScanState state = default; - var readBuffer = CycleBuffer.Create(); + _readBuffer = CycleBuffer.Create(); try { int read; do { _readStatus = ReadStatus.ReadAsync; - var buffer = readBuffer.GetUncommittedMemory(); + var buffer = _readBuffer.GetUncommittedMemory(); var pending = tail.ReadAsync(buffer, CancellationToken.None); #if DEBUG bool inline = pending.IsCompleted; @@ -49,11 +50,12 @@ private async Task ReadAllAsync(Stream tail) _readStatus = ReadStatus.TryParseResult; } // another formatter glitch - while (CommitAndParseFrames(ref state, ref readBuffer, read)); + while (CommitAndParseFrames(ref state, read)); + _readStatus = ReadStatus.ProcessBufferComplete; // Volatile.Write(ref _readStatus, ReaderCompleted); - readBuffer.Release(); // clean exit, we can recycle + _readBuffer.Release(); // clean exit, we can recycle _readStatus = ReadStatus.RanToCompletion; RecordConnectionFailed(ConnectionFailureType.SocketClosed); } @@ -67,37 +69,56 @@ private async Task ReadAllAsync(Stream tail) _readStatus = ReadStatus.Faulted; RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex); } + finally + { + _readBuffer = default; // wipe, however we exited + } } private static byte[]? SharedNoLease; - private bool CommitAndParseFrames(ref RespScanState state, ref CycleBuffer readBuffer, int bytesRead) + private CycleBuffer _readBuffer; + private long GetReadCommittedLength() + { + try + { + var len = _readBuffer.GetCommittedLength(); + return len < 0 ? -1 : len; + } + catch + { + return -1; + } + } + + private bool CommitAndParseFrames(ref RespScanState state, int bytesRead) { if (bytesRead <= 0) { return false; } + totalBytesReceived += bytesRead; #if PARSE_DETAIL string src = $"parse {bytesRead}"; try #endif { - Debug.Assert(readBuffer.GetCommittedLength() >= 0, "multi-segment running-indices are corrupt"); + Debug.Assert(_readBuffer.GetCommittedLength() >= 0, "multi-segment running-indices are corrupt"); #if PARSE_DETAIL src += $" ({readBuffer.GetCommittedLength()}+{bytesRead}-{state.TotalBytes})"; #endif Debug.Assert( - bytesRead <= readBuffer.UncommittedAvailable, - $"Insufficient bytes in {nameof(CommitAndParseFrames)}; got {bytesRead}, Available={readBuffer.UncommittedAvailable}"); - readBuffer.Commit(bytesRead); + bytesRead <= _readBuffer.UncommittedAvailable, + $"Insufficient bytes in {nameof(CommitAndParseFrames)}; got {bytesRead}, Available={_readBuffer.UncommittedAvailable}"); + _readBuffer.Commit(bytesRead); #if PARSE_DETAIL src += $",total {readBuffer.GetCommittedLength()}"; #endif var scanner = RespFrameScanner.Default; OperationStatus status = OperationStatus.NeedMoreData; - if (readBuffer.TryGetCommitted(out var fullSpan)) + if (_readBuffer.TryGetCommitted(out var fullSpan)) { int fullyConsumed = 0; var toParse = fullSpan.Slice((int)state.TotalBytes); // skip what we've already parsed @@ -138,11 +159,11 @@ state is status = OperationStatus.NeedMoreData; } - readBuffer.DiscardCommitted(fullyConsumed); + _readBuffer.DiscardCommitted(fullyConsumed); } else // the same thing again, but this time with multi-segment sequence { - var fullSequence = readBuffer.GetAllCommitted(); + var fullSequence = _readBuffer.GetAllCommitted(); Debug.Assert( fullSequence is { IsEmpty: false, IsSingleSegment: false }, "non-trivial sequence expected"); @@ -184,7 +205,7 @@ state is status = OperationStatus.NeedMoreData; } - readBuffer.DiscardCommitted(fullyConsumed); + _readBuffer.DiscardCommitted(fullyConsumed); } if (status != OperationStatus.NeedMoreData) diff --git a/src/StackExchange.Redis/PhysicalConnection.cs b/src/StackExchange.Redis/PhysicalConnection.cs index 10dc9250b..29991d7ca 100644 --- a/src/StackExchange.Redis/PhysicalConnection.cs +++ b/src/StackExchange.Redis/PhysicalConnection.cs @@ -1,6 +1,5 @@ using System; using System.Buffers; -using System.Buffers.Binary; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -29,6 +28,7 @@ internal sealed partial class PhysicalConnection : IDisposable internal readonly byte[]? ChannelPrefix; private const int DefaultRedisDatabaseCount = 16; + private long totalBytesSent, totalBytesReceived; private static readonly Message[] ReusableChangeDatabaseCommands = Enumerable.Range(0, DefaultRedisDatabaseCount).Select( i => Message.Create(i, CommandFlags.FireAndForget, RedisCommand.SELECT)).ToArray(); @@ -64,23 +64,16 @@ private static readonly Message internal void GetBytes(out long sent, out long received) { - if (_ioPipe is IMeasuredDuplexPipe sc) - { - sent = sc.TotalBytesSent; - received = sc.TotalBytesReceived; - } - else - { - sent = received = -1; - } + sent = totalBytesSent; + received = totalBytesReceived; } /// /// Nullable because during simulation of failure, we'll null out. /// ...but in those cases, we'll accept any null ref in a race - it's fine. /// - private IDuplexPipe? _ioPipe; - internal bool HasOutputPipe => _ioPipe?.Output != null; + private Stream? _ioStream; + internal bool HasOutputPipe => _ioStream is not null; private Socket? _socket; internal Socket? VolatileSocket => Volatile.Read(ref _socket); @@ -299,18 +292,15 @@ public void SetProtocol(RedisProtocol value) [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2202:Do not dispose objects multiple times", Justification = "Trust me yo")] internal void Shutdown() { - var ioPipe = Interlocked.Exchange(ref _ioPipe, null); // compare to the critical read + var ioStream = Interlocked.Exchange(ref _ioStream, null); // compare to the critical read var socket = Interlocked.Exchange(ref _socket, null); - if (ioPipe != null) + if (ioStream != null) { Trace("Disconnecting..."); try { BridgeCouldBeNull?.OnDisconnected(ConnectionFailureType.ConnectionDisposed, this, out _, out _); } catch { } - try { ioPipe.Input?.CancelPendingRead(); } catch { } - try { ioPipe.Input?.Complete(); } catch { } - try { ioPipe.Output?.CancelPendingFlush(); } catch { } - try { ioPipe.Output?.Complete(); } catch { } - try { using (ioPipe as IDisposable) { } } catch { } + try { ioStream.Close(); } catch { } + try { ioStream.Dispose(); } catch { } } if (socket != null) @@ -336,7 +326,7 @@ public void Dispose() GC.SuppressFinalize(this); } - private async Task AwaitedFlush(ValueTask flush) + private async Task AwaitedFlush(Task flush) { await flush.ForAwait(); _writeStatus = WriteStatus.Flushed; @@ -345,7 +335,7 @@ private async Task AwaitedFlush(ValueTask flush) internal void UpdateLastWriteTime() => Interlocked.Exchange(ref lastWriteTickCount, Environment.TickCount); public Task FlushAsync() { - var tmp = _ioPipe?.Output; + var tmp = _ioStream; if (tmp != null) { _writeStatus = WriteStatus.Flushing; @@ -402,13 +392,12 @@ public void RecordConnectionFailed( bool isInitialConnect = false, IDuplexPipe? connectingPipe = null) { - bool weAskedForThis; Exception? outerException = innerException; IdentifyFailureType(innerException, ref failureType); var bridge = BridgeCouldBeNull; Message? nextMessage; - if (_ioPipe != null || isInitialConnect) // if *we* didn't burn the pipe: flag it + if (_ioStream is not null || isInitialConnect) // if *we* didn't burn the pipe: flag it { if (failureType == ConnectionFailureType.InternalFailure && innerException is not null) { @@ -448,7 +437,7 @@ public void RecordConnectionFailed( var exMessage = new StringBuilder(failureType.ToString()); // If the reason for the shutdown was we asked for the socket to die, don't log it as an error (only informational) - weAskedForThis = Volatile.Read(ref clientSentQuit) != 0; + var weAskedForThis = Volatile.Read(ref clientSentQuit) != 0; var pipe = connectingPipe ?? _ioPipe; if (pipe is SocketConnection sc) @@ -1112,14 +1101,15 @@ internal ValueTask FlushAsync(bool throwOnFailure, CancellationToke } } - private static readonly ReadOnlyMemory NullBulkString = Encoding.ASCII.GetBytes("$-1\r\n"), EmptyBulkString = Encoding.ASCII.GetBytes("$0\r\n\r\n"); + private static ReadOnlySpan NullBulkString => "$-1\r\n"u8; + private static ReadOnlySpan EmptyBulkString => "$0\r\n\r\n"u8; private static void WriteUnifiedBlob(PipeWriter writer, byte[]? value) { - if (value == null) + if (value is null) { // special case: - writer.Write(NullBulkString.Span); + writer.Write(NullBulkString); } else { @@ -1135,7 +1125,7 @@ private static void WriteUnifiedSpan(PipeWriter writer, ReadOnlySpan value if (value.Length == 0) { // special case: - writer.Write(EmptyBulkString.Span); + writer.Write(EmptyBulkString); } else if (value.Length <= MaxQuickSpanSize) { @@ -1176,16 +1166,16 @@ private static int AppendToSpan(Span span, ReadOnlySpan value, int o return WriteCrlf(span, offset); } - internal void WriteSha1AsHex(byte[] value) + internal void WriteSha1AsHex(byte[]? value) { if (_ioPipe?.Output is not PipeWriter writer) { return; // Prevent null refs during disposal } - if (value == null) + if (value is null) { - writer.Write(NullBulkString.Span); + writer.Write(NullBulkString); } else if (value.Length == ResultProcessor.ScriptLoadProcessor.Sha1HashLength) { @@ -1221,9 +1211,9 @@ internal static byte ToHexNibble(int value) return value < 10 ? (byte)('0' + value) : (byte)('a' - 10 + value); } - internal static void WriteUnifiedPrefixedString(PipeWriter? maybeNullWriter, byte[]? prefix, string? value) + internal static void WriteUnifiedPrefixedString(PipeWriter? writer, byte[]? prefix, string? value) { - if (maybeNullWriter is not PipeWriter writer) + if (writer is null) { return; // Prevent null refs during disposal } @@ -1231,7 +1221,7 @@ internal static void WriteUnifiedPrefixedString(PipeWriter? maybeNullWriter, byt if (value == null) { // special case - writer.Write(NullBulkString.Span); + writer.Write(NullBulkString); } else { @@ -1244,7 +1234,7 @@ internal static void WriteUnifiedPrefixedString(PipeWriter? maybeNullWriter, byt if (totalLength == 0) { // special-case - writer.Write(EmptyBulkString.Span); + writer.Write(EmptyBulkString); } else { @@ -1496,38 +1486,22 @@ public override string ToString() => public ConnectionStatus GetStatus() { - if (_ioPipe is SocketConnection conn) - { - var counters = conn.GetCounters(); - return new ConnectionStatus() - { - MessagesSentAwaitingResponse = GetSentAwaitingResponseCount(), - BytesAvailableOnSocket = counters.BytesAvailableOnSocket, - BytesInReadPipe = counters.BytesWaitingToBeRead, - BytesInWritePipe = counters.BytesWaitingToBeSent, - ReadStatus = _readStatus, - WriteStatus = _writeStatus, - BytesLastResult = bytesLastResult, - BytesInBuffer = bytesInBuffer, - }; - } - // Fall back to bytes waiting on the socket if we can - int fallbackBytesAvailable; + int socketBytes; try { - fallbackBytesAvailable = VolatileSocket?.Available ?? -1; + socketBytes = VolatileSocket?.Available ?? -1; } catch { // If this fails, we're likely in a race disposal situation and do not want to blow sky high here. - fallbackBytesAvailable = -1; + socketBytes = -1; } return new ConnectionStatus() { - BytesAvailableOnSocket = fallbackBytesAvailable, - BytesInReadPipe = -1, + BytesAvailableOnSocket = socketBytes, + BytesInReadPipe = GetReadCommittedLength(), BytesInWritePipe = -1, ReadStatus = _readStatus, WriteStatus = _writeStatus, @@ -1588,7 +1562,6 @@ internal async ValueTask ConnectedAsync(Socket? socket, ILogger? log, Sock var bridge = BridgeCouldBeNull; if (bridge == null) return false; - IDuplexPipe? pipe = null; try { // disallow connection in some cases @@ -1606,6 +1579,9 @@ internal async ValueTask ConnectedAsync(Socket? socket, ILogger? log, Sock stream = await tunnel.BeforeAuthenticateAsync(bridge.ServerEndPoint.EndPoint, bridge.ConnectionType, socket, CancellationToken.None).ForAwait(); } + static Stream DemandSocketStream(Socket? socket) + => new NetworkStream(socket ?? throw new InvalidOperationException("No socket or stream available - possibly a tunnel error")); + if (config.Ssl) { log?.LogInformationConfiguringTLS(); @@ -1615,7 +1591,7 @@ internal async ValueTask ConnectedAsync(Socket? socket, ILogger? log, Sock host = Format.ToStringHostOnly(bridge.ServerEndPoint.EndPoint); } - stream ??= new NetworkStream(socket ?? throw new InvalidOperationException("No socket or stream available - possibly a tunnel error")); + stream ??= DemandSocketStream(socket); var ssl = new SslStream( innerStream: stream, leaveInnerStreamOpen: false, @@ -1658,17 +1634,10 @@ internal async ValueTask ConnectedAsync(Socket? socket, ILogger? log, Sock stream = ssl; } - if (stream is not null) - { - pipe = StreamConnection.GetDuplex(stream, manager.SendPipeOptions, manager.ReceivePipeOptions, name: bridge.Name); - } - else - { - pipe = SocketConnection.Create(socket, manager.SendPipeOptions, manager.ReceivePipeOptions, name: bridge.Name); - } - OnWrapForLogging(ref pipe, _physicalName, manager); + stream ??= DemandSocketStream(socket); + OnWrapForLogging(ref stream, _physicalName, manager); - _ioPipe = pipe; + _ioStream = stream; log?.LogInformationConnected(bridge.Name); @@ -1721,119 +1690,10 @@ private void OnDebugAbort() } } - partial void OnWrapForLogging(ref IDuplexPipe pipe, string name, SocketManager mgr); + partial void OnWrapForLogging(ref Stream stream, string name, SocketManager mgr); internal void UpdateLastReadTime() => Interlocked.Exchange(ref lastReadTickCount, Environment.TickCount); - private static RawResult.ResultFlags AsNull(RawResult.ResultFlags flags) => flags & ~RawResult.ResultFlags.NonNull; - - private static RawResult ReadArray(ResultType resultType, RawResult.ResultFlags flags, Arena arena, in ReadOnlySequence buffer, ref BufferReader reader, bool includeDetailInExceptions, ServerEndPoint? server) - { - var itemCount = ReadLineTerminatedString(ResultType.Integer, flags, ref reader); - if (itemCount.HasValue) - { - if (!itemCount.TryGetInt64(out long i64)) - { - throw ExceptionFactory.ConnectionFailure( - includeDetailInExceptions, - ConnectionFailureType.ProtocolFailure, - itemCount.Is('?') ? "Streamed aggregate types not yet implemented" : "Invalid array length", - server); - } - - int itemCountActual = checked((int)i64); - - if (itemCountActual < 0) - { - // for null response by command like EXEC, RESP array: *-1\r\n - return new RawResult(resultType, items: default, AsNull(flags)); - } - else if (itemCountActual == 0) - { - // for zero array response by command like SCAN, Resp array: *0\r\n - return new RawResult(resultType, items: default, flags); - } - - if (resultType == ResultType.Map) itemCountActual <<= 1; // if it says "3", it means 3 pairs, i.e. 6 values - - var oversized = arena.Allocate(itemCountActual); - var result = new RawResult(resultType, oversized, flags); - - if (oversized.IsSingleSegment) - { - var span = oversized.FirstSpan; - for (int i = 0; i < span.Length; i++) - { - if (!(span[i] = TryParseResult(flags, arena, in buffer, ref reader, includeDetailInExceptions, server)).HasValue) - { - return RawResult.Nil; - } - } - } - else - { - foreach (var span in oversized.Spans) - { - for (int i = 0; i < span.Length; i++) - { - if (!(span[i] = TryParseResult(flags, arena, in buffer, ref reader, includeDetailInExceptions, server)).HasValue) - { - return RawResult.Nil; - } - } - } - } - return result; - } - return RawResult.Nil; - } - - private static RawResult ReadBulkString(ResultType type, RawResult.ResultFlags flags, ref BufferReader reader, bool includeDetailInExceptions, ServerEndPoint? server) - { - var prefix = ReadLineTerminatedString(ResultType.Integer, flags, ref reader); - if (prefix.HasValue) - { - if (!prefix.TryGetInt64(out long i64)) - { - throw ExceptionFactory.ConnectionFailure( - includeDetailInExceptions, - ConnectionFailureType.ProtocolFailure, - prefix.Is('?') ? "Streamed strings not yet implemented" : "Invalid bulk string length", - server); - } - int bodySize = checked((int)i64); - if (bodySize < 0) - { - return new RawResult(type, ReadOnlySequence.Empty, AsNull(flags)); - } - - if (reader.TryConsumeAsBuffer(bodySize, out var payload)) - { - switch (reader.TryConsumeCRLF()) - { - case ConsumeResult.NeedMoreData: - break; // see NilResult below - case ConsumeResult.Success: - return new RawResult(type, payload, flags); - default: - throw ExceptionFactory.ConnectionFailure(includeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid bulk string terminator", server); - } - } - } - return RawResult.Nil; - } - - private static RawResult ReadLineTerminatedString(ResultType type, RawResult.ResultFlags flags, ref BufferReader reader) - { - int crlfOffsetFromCurrent = BufferReader.FindNextCrLf(reader); - if (crlfOffsetFromCurrent < 0) return RawResult.Nil; - - var payload = reader.ConsumeAsBuffer(crlfOffsetFromCurrent); - reader.Consume(2); - - return new RawResult(type, payload, flags); - } - internal enum ReadStatus { NotStarted, @@ -1864,118 +1724,6 @@ internal enum ReadStatus NA = -1, } - internal static RawResult TryParseResult( - bool isResp3, - Arena arena, - in ReadOnlySequence buffer, - ref BufferReader reader, - bool includeDetilInExceptions, - PhysicalConnection? connection, - bool allowInlineProtocol = false) - { - return TryParseResult( - isResp3 ? (RawResult.ResultFlags.Resp3 | RawResult.ResultFlags.NonNull) : RawResult.ResultFlags.NonNull, - arena, - buffer, - ref reader, - includeDetilInExceptions, - connection?.BridgeCouldBeNull?.ServerEndPoint, - allowInlineProtocol); - } - - private static RawResult TryParseResult( - RawResult.ResultFlags flags, - Arena arena, - in ReadOnlySequence buffer, - ref BufferReader reader, - bool includeDetilInExceptions, - ServerEndPoint? server, - bool allowInlineProtocol = false) - { - int prefix; - do // this loop is just to allow us to parse (skip) attributes without doing a stack-dive - { - prefix = reader.PeekByte(); - if (prefix < 0) return RawResult.Nil; // EOF - switch (prefix) - { - // RESP2 - case '+': // simple string - reader.Consume(1); - return ReadLineTerminatedString(ResultType.SimpleString, flags, ref reader); - case '-': // error - reader.Consume(1); - return ReadLineTerminatedString(ResultType.Error, flags, ref reader); - case ':': // integer - reader.Consume(1); - return ReadLineTerminatedString(ResultType.Integer, flags, ref reader); - case '$': // bulk string - reader.Consume(1); - return ReadBulkString(ResultType.BulkString, flags, ref reader, includeDetilInExceptions, server); - case '*': // array - reader.Consume(1); - return ReadArray(ResultType.Array, flags, arena, in buffer, ref reader, includeDetilInExceptions, server); - // RESP3 - case '_': // null - reader.Consume(1); - return ReadLineTerminatedString(ResultType.Null, flags, ref reader); - case ',': // double - reader.Consume(1); - return ReadLineTerminatedString(ResultType.Double, flags, ref reader); - case '#': // boolean - reader.Consume(1); - return ReadLineTerminatedString(ResultType.Boolean, flags, ref reader); - case '!': // blob error - reader.Consume(1); - return ReadBulkString(ResultType.BlobError, flags, ref reader, includeDetilInExceptions, server); - case '=': // verbatim string - reader.Consume(1); - return ReadBulkString(ResultType.VerbatimString, flags, ref reader, includeDetilInExceptions, server); - case '(': // big number - reader.Consume(1); - return ReadLineTerminatedString(ResultType.BigInteger, flags, ref reader); - case '%': // map - reader.Consume(1); - return ReadArray(ResultType.Map, flags, arena, in buffer, ref reader, includeDetilInExceptions, server); - case '~': // set - reader.Consume(1); - return ReadArray(ResultType.Set, flags, arena, in buffer, ref reader, includeDetilInExceptions, server); - case '|': // attribute - reader.Consume(1); - var arr = ReadArray(ResultType.Attribute, flags, arena, in buffer, ref reader, includeDetilInExceptions, server); - if (!arr.HasValue) return RawResult.Nil; // failed to parse attribute data - - // for now, we want to just skip attribute data; so - // drop whatever we parsed on the floor and keep looking - break; // exits the SWITCH, not the DO/WHILE - case '>': // push - reader.Consume(1); - return ReadArray(ResultType.Push, flags, arena, in buffer, ref reader, includeDetilInExceptions, server); - } - } - while (prefix == '|'); - - if (allowInlineProtocol) return ParseInlineProtocol(flags, arena, ReadLineTerminatedString(ResultType.SimpleString, flags, ref reader)); - throw new InvalidOperationException("Unexpected response prefix: " + (char)prefix); - } - - private static RawResult ParseInlineProtocol(RawResult.ResultFlags flags, Arena arena, in RawResult line) - { - if (!line.HasValue) return RawResult.Nil; // incomplete line - - int count = 0; - foreach (var _ in line.GetInlineTokenizer()) count++; - var block = arena.Allocate(count); - - var iter = block.GetEnumerator(); - foreach (var token in line.GetInlineTokenizer()) - { - // this assigns *via a reference*, returned via the iterator; just... sweet - iter.GetNext() = new RawResult(line.Resp3Type, token, flags); // spoof RESP2 from RESP1 - } - return new RawResult(ResultType.Array, block, flags); // spoof RESP2 from RESP1 - } - internal bool HasPendingCallerFacingItems() { bool lockTaken = false; diff --git a/src/StackExchange.Redis/RespReaderExtensions.cs b/src/StackExchange.Redis/RespReaderExtensions.cs index 179bc986c..caff882d0 100644 --- a/src/StackExchange.Redis/RespReaderExtensions.cs +++ b/src/StackExchange.Redis/RespReaderExtensions.cs @@ -1,4 +1,5 @@ using System; +using System.Threading.Tasks; using RESPite.Messages; namespace StackExchange.Redis; @@ -107,4 +108,11 @@ internal bool AnyNull() return false; } } + +#if !(NET || NETSTANDARD2_1_OR_GREATER) + extension(Task task) + { + public bool IsCompletedSuccessfully => task.Status is TaskStatus.RanToCompletion; + } +#endif } diff --git a/src/StackExchange.Redis/ResultProcessor.Literals.cs b/src/StackExchange.Redis/ResultProcessor.Literals.cs index 79ea83150..3c61e0930 100644 --- a/src/StackExchange.Redis/ResultProcessor.Literals.cs +++ b/src/StackExchange.Redis/ResultProcessor.Literals.cs @@ -32,7 +32,6 @@ [FastHash] internal static partial class groups { } [FastHash] internal static partial class iids_tracked { } [FastHash] internal static partial class iids_added { } [FastHash] internal static partial class iids_duplicates { } - // ReSharper restore InconsistentNaming #pragma warning restore CS8981, SA1300, SA1134 // forgive naming etc } diff --git a/src/StackExchange.Redis/ServerEndPoint.cs b/src/StackExchange.Redis/ServerEndPoint.cs index abe8d8afb..c8fe42e28 100644 --- a/src/StackExchange.Redis/ServerEndPoint.cs +++ b/src/StackExchange.Redis/ServerEndPoint.cs @@ -673,7 +673,7 @@ static async Task OnEstablishingAsyncAwaited(PhysicalConnection connection, Task var handshake = HandshakeAsync(connection, log); - if (handshake.Status != TaskStatus.RanToCompletion) + if (!handshake.IsCompletedSuccessfully) { return OnEstablishingAsyncAwaited(connection, handshake); } From b9b72fe27b9fde8181b0dcf368b66920a1a9d584 Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Tue, 17 Feb 2026 10:35:16 +0000 Subject: [PATCH 07/11] deal with LoggingTunnel, untested --- .../Configuration/LoggingTunnel.cs | 254 ++++++++++-------- .../PhysicalConnection.Read.cs | 22 +- src/StackExchange.Redis/RedisResult.cs | 87 +++--- .../RespReaderExtensions.cs | 20 ++ src/StackExchange.Redis/ResultProcessor.cs | 38 +-- 5 files changed, 231 insertions(+), 190 deletions(-) diff --git a/src/StackExchange.Redis/Configuration/LoggingTunnel.cs b/src/StackExchange.Redis/Configuration/LoggingTunnel.cs index ccfa4ee63..3a05fab11 100644 --- a/src/StackExchange.Redis/Configuration/LoggingTunnel.cs +++ b/src/StackExchange.Redis/Configuration/LoggingTunnel.cs @@ -12,6 +12,9 @@ using System.Threading.Tasks; using Pipelines.Sockets.Unofficial; using Pipelines.Sockets.Unofficial.Arenas; +using RESPite.Buffers; +using RESPite.Internal; +using RESPite.Messages; using static StackExchange.Redis.PhysicalConnection; namespace StackExchange.Redis.Configuration; @@ -27,25 +30,152 @@ public abstract class LoggingTunnel : Tunnel private readonly bool _ssl; private readonly Tunnel? _tail; + private sealed class StreamRespReader(Stream source, bool isInbound) : IDisposable + { + private CycleBuffer _readBuffer = CycleBuffer.Create(); + private RespScanState _state; + private bool _reading, _disposed; // we need to track the state of the reader to avoid releasing the buffer while it's in use + + public bool TryTakeOne(out ContextualRedisResult result, bool withData = true) + { + var fullBuffer = _readBuffer.GetAllCommitted(); + var newData = fullBuffer.Slice(_state.TotalBytes); + var status = RespFrameScanner.Default.TryRead(ref _state, newData); + switch (status) + { + case OperationStatus.Done: + var frame = fullBuffer.Slice(0, _state.TotalBytes); + var reader = new RespReader(frame); + reader.MovePastBof(); + bool isOutOfBand = reader.Prefix is RespPrefix.Push + || (isInbound && reader.IsAggregate && + !IsArrayOutOfBand(in reader)); + + RedisResult? parsed; + if (withData) + { + if (!RedisResult.TryCreate(null, ref reader, out parsed)) + { + ThrowInvalidReadStatus(OperationStatus.InvalidData); + } + } + else + { + parsed = null; + } + result = new(parsed, isOutOfBand); + return true; + case OperationStatus.NeedMoreData: + result = default; + return false; + default: + ThrowInvalidReadStatus(status); + goto case OperationStatus.NeedMoreData; // never reached + } + } + + private static bool IsArrayOutOfBand(in RespReader source) + { + var reader = source; + int len; + if (!reader.IsStreaming + && (len = reader.AggregateLength()) >= 2 + && (reader.SafeTryMoveNext() & reader.IsInlineScalar & !reader.IsError)) + { + const int MAX_TYPE_LEN = 16; + var span = reader.TryGetSpan(out var tmp) + ? tmp + : StackCopyLengthChecked(in reader, stackalloc byte[MAX_TYPE_LEN]); + + var hash = span.Hash64(); + switch (hash) + { + case PushMessage.Hash when PushMessage.Is(hash, span) & len >= 3: + case PushPMessage.Hash when PushPMessage.Is(hash, span) & len >= 4: + case PushSMessage.Hash when PushSMessage.Is(hash, span) & len >= 3: + return true; + } + } + + return false; + } + + public ValueTask ReadOneAsync(CancellationToken cancellationToken = default) + => TryTakeOne(out var result) ? new(result) : ReadMoreAsync(cancellationToken); + + [DoesNotReturn] + private static void ThrowInvalidReadStatus(OperationStatus status) + => throw new InvalidOperationException($"Unexpected read status: {status}"); + + private async ValueTask ReadMoreAsync(CancellationToken cancellationToken) + { + while (true) + { + var buffer = _readBuffer.GetUncommittedMemory(); + Debug.Assert(!buffer.IsEmpty, "rule out zero-length reads"); + _reading = true; + var read = await source.ReadAsync(buffer, cancellationToken).ForAwait(); + _reading = false; + if (read <= 0) + { + // EOF + return default; + } + _readBuffer.Commit(read); + + if (TryTakeOne(out var result)) return result; + } + } + + public void Dispose() + { + bool disposed = _disposed; + _disposed = true; + _state = default; + + if (!(_reading | disposed)) _readBuffer.Release(); + _readBuffer = default; + if (!disposed) source.Dispose(); + } + + public async ValueTask ValidateAsync(CancellationToken cancellationToken = default) + { + long count = 0; + while (true) + { + var buffer = _readBuffer.GetUncommittedMemory(); + Debug.Assert(!buffer.IsEmpty, "rule out zero-length reads"); + _reading = true; + var read = await source.ReadAsync(buffer, cancellationToken).ForAwait(); + _reading = false; + if (read <= 0) + { + // EOF + return count; + } + _readBuffer.Commit(read); + while (TryTakeOne(out _, withData: false)) count++; + } + } + } + /// /// Replay the RESP messages for a pair of streams, invoking a callback per operation. /// public static async Task ReplayAsync(Stream @out, Stream @in, Action pair) { - using Arena arena = new(); - var outPipe = StreamConnection.GetReader(@out); - var inPipe = StreamConnection.GetReader(@in); - long count = 0; + using var outReader = new StreamRespReader(@out, isInbound: false); + using var inReader = new StreamRespReader(@in, isInbound: true); while (true) { - var sent = await ReadOneAsync(outPipe, arena, isInbound: false).ForAwait(); + if (!outReader.TryTakeOne(out var sent)) sent = await outReader.ReadOneAsync().ForAwait(); ContextualRedisResult received; try { do { - received = await ReadOneAsync(inPipe, arena, isInbound: true).ForAwait(); + if (!inReader.TryTakeOne(out received)) received = await inReader.ReadOneAsync().ForAwait(); if (received.IsOutOfBand && received.Result is not null) { // spoof an empty request for OOB messages @@ -93,26 +223,6 @@ public static async Task ReplayAsync(string path, Action ReadOneAsync(PipeReader input, Arena arena, bool isInbound) - { - while (true) - { - var readResult = await input.ReadAsync().ForAwait(); - var buffer = readResult.Buffer; - int handled = 0; - var result = buffer.IsEmpty ? default : ProcessBuffer(arena, ref buffer, isInbound); - input.AdvanceTo(buffer.Start, buffer.End); - - if (result.Result is not null) return result; - - if (handled == 0 && readResult.IsCompleted) - { - break; // no more data, or trailing incomplete messages - } - } - return default; - } - /// /// Validate a RESP stream and return the number of top-level RESP fragments. /// @@ -152,60 +262,8 @@ public static async Task ValidateAsync(string path) /// public static async Task ValidateAsync(Stream stream) { - using var arena = new Arena(); - var input = StreamConnection.GetReader(stream); - long total = 0, position = 0; - while (true) - { - var readResult = await input.ReadAsync().ForAwait(); - var buffer = readResult.Buffer; - int handled = 0; - if (!buffer.IsEmpty) - { - try - { - ProcessBuffer(arena, ref buffer, ref position, ref handled); // updates buffer.Start - } - catch (Exception ex) - { - throw new InvalidOperationException($"Invalid fragment starting at {position} (fragment {total + handled})", ex); - } - total += handled; - } - - input.AdvanceTo(buffer.Start, buffer.End); - - if (handled == 0 && readResult.IsCompleted) - { - break; // no more data, or trailing incomplete messages - } - } - return total; - } - private static void ProcessBuffer(Arena arena, ref ReadOnlySequence buffer, ref long position, ref int messageCount) - { - while (!buffer.IsEmpty) - { - var reader = new BufferReader(buffer); - try - { - var result = TryParseResult(true, arena, in buffer, ref reader, true, null); - if (result.HasValue) - { - buffer = reader.SliceFromCurrent(); - position += reader.TotalConsumed; - messageCount++; - } - else - { - break; // remaining buffer isn't enough; give up - } - } - finally - { - arena.Reset(); - } - } + using var reader = new StreamRespReader(stream, isInbound: false); + return await reader.ValidateAsync(); } private readonly struct ContextualRedisResult @@ -219,42 +277,6 @@ public ContextualRedisResult(RedisResult? result, bool isOutOfBand) } } - private static ContextualRedisResult ProcessBuffer(Arena arena, ref ReadOnlySequence buffer, bool isInbound) - { - if (!buffer.IsEmpty) - { - var reader = new BufferReader(buffer); - try - { - var result = TryParseResult(true, arena, in buffer, ref reader, true, null); - bool isOutOfBand = result.Resp3Type == ResultType.Push - || (isInbound && result.Resp2TypeArray == ResultType.Array && IsArrayOutOfBand(result)); - if (result.HasValue) - { - buffer = reader.SliceFromCurrent(); - if (!RedisResult.TryCreate(null, result, out var parsed)) - { - throw new InvalidOperationException("Unable to parse raw result to RedisResult"); - } - return new(parsed, isOutOfBand); - } - } - finally - { - arena.Reset(); - } - } - return default; - - static bool IsArrayOutOfBand(in RawResult result) - { - var items = result.GetItems(); - return (items.Length >= 3 && (items[0].IsEqual(message) || items[0].IsEqual(smessage))) - || (items.Length >= 4 && items[0].IsEqual(pmessage)); - } - } - private static readonly CommandBytes message = "message", pmessage = "pmessage", smessage = "smessage"; - /// /// Create a new instance of a . /// diff --git a/src/StackExchange.Redis/PhysicalConnection.Read.cs b/src/StackExchange.Redis/PhysicalConnection.Read.cs index 75f1ff379..b26bfe601 100644 --- a/src/StackExchange.Redis/PhysicalConnection.Read.cs +++ b/src/StackExchange.Redis/PhysicalConnection.Read.cs @@ -326,15 +326,15 @@ private enum PushKind SUnsubscribe, } - private bool OnOutOfBand(ReadOnlySpan payload, ref byte[]? lease) + internal static ReadOnlySpan StackCopyLengthChecked(scoped in RespReader reader, Span buffer) { - static ReadOnlySpan StackCopyLenChecked(scoped in RespReader reader, Span buffer) - { - var len = reader.CopyTo(buffer); - if (len == buffer.Length && reader.ScalarLength() > len) return default; // too small - return buffer.Slice(0, len); - } + var len = reader.CopyTo(buffer); + if (len == buffer.Length && reader.ScalarLength() > len) return default; // too small + return buffer.Slice(0, len); + } + private bool OnOutOfBand(ReadOnlySpan payload, ref byte[]? lease) + { var muxer = BridgeCouldBeNull?.Multiplexer; if (muxer is null) return true; // consume it blindly @@ -348,7 +348,7 @@ static ReadOnlySpan StackCopyLenChecked(scoped in RespReader reader, Span< { const int MAX_TYPE_LEN = 16; var span = reader.TryGetSpan(out var tmp) - ? tmp : StackCopyLenChecked(in reader, stackalloc byte[MAX_TYPE_LEN]); + ? tmp : StackCopyLengthChecked(in reader, stackalloc byte[MAX_TYPE_LEN]); var hash = span.Hash64(); RedisChannel.RedisChannelOptions channelOptions = RedisChannel.RedisChannelOptions.None; @@ -636,13 +636,13 @@ internal RedisChannel AsRedisChannel(in RespReader reader, RedisChannel.RedisCha } [FastHash("message")] - private static partial class PushMessage { } + internal static partial class PushMessage { } [FastHash("pmessage")] - private static partial class PushPMessage { } + internal static partial class PushPMessage { } [FastHash("smessage")] - private static partial class PushSMessage { } + internal static partial class PushSMessage { } [FastHash("subscribe")] private static partial class PushSubscribe { } diff --git a/src/StackExchange.Redis/RedisResult.cs b/src/StackExchange.Redis/RedisResult.cs index 4a1644c36..d9efd69c1 100644 --- a/src/StackExchange.Redis/RedisResult.cs +++ b/src/StackExchange.Redis/RedisResult.cs @@ -3,6 +3,7 @@ using System.ComponentModel; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using RESPite.Messages; namespace StackExchange.Redis { @@ -105,52 +106,62 @@ public static RedisResult Create(RedisResult[] values, ResultType resultType) /// Internally, this is very similar to RawResult, except it is designed to be usable, /// outside of the IO-processing pipeline: the buffers are standalone, etc. /// - internal static bool TryCreate(PhysicalConnection? connection, in RawResult result, [NotNullWhen(true)] out RedisResult? redisResult) + internal static bool TryCreate(PhysicalConnection? connection, ref RespReader reader, [NotNullWhen(true)] out RedisResult? redisResult) { try { - switch (result.Resp2TypeBulkString) + var type = reader.Prefix.ToResultType(); + if (reader.Prefix is RespPrefix.Null) { - case ResultType.Integer: - case ResultType.SimpleString: - case ResultType.BulkString: - redisResult = new SingleRedisResult(result.AsRedisValue(), result.Resp3Type); + redisResult = NullSingle; + return true; + } + + if (reader.IsError) + { + redisResult = new ErrorRedisResult(reader.ReadString(), type); + return true; + } + + if (reader.IsScalar) + { + redisResult = new SingleRedisResult(reader.ReadRedisValue(), type); + return true; + } + + if (reader.IsAggregate) + { + if (reader.IsNull) + { + redisResult = new ArrayRedisResult(null, type); return true; - case ResultType.Array: - if (result.IsNull) - { - redisResult = NullArray; - return true; - } - var items = result.GetItems(); - if (items.Length == 0) - { - redisResult = EmptyArray(result.Resp3Type); - return true; - } - var arr = new RedisResult[items.Length]; - int i = 0; - foreach (ref RawResult item in items) + } + var len = reader.AggregateLength(); + if (len == 0) + { + redisResult = EmptyArray(type); + return true; + } + + var arr = new RedisResult[len]; + var iter = reader.AggregateChildren(); + int i = 0; + while (iter.MoveNext()) // avoiding ReadPastArray here as we can't make it static in this case + { + if (!TryCreate(connection, ref iter.Value, out var next)) { - if (TryCreate(connection, in item, out var next)) - { - arr[i++] = next; - } - else - { - redisResult = null; - return false; - } + redisResult = null; + return false; } - redisResult = new ArrayRedisResult(arr, result.Resp3Type); - return true; - case ResultType.Error: - redisResult = new ErrorRedisResult(result.GetString(), result.Resp3Type); - return true; - default: - redisResult = null; - return false; + arr[i++] = next; + } + iter.MovePast(out reader); + redisResult = new ArrayRedisResult(arr, type); + return true; } + + redisResult = null; + return false; } catch (Exception ex) { diff --git a/src/StackExchange.Redis/RespReaderExtensions.cs b/src/StackExchange.Redis/RespReaderExtensions.cs index caff882d0..6714fa5a0 100644 --- a/src/StackExchange.Redis/RespReaderExtensions.cs +++ b/src/StackExchange.Redis/RespReaderExtensions.cs @@ -94,6 +94,26 @@ public RespPrefix ToResp2(RespPrefix nullValue) _ => prefix, }; } + + public ResultType ToResultType() => prefix switch + { + RespPrefix.Array => ResultType.Array, + RespPrefix.Attribute => ResultType.Attribute, + RespPrefix.BigInteger => ResultType.BigInteger, + RespPrefix.Boolean => ResultType.Boolean, + RespPrefix.BulkError => ResultType.BlobError, + RespPrefix.BulkString => ResultType.BulkString, + RespPrefix.SimpleString => ResultType.SimpleString, + RespPrefix.Map => ResultType.Map, + RespPrefix.Set => ResultType.Set, + RespPrefix.Double => ResultType.Double, + RespPrefix.Integer => ResultType.Integer, + RespPrefix.SimpleError => ResultType.Error, + RespPrefix.Null => ResultType.Null, + RespPrefix.VerbatimString => ResultType.VerbatimString, + RespPrefix.Push=> ResultType.Push, + _ => throw new ArgumentOutOfRangeException(nameof(prefix), prefix, null), + }; } extension(T?[] array) where T : class diff --git a/src/StackExchange.Redis/ResultProcessor.cs b/src/StackExchange.Redis/ResultProcessor.cs index a9263dc56..4ca08e663 100644 --- a/src/StackExchange.Redis/ResultProcessor.cs +++ b/src/StackExchange.Redis/ResultProcessor.cs @@ -358,11 +358,19 @@ private static RawResult AsRaw(ref RespReader reader, bool resp3) var flags = RawResult.ResultFlags.HasValue; if (!reader.IsNull) flags |= RawResult.ResultFlags.NonNull; if (resp3) flags |= RawResult.ResultFlags.Resp3; - var type = Type(reader.Prefix); + var type = reader.Prefix.ToResultType(); if (reader.IsAggregate) { - var inner = reader.ReadPastArray((ref value) => AsRaw(ref value, resp3), false) ?? []; - return new RawResult(type, new Sequence(inner), flags); + var len = reader.AggregateLength(); + var arr = len == 0 ? [] : new RawResult[len]; + int i = 0; + var iter = reader.AggregateChildren(); + while (iter.MoveNext()) + { + arr[i++] = AsRaw(ref iter.Value, resp3); + } + iter.MovePast(out reader); + return new RawResult(type, new Sequence(arr), flags); } if (reader.IsScalar) @@ -372,26 +380,6 @@ private static RawResult AsRaw(ref RespReader reader, bool resp3) } return default; - - static ResultType Type(RespPrefix prefix) => prefix switch - { - RespPrefix.Array => ResultType.Array, - RespPrefix.Attribute => ResultType.Attribute, - RespPrefix.BigInteger => ResultType.BigInteger, - RespPrefix.Boolean => ResultType.Boolean, - RespPrefix.BulkError => ResultType.BlobError, - RespPrefix.BulkString => ResultType.BulkString, - RespPrefix.SimpleString => ResultType.SimpleString, - RespPrefix.Map => ResultType.Map, - RespPrefix.Set => ResultType.Set, - RespPrefix.Double => ResultType.Double, - RespPrefix.Integer => ResultType.Integer, - RespPrefix.SimpleError => ResultType.Error, - RespPrefix.Null => ResultType.Null, - RespPrefix.VerbatimString => ResultType.VerbatimString, - RespPrefix.Push=> ResultType.Push, - _ => throw new ArgumentOutOfRangeException(nameof(prefix), prefix, null), - }; } // temp hack so we can compile; this should be removed @@ -2168,9 +2156,9 @@ public override bool SetResult(PhysicalConnection connection, Message message, r // note that top-level error messages still get handled by SetResult, but nested errors // (is that a thing?) will be wrapped in the RedisResult - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - if (RedisResult.TryCreate(connection, result, out var value)) + if (RedisResult.TryCreate(connection, ref reader, out var value)) { SetResult(message, value); return true; From ab40ca41475bd0180a9fd87e52bdd590fda88e22 Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Tue, 17 Feb 2026 14:09:01 +0000 Subject: [PATCH 08/11] core write with basic buffer --- src/RESPite/Buffers/CycleBuffer.cs | 1 + src/RESPite/Internal/BlockBuffer.cs | 341 +++++++++++ src/RESPite/Internal/BlockBufferSerializer.cs | 97 +++ src/RESPite/Internal/DebugCounters.cs | 86 ++- .../SynchronizedBlockBufferSerializer.cs | 123 ++++ .../ThreadLocalBlockBufferSerializer.cs | 21 + src/RESPite/PublicAPI/PublicAPI.Unshipped.txt | 44 +- src/RESPite/RESPite.csproj | 9 + src/StackExchange.Redis/CommandBytes.cs | 2 +- src/StackExchange.Redis/Condition.cs | 20 +- .../Configuration/LoggingTunnel.cs | 10 +- .../ConnectionMultiplexer.cs | 2 +- src/StackExchange.Redis/ExceptionFactory.cs | 2 +- src/StackExchange.Redis/Expiration.cs | 10 +- .../HotKeys.StartMessage.cs | 32 +- .../Message.ValueCondition.cs | 20 +- src/StackExchange.Redis/Message.cs | 465 ++++++++------- src/StackExchange.Redis/MessageWriter.cs | 558 ++++++++++++++++++ src/StackExchange.Redis/PhysicalConnection.cs | 558 +----------------- src/StackExchange.Redis/RedisDatabase.cs | 156 ++--- src/StackExchange.Redis/RedisSubscriber.cs | 1 - src/StackExchange.Redis/RedisTransaction.cs | 6 +- src/StackExchange.Redis/RedisValue.cs | 2 +- .../ResultProcessor.VectorSets.cs | 5 +- src/StackExchange.Redis/ResultProcessor.cs | 8 +- src/StackExchange.Redis/ValueCondition.cs | 22 +- .../VectorSetAddMessage.cs | 52 +- .../VectorSetSimilaritySearchMessage.cs | 56 +- tests/StackExchange.Redis.Tests/ParseTests.cs | 52 +- toys/StackExchange.Redis.Server/RespServer.cs | 23 +- 30 files changed, 1789 insertions(+), 995 deletions(-) create mode 100644 src/RESPite/Internal/BlockBuffer.cs create mode 100644 src/RESPite/Internal/BlockBufferSerializer.cs create mode 100644 src/RESPite/Internal/SynchronizedBlockBufferSerializer.cs create mode 100644 src/RESPite/Internal/ThreadLocalBlockBufferSerializer.cs create mode 100644 src/StackExchange.Redis/MessageWriter.cs diff --git a/src/RESPite/Buffers/CycleBuffer.cs b/src/RESPite/Buffers/CycleBuffer.cs index 6d9f8ee12..f9016414c 100644 --- a/src/RESPite/Buffers/CycleBuffer.cs +++ b/src/RESPite/Buffers/CycleBuffer.cs @@ -27,6 +27,7 @@ namespace RESPite.Buffers; /// /// There is a *lot* of validation in debug mode; we want to be super sure that we don't corrupt buffer state. /// +[Experimental(Experiments.Respite, UrlFormat = Experiments.UrlFormat)] public partial struct CycleBuffer { // note: if someone uses an uninitialized CycleBuffer (via default): that's a skills issue; git gud diff --git a/src/RESPite/Internal/BlockBuffer.cs b/src/RESPite/Internal/BlockBuffer.cs new file mode 100644 index 000000000..752d74c8d --- /dev/null +++ b/src/RESPite/Internal/BlockBuffer.cs @@ -0,0 +1,341 @@ +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace RESPite.Internal; + +internal abstract partial class BlockBufferSerializer +{ + internal sealed class BlockBuffer : MemoryManager + { + private BlockBuffer(BlockBufferSerializer parent, int minCapacity) + { + _arrayPool = parent._arrayPool; + _array = _arrayPool.Rent(minCapacity); + DebugCounters.OnBufferCapacity(_array.Length); +#if DEBUG + _parent = parent; + parent.DebugBufferCreated(); +#endif + } + + private int _refCount = 1; + private int _finalizedOffset, _writeOffset; + private readonly ArrayPool _arrayPool; + private byte[] _array; +#if DEBUG + private int _finalizedCount; + private BlockBufferSerializer _parent; +#endif + + public override string ToString() => +#if DEBUG + $"{_finalizedCount} messages; " + +#endif + $"{_finalizedOffset} finalized bytes; writing: {NonFinalizedData.Length} bytes, {Available} available; observers: {_refCount}"; + + // only used when filling; _buffer should be non-null + private int Available => _array.Length - _writeOffset; + public Memory UncommittedMemory => _array.AsMemory(_writeOffset); + public Span UncommittedSpan => _array.AsSpan(_writeOffset); + + // decrease ref-count; dispose if necessary + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Release() + { + if (Interlocked.Decrement(ref _refCount) <= 0) Recycle(); + } + + public void AddRef() + { + if (!TryAddRef()) Throw(); + static void Throw() => throw new ObjectDisposedException(nameof(BlockBuffer)); + } + + public bool TryAddRef() + { + int count; + do + { + count = Volatile.Read(ref _refCount); + if (count <= 0) return false; + } + // repeat until we can successfully swap/incr + while (Interlocked.CompareExchange(ref _refCount, count + 1, count) != count); + + return true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] // called rarely vs Dispose + private void Recycle() + { + var count = Volatile.Read(ref _refCount); + if (count == 0) + { + _array.DebugScramble(); +#if DEBUG + GC.SuppressFinalize(this); // only have a finalizer in debug + _parent.DebugBufferRecycled(_array.Length); +#endif + _arrayPool.Return(_array); + _array = []; + } + + Debug.Assert(count == 0, $"over-disposal? count={count}"); + } + +#if DEBUG +#pragma warning disable CA2015 // Adding a finalizer to a type derived from MemoryManager may permit memory to be freed while it is still in use by a Span + // (the above is fine because we don't actually release anything - just a counter) + ~BlockBuffer() + { + _parent.DebugBufferLeaked(); + DebugCounters.OnBufferLeaked(); + } +#pragma warning restore CA2015 +#endif + + public static BlockBuffer GetBuffer(BlockBufferSerializer parent, int sizeHint) + { + // note this isn't an actual "max", just a max of what we guarantee; we give the caller + // whatever is left in the buffer; the clamped hint just decides whether we need a *new* buffer + const int MinSize = 16, MaxSize = 128; + sizeHint = Math.Min(Math.Max(sizeHint, MinSize), MaxSize); + + var buffer = parent.Buffer; // most common path is "exists, with enough data" + return buffer is not null && buffer.AvailableWithResetIfUseful() >= sizeHint + ? buffer + : GetBufferSlow(parent, sizeHint); + } + + // would it be useful and possible to reset? i.e. if all finalized chunks have been returned, + private int AvailableWithResetIfUseful() + { + if (_finalizedOffset != 0 // at least some chunks have been finalized + && Volatile.Read(ref _refCount) == 1 // all finalized chunks returned + & _writeOffset == _finalizedOffset) // we're not in the middle of serializing something new + { + _writeOffset = _finalizedOffset = 0; // swipe left + } + + return _array.Length - _writeOffset; + } + + private static BlockBuffer GetBufferSlow(BlockBufferSerializer parent, int minBytes) + { + // note clamp on size hint has already been applied + const int DefaultBufferSize = 2048; + var buffer = parent.Buffer; + if (buffer is null) + { + // first buffer + return parent.Buffer = new BlockBuffer(parent, DefaultBufferSize); + } + + Debug.Assert(minBytes > buffer.Available, "existing buffer has capacity - why are we here?"); + + if (buffer.TryResizeFor(minBytes)) + { + Debug.Assert(buffer.Available >= minBytes); + return buffer; + } + + // We've tried reset and resize - no more tricks; we need to move to a new buffer, starting with a + // capacity for any existing data in this message, plus the new chunk we're adding. + var nonFinalizedBytes = buffer.NonFinalizedData; + var newBuffer = new BlockBuffer(parent, Math.Max(nonFinalizedBytes.Length + minBytes, DefaultBufferSize)); + + // copy the existing message data, if any (the previous message might have finished near the + // boundary, in which case we might not have written anything yet) + newBuffer.CopyFrom(nonFinalizedBytes); + Debug.Assert(newBuffer.Available >= minBytes, "should have requested extra capacity"); + + // the ~emperor~ buffer is dead; long live the ~emperor~ buffer + parent.Buffer = newBuffer; + buffer.MarkComplete(parent); + return newBuffer; + } + + // used for elective reset (rather than "because we ran out of space") + public static void Clear(BlockBufferSerializer parent) + { + if (parent.Buffer is { } buffer) + { + parent.Buffer = null; + buffer.MarkComplete(parent); + } + } + + public static ReadOnlyMemory RetainCurrent(BlockBufferSerializer parent) + { + if (parent.Buffer is { } buffer && buffer._finalizedOffset != 0) + { + parent.Buffer = null; + buffer.AddRef(); + return buffer.CreateMemory(0, buffer._finalizedOffset); + } + // nothing useful to detach! + return default; + } + + private void MarkComplete(BlockBufferSerializer parent) + { + // record that the old buffer no longer logically has any non-committed bytes (mostly just for ToString()) + _writeOffset = _finalizedOffset; + Debug.Assert(IsNonCommittedEmpty); + + // see if the caller wants to take ownership of the segment + if (_finalizedOffset != 0 && !parent.ClaimSegment(CreateMemory(0, _finalizedOffset))) + { + Release(); // decrement the observer + } +#if DEBUG + DebugCounters.OnBufferCompleted(_finalizedCount, _finalizedOffset); +#endif + } + + private void CopyFrom(Span source) + { + source.CopyTo(UncommittedSpan); + _writeOffset += source.Length; + } + + private Span NonFinalizedData => _array.AsSpan( + _finalizedOffset, _writeOffset - _finalizedOffset); + + private bool TryResizeFor(int extraBytes) + { + if (_finalizedOffset == 0 & // we can only do this if there are no other messages in the buffer + Volatile.Read(ref _refCount) == 1) // and no-one else is looking (we already tried reset) + { + // we're already on the boundary - don't scrimp; just do the math from the end of the buffer + byte[] newArray = _arrayPool.Rent(_array.Length + extraBytes); + DebugCounters.OnBufferCapacity(newArray.Length - _array.Length); // account for extra only + + // copy the existing data (we always expect some, since we've clamped extraBytes to be + // much smaller than the default buffer size) + NonFinalizedData.CopyTo(newArray); + _array.DebugScramble(); + _arrayPool.Return(_array); + _array = newArray; + return true; + } + + return false; + } + + public static void Advance(BlockBufferSerializer parent, int count) + { + if (count == 0) return; + if (count < 0) ThrowOutOfRange(); + var buffer = parent.Buffer; + if (buffer is null || buffer.Available < count) ThrowOutOfRange(); + buffer._writeOffset += count; + + [DoesNotReturn] + static void ThrowOutOfRange() => throw new ArgumentOutOfRangeException(nameof(count)); + } + + public void RevertUnfinalized(BlockBufferSerializer parent) + { + // undo any writes (something went wrong during serialize) + _finalizedOffset = _writeOffset; + } + + private ReadOnlyMemory FinalizeBlock() + { + var length = _writeOffset - _finalizedOffset; + Debug.Assert(length > 0, "already checked this in FinalizeMessage!"); + var chunk = CreateMemory(_finalizedOffset, length); + _finalizedOffset = _writeOffset; // move the write head +#if DEBUG + _finalizedCount++; + _parent.DebugMessageFinalized(length); +#endif + Interlocked.Increment(ref _refCount); // add an observer + return chunk; + } + + private bool IsNonCommittedEmpty => _finalizedOffset == _writeOffset; + + public static ReadOnlyMemory FinalizeMessage(BlockBufferSerializer parent) + { + var buffer = parent.Buffer; + if (buffer is null || buffer.IsNonCommittedEmpty) + { +#if DEBUG // still count it for logging purposes + if (buffer is not null) buffer._finalizedCount++; + parent.DebugMessageFinalized(0); +#endif + return default; + } + + return buffer.FinalizeBlock(); + } + + // MemoryManager pieces + protected override void Dispose(bool disposing) + { + if (disposing) Release(); + } + + public override Span GetSpan() => _array; + public int Length => _array.Length; + + // base version is CreateMemory(GetSpan().Length); avoid that GetSpan() + public override Memory Memory => CreateMemory(_array.Length); + + public override unsafe MemoryHandle Pin(int elementIndex = 0) + { + // We *could* be cute and use a shared pin - but that's a *lot* + // of work (synchronization), requires extra storage, and for an + // API that is very unlikely; hence: we'll use per-call GC pins. + GCHandle handle = GCHandle.Alloc(_array, GCHandleType.Pinned); + DebugCounters.OnBufferPinned(); // prove how unlikely this is + byte* ptr = (byte*)handle.AddrOfPinnedObject(); + // note no IPinnable in the MemoryHandle; + return new MemoryHandle(ptr + elementIndex, handle); + } + + // This would only be called if we passed out a MemoryHandle with ourselves + // as IPinnable (in Pin), which: we don't. + public override void Unpin() => throw new NotSupportedException(); + + protected override bool TryGetArray(out ArraySegment segment) + { + segment = new ArraySegment(_array); + return true; + } + + internal static void Release(in ReadOnlySequence request) + { + if (request.IsSingleSegment) + { + if (MemoryMarshal.TryGetMemoryManager( + request.First, out var block)) + { + block.Release(); + } + } + else + { + ReleaseMultiBlock(in request); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void ReleaseMultiBlock(in ReadOnlySequence request) + { + foreach (var segment in request) + { + if (MemoryMarshal.TryGetMemoryManager( + segment, out var block)) + { + block.Release(); + } + } + } + } + } +} diff --git a/src/RESPite/Internal/BlockBufferSerializer.cs b/src/RESPite/Internal/BlockBufferSerializer.cs new file mode 100644 index 000000000..a74b96472 --- /dev/null +++ b/src/RESPite/Internal/BlockBufferSerializer.cs @@ -0,0 +1,97 @@ +using System.Buffers; +using System.Diagnostics; +using RESPite.Messages; + +namespace RESPite.Internal; + +/// +/// Provides abstracted access to a buffer-writing API. Conveniently, we only give the caller +/// RespWriter - which they cannot export (ref-type), thus we never actually give the +/// public caller our IBufferWriter{byte}. Likewise, note that serialization is synchronous, +/// i.e. never switches thread during an operation. This gives us quite a bit of flexibility. +/// There are two main uses of BlockBufferSerializer: +/// 1. thread-local: ambient, used for random messages so that each thread is quietly packing +/// a thread-specific buffer; zero concurrency because of [ThreadStatic] hackery. +/// 2. batching: RespBatch hosts a serializer that reflects the batch we're building; successive +/// commands in the same batch are written adjacently in a shared buffer - we explicitly +/// detect and reject concurrency attempts in a batch (which is fair: a batch has order). +/// +internal abstract partial class BlockBufferSerializer(ArrayPool? arrayPool = null) : IBufferWriter +{ + private readonly ArrayPool _arrayPool = arrayPool ?? ArrayPool.Shared; + private protected abstract BlockBuffer? Buffer { get; set; } + + Memory IBufferWriter.GetMemory(int sizeHint) => BlockBuffer.GetBuffer(this, sizeHint).UncommittedMemory; + + Span IBufferWriter.GetSpan(int sizeHint) => BlockBuffer.GetBuffer(this, sizeHint).UncommittedSpan; + + void IBufferWriter.Advance(int count) => BlockBuffer.Advance(this, count); + + public virtual void Clear() => BlockBuffer.Clear(this); + + internal virtual ReadOnlySequence Flush() => throw new NotSupportedException(); + + /* + public virtual ReadOnlyMemory Serialize( + RespCommandMap? commandMap, + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + try + { + var writer = new RespWriter(this); + writer.CommandMap = commandMap; + formatter.Format(command, ref writer, request); + writer.Flush(); + return BlockBuffer.FinalizeMessage(this); + } + catch + { + Buffer?.RevertUnfinalized(this); + throw; + } + } + */ + + internal void Revert() => Buffer?.RevertUnfinalized(this); + + protected virtual bool ClaimSegment(ReadOnlyMemory segment) => false; + +#if DEBUG + private int _countAdded, _countRecycled, _countLeaked, _countMessages; + private long _countMessageBytes; + public int CountLeaked => Volatile.Read(ref _countLeaked); + public int CountRecycled => Volatile.Read(ref _countRecycled); + public int CountAdded => Volatile.Read(ref _countAdded); + public int CountMessages => Volatile.Read(ref _countMessages); + public long CountMessageBytes => Volatile.Read(ref _countMessageBytes); + + [Conditional("DEBUG")] + private void DebugBufferLeaked() => Interlocked.Increment(ref _countLeaked); + + [Conditional("DEBUG")] + private void DebugBufferRecycled(int length) + { + Interlocked.Increment(ref _countRecycled); + DebugCounters.OnBufferRecycled(length); + } + + [Conditional("DEBUG")] + private void DebugBufferCreated() + { + Interlocked.Increment(ref _countAdded); + DebugCounters.OnBufferCreated(); + } + + [Conditional("DEBUG")] + private void DebugMessageFinalized(int bytes) + { + Interlocked.Increment(ref _countMessages); + Interlocked.Add(ref _countMessageBytes, bytes); + } +#endif +} diff --git a/src/RESPite/Internal/DebugCounters.cs b/src/RESPite/Internal/DebugCounters.cs index 6b0d0866d..d6f3da37a 100644 --- a/src/RESPite/Internal/DebugCounters.cs +++ b/src/RESPite/Internal/DebugCounters.cs @@ -9,11 +9,20 @@ private static int _tallyAsyncReadCount, _tallyAsyncReadInlineCount, _tallyDiscardFullCount, - _tallyDiscardPartialCount; + _tallyDiscardPartialCount, + _tallyBufferCreatedCount, + _tallyBufferRecycledCount, + _tallyBufferMessageCount, + _tallyBufferPinCount, + _tallyBufferLeakCount; private static long _tallyReadBytes, - _tallyDiscardAverage; + _tallyDiscardAverage, + _tallyBufferMessageBytes, + _tallyBufferRecycledBytes, + _tallyBufferMaxOutstandingBytes, + _tallyBufferTotalBytes; #endif [Conditional("DEBUG")] @@ -49,6 +58,69 @@ internal static void OnAsyncRead(int bytes, bool inline) #endif } + [Conditional("DEBUG")] + public static void OnBufferCreated() + { +#if DEBUG + Interlocked.Increment(ref _tallyBufferCreatedCount); +#endif + } + + [Conditional("DEBUG")] + public static void OnBufferRecycled(int messageBytes) + { +#if DEBUG + Interlocked.Increment(ref _tallyBufferRecycledCount); + var now = Interlocked.Add(ref _tallyBufferRecycledBytes, messageBytes); + var outstanding = Volatile.Read(ref _tallyBufferMessageBytes) - now; + + while (true) + { + var oldOutstanding = Volatile.Read(ref _tallyBufferMaxOutstandingBytes); + // loop until either it isn't an increase, or we successfully perform + // the swap + if (outstanding <= oldOutstanding + || Interlocked.CompareExchange( + ref _tallyBufferMaxOutstandingBytes, + outstanding, + oldOutstanding) == oldOutstanding) break; + } +#endif + } + + [Conditional("DEBUG")] + public static void OnBufferCompleted(int messageCount, int messageBytes) + { +#if DEBUG + Interlocked.Add(ref _tallyBufferMessageCount, messageCount); + Interlocked.Add(ref _tallyBufferMessageBytes, messageBytes); +#endif + } + + [Conditional("DEBUG")] + public static void OnBufferCapacity(int bytes) + { +#if DEBUG + Interlocked.Add(ref _tallyBufferTotalBytes, bytes); +#endif + } + + [Conditional("DEBUG")] + public static void OnBufferPinned() + { +#if DEBUG + Interlocked.Increment(ref _tallyBufferPinCount); +#endif + } + + [Conditional("DEBUG")] + public static void OnBufferLeaked() + { +#if DEBUG + Interlocked.Increment(ref _tallyBufferLeakCount); +#endif + } + #if DEBUG private static void EstimatedMovingRangeAverage(ref long field, long value) { @@ -66,5 +138,15 @@ private static void EstimatedMovingRangeAverage(ref long field, long value) public long DiscardAverage { get; } = Interlocked.Exchange(ref _tallyDiscardAverage, 32); public int DiscardFullCount { get; } = Interlocked.Exchange(ref _tallyDiscardFullCount, 0); public int DiscardPartialCount { get; } = Interlocked.Exchange(ref _tallyDiscardPartialCount, 0); + + public int BufferCreatedCount { get; } = Interlocked.Exchange(ref _tallyBufferCreatedCount, 0); + public int BufferRecycledCount { get; } = Interlocked.Exchange(ref _tallyBufferRecycledCount, 0); + public long BufferRecycledBytes { get; } = Interlocked.Exchange(ref _tallyBufferRecycledBytes, 0); + public long BufferMaxOutstandingBytes { get; } = Interlocked.Exchange(ref _tallyBufferMaxOutstandingBytes, 0); + public int BufferMessageCount { get; } = Interlocked.Exchange(ref _tallyBufferMessageCount, 0); + public long BufferMessageBytes { get; } = Interlocked.Exchange(ref _tallyBufferMessageBytes, 0); + public long BufferTotalBytes { get; } = Interlocked.Exchange(ref _tallyBufferTotalBytes, 0); + public int BufferPinCount { get; } = Interlocked.Exchange(ref _tallyBufferPinCount, 0); + public int BufferLeakCount { get; } = Interlocked.Exchange(ref _tallyBufferLeakCount, 0); #endif } diff --git a/src/RESPite/Internal/SynchronizedBlockBufferSerializer.cs b/src/RESPite/Internal/SynchronizedBlockBufferSerializer.cs new file mode 100644 index 000000000..1b121fd8b --- /dev/null +++ b/src/RESPite/Internal/SynchronizedBlockBufferSerializer.cs @@ -0,0 +1,123 @@ +using System.Buffers; +using RESPite.Messages; + +namespace RESPite.Internal; + +internal partial class BlockBufferSerializer +{ + internal static BlockBufferSerializer Create(bool retainChain = false) => + new SynchronizedBlockBufferSerializer(retainChain); + + /// + /// Used for things like . + /// + private sealed class SynchronizedBlockBufferSerializer(bool retainChain) : BlockBufferSerializer + { + private bool _discardDuringClear; + + private protected override BlockBuffer? Buffer { get; set; } // simple per-instance auto-prop + + /* + // use lock-based synchronization + public override ReadOnlyMemory Serialize( + RespCommandMap? commandMap, + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter) + { + bool haveLock = false; + try // note that "lock" unrolls to something very similar; we're not adding anything unusual here + { + // in reality, we *expect* people to not attempt to use batches concurrently, *and* + // we expect serialization to be very fast, but: out of an abundance of caution, + // add a timeout - just to avoid surprises (since people can write their own formatters) + Monitor.TryEnter(this, LockTimeout, ref haveLock); + if (!haveLock) ThrowTimeout(); + return base.Serialize(commandMap, command, in request, formatter); + } + finally + { + if (haveLock) Monitor.Exit(this); + } + + static void ThrowTimeout() => throw new TimeoutException( + "It took a long time to get access to the serialization-buffer. This is very odd - please " + + "ask on GitHub, but *as a guess*, you have a custom RESP formatter that is really slow *and* " + + "you are using concurrent access to a RESP batch / transaction."); + } + */ + + private static readonly TimeSpan LockTimeout = TimeSpan.FromSeconds(5); + + private Segment? _head, _tail; + + protected override bool ClaimSegment(ReadOnlyMemory segment) + { + if (retainChain & !_discardDuringClear) + { + if (_head is null) + { + _head = _tail = new Segment(segment); + } + else + { + _tail = new Segment(segment, _tail); + } + + // note we don't need to increment the ref-count; because of this "true" + return true; + } + + return false; + } + + internal override ReadOnlySequence Flush() + { + if (_head is null) + { + // at worst, single-segment - we can skip the alloc + return new(BlockBuffer.RetainCurrent(this)); + } + + // otherwise, flush everything *keeping the chain* + ClearWithDiscard(discard: false); + ReadOnlySequence seq = new(_head, 0, _tail!, _tail!.Length); + _head = _tail = null; + return seq; + } + + public override void Clear() + { + ClearWithDiscard(discard: true); + _head = _tail = null; + } + + private void ClearWithDiscard(bool discard) + { + try + { + _discardDuringClear = discard; + base.Clear(); + } + finally + { + _discardDuringClear = false; + } + } + + private sealed class Segment : ReadOnlySequenceSegment + { + public Segment(ReadOnlyMemory memory, Segment? previous = null) + { + Memory = memory; + if (previous is not null) + { + previous.Next = this; + RunningIndex = previous.RunningIndex + previous.Length; + } + } + + public int Length => Memory.Length; + } + } +} diff --git a/src/RESPite/Internal/ThreadLocalBlockBufferSerializer.cs b/src/RESPite/Internal/ThreadLocalBlockBufferSerializer.cs new file mode 100644 index 000000000..1c1895ff4 --- /dev/null +++ b/src/RESPite/Internal/ThreadLocalBlockBufferSerializer.cs @@ -0,0 +1,21 @@ +namespace RESPite.Internal; + +internal partial class BlockBufferSerializer +{ + internal static BlockBufferSerializer Shared => ThreadLocalBlockBufferSerializer.Instance; + private sealed class ThreadLocalBlockBufferSerializer : BlockBufferSerializer + { + private ThreadLocalBlockBufferSerializer() { } + public static readonly ThreadLocalBlockBufferSerializer Instance = new(); + + [ThreadStatic] + // side-step concurrency using per-thread semantics + private static BlockBuffer? _perTreadBuffer; + + private protected override BlockBuffer? Buffer + { + get => _perTreadBuffer; + set => _perTreadBuffer = value; + } + } +} diff --git a/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt b/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt index d3340f383..123e6c86e 100644 --- a/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt +++ b/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt @@ -1,26 +1,26 @@ #nullable enable -const RESPite.Buffers.CycleBuffer.GetAnything = 0 -> int -const RESPite.Buffers.CycleBuffer.GetFullPagesOnly = -1 -> int -RESPite.Buffers.CycleBuffer -RESPite.Buffers.CycleBuffer.Commit(int count) -> void -RESPite.Buffers.CycleBuffer.CommittedIsEmpty.get -> bool -RESPite.Buffers.CycleBuffer.CycleBuffer() -> void -RESPite.Buffers.CycleBuffer.DiscardCommitted(int count) -> void -RESPite.Buffers.CycleBuffer.DiscardCommitted(long count) -> void -RESPite.Buffers.CycleBuffer.GetAllCommitted() -> System.Buffers.ReadOnlySequence -RESPite.Buffers.CycleBuffer.GetCommittedLength() -> long -RESPite.Buffers.CycleBuffer.GetUncommittedMemory(int hint = 0) -> System.Memory -RESPite.Buffers.CycleBuffer.GetUncommittedSpan(int hint = 0) -> System.Span -RESPite.Buffers.CycleBuffer.PageSize.get -> int -RESPite.Buffers.CycleBuffer.Pool.get -> System.Buffers.MemoryPool! -RESPite.Buffers.CycleBuffer.Release() -> void -RESPite.Buffers.CycleBuffer.TryGetCommitted(out System.ReadOnlySpan span) -> bool -RESPite.Buffers.CycleBuffer.TryGetFirstCommittedMemory(int minBytes, out System.ReadOnlyMemory memory) -> bool -RESPite.Buffers.CycleBuffer.TryGetFirstCommittedSpan(int minBytes, out System.ReadOnlySpan span) -> bool -RESPite.Buffers.CycleBuffer.UncommittedAvailable.get -> int -RESPite.Buffers.CycleBuffer.Write(in System.Buffers.ReadOnlySequence value) -> void -RESPite.Buffers.CycleBuffer.Write(System.ReadOnlySpan value) -> void -static RESPite.Buffers.CycleBuffer.Create(System.Buffers.MemoryPool? pool = null, int pageSize = 8192) -> RESPite.Buffers.CycleBuffer +[SER004]const RESPite.Buffers.CycleBuffer.GetAnything = 0 -> int +[SER004]const RESPite.Buffers.CycleBuffer.GetFullPagesOnly = -1 -> int +[SER004]RESPite.Buffers.CycleBuffer +[SER004]RESPite.Buffers.CycleBuffer.Commit(int count) -> void +[SER004]RESPite.Buffers.CycleBuffer.CommittedIsEmpty.get -> bool +[SER004]RESPite.Buffers.CycleBuffer.CycleBuffer() -> void +[SER004]RESPite.Buffers.CycleBuffer.DiscardCommitted(int count) -> void +[SER004]RESPite.Buffers.CycleBuffer.DiscardCommitted(long count) -> void +[SER004]RESPite.Buffers.CycleBuffer.GetAllCommitted() -> System.Buffers.ReadOnlySequence +[SER004]RESPite.Buffers.CycleBuffer.GetCommittedLength() -> long +[SER004]RESPite.Buffers.CycleBuffer.GetUncommittedMemory(int hint = 0) -> System.Memory +[SER004]RESPite.Buffers.CycleBuffer.GetUncommittedSpan(int hint = 0) -> System.Span +[SER004]RESPite.Buffers.CycleBuffer.PageSize.get -> int +[SER004]RESPite.Buffers.CycleBuffer.Pool.get -> System.Buffers.MemoryPool! +[SER004]RESPite.Buffers.CycleBuffer.Release() -> void +[SER004]RESPite.Buffers.CycleBuffer.TryGetCommitted(out System.ReadOnlySpan span) -> bool +[SER004]RESPite.Buffers.CycleBuffer.TryGetFirstCommittedMemory(int minBytes, out System.ReadOnlyMemory memory) -> bool +[SER004]RESPite.Buffers.CycleBuffer.TryGetFirstCommittedSpan(int minBytes, out System.ReadOnlySpan span) -> bool +[SER004]RESPite.Buffers.CycleBuffer.UncommittedAvailable.get -> int +[SER004]RESPite.Buffers.CycleBuffer.Write(in System.Buffers.ReadOnlySequence value) -> void +[SER004]RESPite.Buffers.CycleBuffer.Write(System.ReadOnlySpan value) -> void +[SER004]static RESPite.Buffers.CycleBuffer.Create(System.Buffers.MemoryPool? pool = null, int pageSize = 8192) -> RESPite.Buffers.CycleBuffer [SER004]const RESPite.Messages.RespScanState.MinBytes = 3 -> int [SER004]override RESPite.Messages.RespScanState.Equals(object? obj) -> bool [SER004]override RESPite.Messages.RespScanState.GetHashCode() -> int diff --git a/src/RESPite/RESPite.csproj b/src/RESPite/RESPite.csproj index abde624b2..7e93abb6f 100644 --- a/src/RESPite/RESPite.csproj +++ b/src/RESPite/RESPite.csproj @@ -27,6 +27,15 @@ RespReader.cs + + BlockBufferSerializer.cs + + + BlockBufferSerializer.cs + + + BlockBufferSerializer.cs + diff --git a/src/StackExchange.Redis/CommandBytes.cs b/src/StackExchange.Redis/CommandBytes.cs index 19a69549b..da6e3df6a 100644 --- a/src/StackExchange.Redis/CommandBytes.cs +++ b/src/StackExchange.Redis/CommandBytes.cs @@ -18,7 +18,7 @@ internal static unsafe CommandBytes TrimToFit(string value) fixed (char* c = value) { byte* b = stackalloc byte[ChunkLength * sizeof(ulong)]; - var encoder = PhysicalConnection.GetPerThreadEncoder(); + var encoder = MessageWriter.GetPerThreadEncoder(); encoder.Convert(c, value.Length, b, MaxLength, true, out var maxLen, out _, out var isComplete); if (!isComplete) maxLen--; return new CommandBytes(value.Substring(0, maxLen)); diff --git a/src/StackExchange.Redis/Condition.cs b/src/StackExchange.Redis/Condition.cs index 861abbbd2..1b5bcced4 100644 --- a/src/StackExchange.Redis/Condition.cs +++ b/src/StackExchange.Redis/Condition.cs @@ -432,26 +432,26 @@ public ConditionMessage(Condition condition, int db, CommandFlags flags, RedisCo this.value4 = value4; // note no assert here } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { if (value.IsNull) { - physical.WriteHeader(command, 1); - physical.Write(Key); + writer.WriteHeader(command, 1); + writer.Write(Key); } else { - physical.WriteHeader(command, value1.IsNull ? 2 : value2.IsNull ? 3 : value3.IsNull ? 4 : value4.IsNull ? 5 : 6); - physical.Write(Key); - physical.WriteBulkString(value); + writer.WriteHeader(command, value1.IsNull ? 2 : value2.IsNull ? 3 : value3.IsNull ? 4 : value4.IsNull ? 5 : 6); + writer.Write(Key); + writer.WriteBulkString(value); if (!value1.IsNull) - physical.WriteBulkString(value1); + writer.WriteBulkString(value1); if (!value2.IsNull) - physical.WriteBulkString(value2); + writer.WriteBulkString(value2); if (!value3.IsNull) - physical.WriteBulkString(value3); + writer.WriteBulkString(value3); if (!value4.IsNull) - physical.WriteBulkString(value4); + writer.WriteBulkString(value4); } } public override int ArgCount => value.IsNull ? 1 : value1.IsNull ? 2 : value2.IsNull ? 3 : value3.IsNull ? 4 : value4.IsNull ? 5 : 6; diff --git a/src/StackExchange.Redis/Configuration/LoggingTunnel.cs b/src/StackExchange.Redis/Configuration/LoggingTunnel.cs index 3a05fab11..2cdf1f418 100644 --- a/src/StackExchange.Redis/Configuration/LoggingTunnel.cs +++ b/src/StackExchange.Redis/Configuration/LoggingTunnel.cs @@ -3,17 +3,13 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; -using System.IO.Pipelines; using System.Net; using System.Net.Security; using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; -using Pipelines.Sockets.Unofficial; -using Pipelines.Sockets.Unofficial.Arenas; using RESPite.Buffers; -using RESPite.Internal; using RESPite.Messages; using static StackExchange.Redis.PhysicalConnection; @@ -30,13 +26,13 @@ public abstract class LoggingTunnel : Tunnel private readonly bool _ssl; private readonly Tunnel? _tail; - private sealed class StreamRespReader(Stream source, bool isInbound) : IDisposable + internal sealed class StreamRespReader(Stream source, bool isInbound) : IDisposable { private CycleBuffer _readBuffer = CycleBuffer.Create(); private RespScanState _state; private bool _reading, _disposed; // we need to track the state of the reader to avoid releasing the buffer while it's in use - public bool TryTakeOne(out ContextualRedisResult result, bool withData = true) + internal bool TryTakeOne(out ContextualRedisResult result, bool withData = true) { var fullBuffer = _readBuffer.GetAllCommitted(); var newData = fullBuffer.Slice(_state.TotalBytes); @@ -266,7 +262,7 @@ public static async Task ValidateAsync(Stream stream) return await reader.ValidateAsync(); } - private readonly struct ContextualRedisResult + internal readonly struct ContextualRedisResult { public readonly RedisResult? Result; public readonly bool IsOutOfBand; diff --git a/src/StackExchange.Redis/ConnectionMultiplexer.cs b/src/StackExchange.Redis/ConnectionMultiplexer.cs index 0c6148923..bc46c1520 100644 --- a/src/StackExchange.Redis/ConnectionMultiplexer.cs +++ b/src/StackExchange.Redis/ConnectionMultiplexer.cs @@ -350,7 +350,7 @@ internal void CheckMessage(Message message) } // using >= here because we will be adding 1 for the command itself (which is an argument for the purposes of the multi-bulk protocol) - if (message.ArgCount >= PhysicalConnection.REDIS_MAX_ARGS) + if (message.ArgCount >= MessageWriter.REDIS_MAX_ARGS) { throw ExceptionFactory.TooManyArgs(message.CommandAndKey, message.ArgCount); } diff --git a/src/StackExchange.Redis/ExceptionFactory.cs b/src/StackExchange.Redis/ExceptionFactory.cs index 3cfb0268c..3e31933fe 100644 --- a/src/StackExchange.Redis/ExceptionFactory.cs +++ b/src/StackExchange.Redis/ExceptionFactory.cs @@ -28,7 +28,7 @@ internal static Exception CommandDisabled(string command) => new RedisCommandException("This operation has been disabled in the command-map and cannot be used: " + command); internal static Exception TooManyArgs(string command, int argCount) - => new RedisCommandException($"This operation would involve too many arguments ({argCount + 1} vs the redis limit of {PhysicalConnection.REDIS_MAX_ARGS}): {command}"); + => new RedisCommandException($"This operation would involve too many arguments ({argCount + 1} vs the redis limit of {MessageWriter.REDIS_MAX_ARGS}): {command}"); internal static Exception ConnectionFailure(bool includeDetail, ConnectionFailureType failureType, string message, ServerEndPoint? server) { diff --git a/src/StackExchange.Redis/Expiration.cs b/src/StackExchange.Redis/Expiration.cs index e04094358..738b0b111 100644 --- a/src/StackExchange.Redis/Expiration.cs +++ b/src/StackExchange.Redis/Expiration.cs @@ -244,7 +244,7 @@ private static void ThrowMode(ExpirationMode mode) => _ => 2, }; - internal void WriteTo(PhysicalConnection physical) + internal void WriteTo(in MessageWriter writer) { var mode = Mode; switch (Mode) @@ -252,13 +252,13 @@ internal void WriteTo(PhysicalConnection physical) case ExpirationMode.Default or ExpirationMode.NotUsed: break; case ExpirationMode.KeepTtl: - physical.WriteBulkString("KEEPTTL"u8); + writer.WriteBulkString("KEEPTTL"u8); break; case ExpirationMode.Persist: - physical.WriteBulkString("PERSIST"u8); + writer.WriteBulkString("PERSIST"u8); break; default: - physical.WriteBulkString(mode switch + writer.WriteBulkString(mode switch { ExpirationMode.RelativeSeconds => "EX"u8, ExpirationMode.RelativeMilliseconds => "PX"u8, @@ -266,7 +266,7 @@ internal void WriteTo(PhysicalConnection physical) ExpirationMode.AbsoluteMilliseconds => "PXAT"u8, _ => default, }); - physical.WriteBulkString(Value); + writer.WriteBulkString(Value); break; } } diff --git a/src/StackExchange.Redis/HotKeys.StartMessage.cs b/src/StackExchange.Redis/HotKeys.StartMessage.cs index c9f0bc371..0265f34ab 100644 --- a/src/StackExchange.Redis/HotKeys.StartMessage.cs +++ b/src/StackExchange.Redis/HotKeys.StartMessage.cs @@ -13,7 +13,7 @@ internal sealed class HotKeysStartMessage( long sampleRatio, int[]? slots) : Message(-1, flags, RedisCommand.HOTKEYS) { - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { /* HOTKEYS START @@ -23,41 +23,41 @@ [DURATION duration] [SAMPLE ratio] [SLOTS count slot…] */ - physical.WriteHeader(Command, ArgCount); - physical.WriteBulkString("START"u8); - physical.WriteBulkString("METRICS"u8); + writer.WriteHeader(Command, ArgCount); + writer.WriteBulkString("START"u8); + writer.WriteBulkString("METRICS"u8); var metricCount = 0; if ((metrics & HotKeysMetrics.Cpu) != 0) metricCount++; if ((metrics & HotKeysMetrics.Network) != 0) metricCount++; - physical.WriteBulkString(metricCount); - if ((metrics & HotKeysMetrics.Cpu) != 0) physical.WriteBulkString("CPU"u8); - if ((metrics & HotKeysMetrics.Network) != 0) physical.WriteBulkString("NET"u8); + writer.WriteBulkString(metricCount); + if ((metrics & HotKeysMetrics.Cpu) != 0) writer.WriteBulkString("CPU"u8); + if ((metrics & HotKeysMetrics.Network) != 0) writer.WriteBulkString("NET"u8); if (count != 0) { - physical.WriteBulkString("COUNT"u8); - physical.WriteBulkString(count); + writer.WriteBulkString("COUNT"u8); + writer.WriteBulkString(count); } if (duration != TimeSpan.Zero) { - physical.WriteBulkString("DURATION"u8); - physical.WriteBulkString(Math.Ceiling(duration.TotalSeconds)); + writer.WriteBulkString("DURATION"u8); + writer.WriteBulkString(Math.Ceiling(duration.TotalSeconds)); } if (sampleRatio != 1) { - physical.WriteBulkString("SAMPLE"u8); - physical.WriteBulkString(sampleRatio); + writer.WriteBulkString("SAMPLE"u8); + writer.WriteBulkString(sampleRatio); } if (slots is { Length: > 0 }) { - physical.WriteBulkString("SLOTS"u8); - physical.WriteBulkString(slots.Length); + writer.WriteBulkString("SLOTS"u8); + writer.WriteBulkString(slots.Length); foreach (var slot in slots) { - physical.WriteBulkString(slot); + writer.WriteBulkString(slot); } } } diff --git a/src/StackExchange.Redis/Message.ValueCondition.cs b/src/StackExchange.Redis/Message.ValueCondition.cs index 53ddc651b..09e1c504e 100644 --- a/src/StackExchange.Redis/Message.ValueCondition.cs +++ b/src/StackExchange.Redis/Message.ValueCondition.cs @@ -22,11 +22,11 @@ private sealed class KeyConditionMessage( public override int ArgCount => 1 + _when.TokenCount; - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, ArgCount); - physical.Write(Key); - _when.WriteTo(physical); + writer.WriteHeader(Command, ArgCount); + writer.Write(Key); + _when.WriteTo(writer); } } @@ -46,13 +46,13 @@ private sealed class KeyValueExpiryConditionMessage( public override int ArgCount => 2 + _expiry.TokenCount + _when.TokenCount; - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, ArgCount); - physical.Write(Key); - physical.WriteBulkString(_value); - _expiry.WriteTo(physical); - _when.WriteTo(physical); + writer.WriteHeader(Command, ArgCount); + writer.Write(Key); + writer.WriteBulkString(_value); + _expiry.WriteTo(writer); + _when.WriteTo(writer); } } } diff --git a/src/StackExchange.Redis/Message.cs b/src/StackExchange.Redis/Message.cs index 5d5a6f050..b4fa5c81f 100644 --- a/src/StackExchange.Redis/Message.cs +++ b/src/StackExchange.Redis/Message.cs @@ -1,12 +1,16 @@ using System; +using System.Buffers; using System.Buffers.Binary; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Text; using System.Threading; using Microsoft.Extensions.Logging; +using RESPite.Buffers; +using RESPite.Internal; using RESPite.Messages; using StackExchange.Redis.Profiling; @@ -35,15 +39,15 @@ private LoggingMessage(ILogger log, Message tail) : base(tail.Db, tail.Flags, ta public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) => tail.GetHashSlot(serverSelectionStrategy); - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { try { - var bridge = physical.BridgeCouldBeNull; + var bridge = writer.BridgeCouldBeNull; log?.LogTrace($"{bridge?.Name}: Writing: {tail.CommandAndKey}"); } catch { } - tail.WriteTo(physical); + tail.WriteTo(writer); } public override int ArgCount => tail.ArgCount; @@ -782,19 +786,45 @@ internal void SetSource(IResultBox resultBox, ResultProcessor? resultPr this.resultProcessor = resultProcessor; } - protected abstract void WriteImpl(PhysicalConnection physical); + internal void WriteTo(in MessageWriter writer) => WriteImpl(in writer); + protected abstract void WriteImpl(in MessageWriter writer); + + internal string GetRespString(PhysicalConnection connection) + { + MessageWriter writer = new MessageWriter(connection); + try + { + WriteImpl(in writer); + var bytes = writer.Flush(); + string s = Encoding.UTF8.GetString(bytes.Span); + MessageWriter.Release(bytes); + return s; + } + finally + { + writer.Revert(); + } + } internal void WriteTo(PhysicalConnection physical) { + MessageWriter writer = new MessageWriter(physical); try { - WriteImpl(physical); + WriteImpl(in writer); + var bytes = writer.Flush(); + physical.WriteDirect(bytes); + MessageWriter.Release(bytes); } catch (Exception ex) when (ex is not RedisCommandException) // these have specific meaning; don't wrap { physical?.OnInternalError(ex); Fail(ConnectionFailureType.InternalFailure, ex, null, physical?.BridgeCouldBeNull?.Multiplexer); } + finally + { + writer.Revert(); + } } private static ReadOnlySpan ChecksumTemplate => "$4\r\nXXXX\r\n"u8; @@ -802,21 +832,30 @@ internal void WriteTo(PhysicalConnection physical) internal void WriteHighIntegrityChecksumRequest(PhysicalConnection physical) { Debug.Assert(IsHighIntegrity, "should only be used for high-integrity"); + var writer = new MessageWriter(physical); try { - physical.WriteHeader(RedisCommand.ECHO, 1); // use WriteHeader to allow command-rewrite + writer.WriteHeader(RedisCommand.ECHO, 1); // use WriteHeader to allow command-rewrite Span chk = stackalloc byte[10]; Debug.Assert(ChecksumTemplate.Length == chk.Length, "checksum template length error"); ChecksumTemplate.CopyTo(chk); BinaryPrimitives.WriteUInt32LittleEndian(chk.Slice(4, 4), _highIntegrityToken); - physical.WriteRaw(chk); + writer.WriteRaw(chk); + + var memory = writer.Flush(); + physical.WriteDirect(memory); + MessageWriter.Release(memory); } catch (Exception ex) { physical?.OnInternalError(ex); Fail(ConnectionFailureType.InternalFailure, ex, null, physical?.BridgeCouldBeNull?.Multiplexer); } + finally + { + writer.Revert(); + } } internal static Message CreateHello(int protocolVersion, string? username, string? password, string? clientName, CommandFlags flags) @@ -848,20 +887,20 @@ public override int ArgCount return count; } } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, ArgCount); - physical.WriteBulkString(_protocolVersion); + writer.WriteHeader(Command, ArgCount); + writer.WriteBulkString(_protocolVersion); if (!string.IsNullOrWhiteSpace(_password)) { - physical.WriteBulkString("AUTH"u8); - physical.WriteBulkString(string.IsNullOrWhiteSpace(_username) ? RedisLiterals.@default : _username); - physical.WriteBulkString(_password); + writer.WriteBulkString("AUTH"u8); + writer.WriteBulkString(string.IsNullOrWhiteSpace(_username) ? RedisLiterals.@default : _username); + writer.WriteBulkString(_password); } if (!string.IsNullOrWhiteSpace(_clientName)) { - physical.WriteBulkString("SETNAME"u8); - physical.WriteBulkString(_clientName); + writer.WriteBulkString("SETNAME"u8); + writer.WriteBulkString(_clientName); } } } @@ -902,10 +941,10 @@ private sealed class CommandChannelMessage : CommandChannelBase public CommandChannelMessage(int db, CommandFlags flags, RedisCommand command, in RedisChannel channel) : base(db, flags, command, channel) { } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 1); - physical.Write(Channel); + writer.WriteHeader(Command, 1); + writer.Write(Channel); } public override int ArgCount => 1; } @@ -920,11 +959,11 @@ public CommandChannelValueMessage(int db, CommandFlags flags, RedisCommand comma this.value = value; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 2); - physical.Write(Channel); - physical.WriteBulkString(value); + writer.WriteHeader(Command, 2); + writer.Write(Channel); + writer.WriteBulkString(value); } public override int ArgCount => 2; } @@ -947,12 +986,12 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) return serverSelectionStrategy.CombineSlot(slot, key2); } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 3); - physical.Write(Key); - physical.Write(key1); - physical.Write(key2); + writer.WriteHeader(Command, 3); + writer.Write(Key); + writer.Write(key1); + writer.Write(key2); } public override int ArgCount => 3; } @@ -972,11 +1011,11 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) return serverSelectionStrategy.CombineSlot(slot, key1); } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 2); - physical.Write(Key); - physical.Write(key1); + writer.WriteHeader(Command, 2); + writer.Write(Key); + writer.Write(key1); } public override int ArgCount => 2; } @@ -1003,13 +1042,13 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) return slot; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(command, keys.Length + 1); - physical.Write(Key); + writer.WriteHeader(command, keys.Length + 1); + writer.Write(Key); for (int i = 0; i < keys.Length; i++) { - physical.Write(keys[i]); + writer.Write(keys[i]); } } public override int ArgCount => keys.Length + 1; @@ -1024,12 +1063,12 @@ public CommandKeyKeyValueMessage(int db, CommandFlags flags, RedisCommand comman this.value = value; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 3); - physical.Write(Key); - physical.Write(key1); - physical.WriteBulkString(value); + writer.WriteHeader(Command, 3); + writer.Write(Key); + writer.Write(key1); + writer.WriteBulkString(value); } public override int ArgCount => 3; @@ -1039,10 +1078,10 @@ private sealed class CommandKeyMessage : CommandKeyBase { public CommandKeyMessage(int db, CommandFlags flags, RedisCommand command, in RedisKey key) : base(db, flags, command, key) { } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 1); - physical.Write(Key); + writer.WriteHeader(Command, 1); + writer.Write(Key); } public override int ArgCount => 1; } @@ -1059,12 +1098,12 @@ public CommandValuesMessage(int db, CommandFlags flags, RedisCommand command, Re this.values = values; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(command, values.Length); + writer.WriteHeader(command, values.Length); for (int i = 0; i < values.Length; i++) { - physical.WriteBulkString(values[i]); + writer.WriteBulkString(values[i]); } } public override int ArgCount => values.Length; @@ -1092,12 +1131,12 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) return slot; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(command, keys.Length); + writer.WriteHeader(command, keys.Length); for (int i = 0; i < keys.Length; i++) { - physical.Write(keys[i]); + writer.Write(keys[i]); } } public override int ArgCount => keys.Length; @@ -1112,11 +1151,11 @@ public CommandKeyValueMessage(int db, CommandFlags flags, RedisCommand command, this.value = value; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 2); - physical.Write(Key); - physical.WriteBulkString(value); + writer.WriteHeader(Command, 2); + writer.Write(Key); + writer.WriteBulkString(value); } public override int ArgCount => 2; } @@ -1142,12 +1181,12 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) return serverSelectionStrategy.CombineSlot(slot, key1); } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, values.Length + 2); - physical.Write(Key); - for (int i = 0; i < values.Length; i++) physical.WriteBulkString(values[i]); - physical.Write(key1); + writer.WriteHeader(Command, values.Length + 2); + writer.Write(Key); + for (int i = 0; i < values.Length; i++) writer.WriteBulkString(values[i]); + writer.Write(key1); } public override int ArgCount => values.Length + 2; } @@ -1164,11 +1203,11 @@ public CommandKeyValuesMessage(int db, CommandFlags flags, RedisCommand command, this.values = values; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, values.Length + 1); - physical.Write(Key); - for (int i = 0; i < values.Length; i++) physical.WriteBulkString(values[i]); + writer.WriteHeader(Command, values.Length + 1); + writer.Write(Key); + for (int i = 0; i < values.Length; i++) writer.WriteBulkString(values[i]); } public override int ArgCount => values.Length + 1; } @@ -1189,12 +1228,12 @@ public CommandKeyKeyValuesMessage(int db, CommandFlags flags, RedisCommand comma this.values = values; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, values.Length + 2); - physical.Write(Key); - physical.Write(key1); - for (int i = 0; i < values.Length; i++) physical.WriteBulkString(values[i]); + writer.WriteHeader(Command, values.Length + 2); + writer.Write(Key); + writer.Write(key1); + for (int i = 0; i < values.Length; i++) writer.WriteBulkString(values[i]); } public override int ArgCount => values.Length + 1; } @@ -1218,13 +1257,13 @@ public CommandKeyValueValueValuesMessage(int db, CommandFlags flags, RedisComman this.values = values; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, values.Length + 3); - physical.Write(Key); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - for (int i = 0; i < values.Length; i++) physical.WriteBulkString(values[i]); + writer.WriteHeader(Command, values.Length + 3); + writer.Write(Key); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + for (int i = 0; i < values.Length; i++) writer.WriteBulkString(values[i]); } public override int ArgCount => values.Length + 3; } @@ -1240,12 +1279,12 @@ public CommandKeyValueValueMessage(int db, CommandFlags flags, RedisCommand comm this.value1 = value1; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 3); - physical.Write(Key); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); + writer.WriteHeader(Command, 3); + writer.Write(Key); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); } public override int ArgCount => 3; } @@ -1263,13 +1302,13 @@ public CommandKeyValueValueValueMessage(int db, CommandFlags flags, RedisCommand this.value2 = value2; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 4); - physical.Write(Key); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - physical.WriteBulkString(value2); + writer.WriteHeader(Command, 4); + writer.Write(Key); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + writer.WriteBulkString(value2); } public override int ArgCount => 4; } @@ -1289,14 +1328,14 @@ public CommandKeyValueValueValueValueMessage(int db, CommandFlags flags, RedisCo this.value3 = value3; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 5); - physical.Write(Key); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - physical.WriteBulkString(value2); - physical.WriteBulkString(value3); + writer.WriteHeader(Command, 5); + writer.Write(Key); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + writer.WriteBulkString(value2); + writer.WriteBulkString(value3); } public override int ArgCount => 5; } @@ -1318,15 +1357,15 @@ public CommandKeyValueValueValueValueValueMessage(int db, CommandFlags flags, Re this.value4 = value4; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 6); - physical.Write(Key); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - physical.WriteBulkString(value2); - physical.WriteBulkString(value3); - physical.WriteBulkString(value4); + writer.WriteHeader(Command, 6); + writer.Write(Key); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + writer.WriteBulkString(value2); + writer.WriteBulkString(value3); + writer.WriteBulkString(value4); } public override int ArgCount => 6; } @@ -1351,16 +1390,16 @@ public CommandKeyValueValueValueValueValueValueMessage(int db, CommandFlags flag this.value5 = value5; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, ArgCount); - physical.Write(Key); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - physical.WriteBulkString(value2); - physical.WriteBulkString(value3); - physical.WriteBulkString(value4); - physical.WriteBulkString(value5); + writer.WriteHeader(Command, ArgCount); + writer.Write(Key); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + writer.WriteBulkString(value2); + writer.WriteBulkString(value3); + writer.WriteBulkString(value4); + writer.WriteBulkString(value5); } public override int ArgCount => 7; } @@ -1387,17 +1426,17 @@ public CommandKeyValueValueValueValueValueValueValueMessage(int db, CommandFlags this.value6 = value6; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, ArgCount); - physical.Write(Key); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - physical.WriteBulkString(value2); - physical.WriteBulkString(value3); - physical.WriteBulkString(value4); - physical.WriteBulkString(value5); - physical.WriteBulkString(value6); + writer.WriteHeader(Command, ArgCount); + writer.Write(Key); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + writer.WriteBulkString(value2); + writer.WriteBulkString(value3); + writer.WriteBulkString(value4); + writer.WriteBulkString(value5); + writer.WriteBulkString(value6); } public override int ArgCount => 8; } @@ -1424,13 +1463,13 @@ public CommandKeyKeyValueValueMessage( this.value1 = value1; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, ArgCount); - physical.Write(Key); - physical.Write(key1); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); + writer.WriteHeader(Command, ArgCount); + writer.Write(Key); + writer.Write(key1); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); } public override int ArgCount => 4; @@ -1461,14 +1500,14 @@ public CommandKeyKeyValueValueValueMessage( this.value2 = value2; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, ArgCount); - physical.Write(Key); - physical.Write(key1); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - physical.WriteBulkString(value2); + writer.WriteHeader(Command, ArgCount); + writer.Write(Key); + writer.Write(key1); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + writer.WriteBulkString(value2); } public override int ArgCount => 5; @@ -1502,15 +1541,15 @@ public CommandKeyKeyValueValueValueValueMessage( this.value3 = value3; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, ArgCount); - physical.Write(Key); - physical.Write(key1); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - physical.WriteBulkString(value2); - physical.WriteBulkString(value3); + writer.WriteHeader(Command, ArgCount); + writer.Write(Key); + writer.Write(key1); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + writer.WriteBulkString(value2); + writer.WriteBulkString(value3); } public override int ArgCount => 6; @@ -1547,16 +1586,16 @@ public CommandKeyKeyValueValueValueValueValueMessage( this.value4 = value4; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, ArgCount); - physical.Write(Key); - physical.Write(key1); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - physical.WriteBulkString(value2); - physical.WriteBulkString(value3); - physical.WriteBulkString(value4); + writer.WriteHeader(Command, ArgCount); + writer.Write(Key); + writer.Write(key1); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + writer.WriteBulkString(value2); + writer.WriteBulkString(value3); + writer.WriteBulkString(value4); } public override int ArgCount => 7; @@ -1596,17 +1635,17 @@ public CommandKeyKeyValueValueValueValueValueValueMessage( this.value5 = value5; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, ArgCount); - physical.Write(Key); - physical.Write(key1); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - physical.WriteBulkString(value2); - physical.WriteBulkString(value3); - physical.WriteBulkString(value4); - physical.WriteBulkString(value5); + writer.WriteHeader(Command, ArgCount); + writer.Write(Key); + writer.Write(key1); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + writer.WriteBulkString(value2); + writer.WriteBulkString(value3); + writer.WriteBulkString(value4); + writer.WriteBulkString(value5); } public override int ArgCount => 8; @@ -1649,18 +1688,18 @@ public CommandKeyKeyValueValueValueValueValueValueValueMessage( this.value6 = value6; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, ArgCount); - physical.Write(Key); - physical.Write(key1); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - physical.WriteBulkString(value2); - physical.WriteBulkString(value3); - physical.WriteBulkString(value4); - physical.WriteBulkString(value5); - physical.WriteBulkString(value6); + writer.WriteHeader(Command, ArgCount); + writer.Write(Key); + writer.Write(key1); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + writer.WriteBulkString(value2); + writer.WriteBulkString(value3); + writer.WriteBulkString(value4); + writer.WriteBulkString(value5); + writer.WriteBulkString(value6); } public override int ArgCount => 9; @@ -1669,9 +1708,9 @@ protected override void WriteImpl(PhysicalConnection physical) private sealed class CommandMessage : Message { public CommandMessage(int db, CommandFlags flags, RedisCommand command) : base(db, flags, command) { } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 0); + writer.WriteHeader(Command, 0); } public override int ArgCount => 0; } @@ -1694,12 +1733,12 @@ public CommandSlotValuesMessage(int db, int slot, CommandFlags flags, RedisComma public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) => slot; - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(command, values.Length); + writer.WriteHeader(command, values.Length); for (int i = 0; i < values.Length; i++) { - physical.WriteBulkString(values[i]); + writer.WriteBulkString(values[i]); } } public override int ArgCount => values.Length; @@ -1725,29 +1764,29 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) ? (1 + (2 * values.Length) + expiry.TokenCount + (when is When.Exists or When.NotExists ? 1 : 0)) : (2 * values.Length); // MSET/MSETNX only support simple syntax - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { var cmd = Command; - physical.WriteHeader(cmd, ArgCount); + writer.WriteHeader(cmd, ArgCount); if (cmd == RedisCommand.MSETEX) // need count prefix { - physical.WriteBulkString(values.Length); + writer.WriteBulkString(values.Length); } for (int i = 0; i < values.Length; i++) { - physical.Write(values[i].Key); - physical.WriteBulkString(values[i].Value); + writer.Write(values[i].Key); + writer.WriteBulkString(values[i].Value); } if (cmd == RedisCommand.MSETEX) // allow expiry/mode tokens { - expiry.WriteTo(physical); + expiry.WriteTo(writer); switch (when) { case When.Exists: - physical.WriteBulkString("XX"u8); + writer.WriteBulkString("XX"u8); break; case When.NotExists: - physical.WriteBulkString("NX"u8); + writer.WriteBulkString("NX"u8); break; } } @@ -1764,11 +1803,11 @@ public CommandValueChannelMessage(int db, CommandFlags flags, RedisCommand comma this.value = value; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 2); - physical.WriteBulkString(value); - physical.Write(Channel); + writer.WriteHeader(Command, 2); + writer.WriteBulkString(value); + writer.Write(Channel); } public override int ArgCount => 2; } @@ -1789,11 +1828,11 @@ public override void AppendStormLog(StringBuilder sb) sb.Append(" (").Append((string?)value).Append(')'); } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 2); - physical.WriteBulkString(value); - physical.Write(Key); + writer.WriteHeader(Command, 2); + writer.WriteBulkString(value); + writer.Write(Key); } public override int ArgCount => 2; } @@ -1807,10 +1846,10 @@ public CommandValueMessage(int db, CommandFlags flags, RedisCommand command, in this.value = value; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 1); - physical.WriteBulkString(value); + writer.WriteHeader(Command, 1); + writer.WriteBulkString(value); } public override int ArgCount => 1; } @@ -1826,11 +1865,11 @@ public CommandValueValueMessage(int db, CommandFlags flags, RedisCommand command this.value1 = value1; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 2); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); + writer.WriteHeader(Command, 2); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); } public override int ArgCount => 2; } @@ -1848,12 +1887,12 @@ public CommandValueValueValueMessage(int db, CommandFlags flags, RedisCommand co this.value2 = value2; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 3); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - physical.WriteBulkString(value2); + writer.WriteHeader(Command, 3); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + writer.WriteBulkString(value2); } public override int ArgCount => 3; } @@ -1875,14 +1914,14 @@ public CommandValueValueValueValueValueMessage(int db, CommandFlags flags, Redis this.value4 = value4; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 5); - physical.WriteBulkString(value0); - physical.WriteBulkString(value1); - physical.WriteBulkString(value2); - physical.WriteBulkString(value3); - physical.WriteBulkString(value4); + writer.WriteHeader(Command, 5); + writer.WriteBulkString(value0); + writer.WriteBulkString(value1); + writer.WriteBulkString(value2); + writer.WriteBulkString(value3); + writer.WriteBulkString(value4); } public override int ArgCount => 5; } @@ -1893,10 +1932,10 @@ public SelectMessage(int db, CommandFlags flags) : base(db, flags, RedisCommand. { } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 1); - physical.WriteBulkString(Db); + writer.WriteHeader(Command, 1); + writer.WriteBulkString(Db); } public override int ArgCount => 1; } @@ -1908,7 +1947,7 @@ internal sealed class UnknownMessage : Message public static UnknownMessage Instance { get; } = new(); private UnknownMessage() : base(0, CommandFlags.None, RedisCommand.UNKNOWN) { } public override int ArgCount => 0; - protected override void WriteImpl(PhysicalConnection physical) => throw new InvalidOperationException("This message cannot be written"); + protected override void WriteImpl(in MessageWriter writer) => throw new InvalidOperationException("This message cannot be written"); } } } diff --git a/src/StackExchange.Redis/MessageWriter.cs b/src/StackExchange.Redis/MessageWriter.cs new file mode 100644 index 000000000..5c45d2abc --- /dev/null +++ b/src/StackExchange.Redis/MessageWriter.cs @@ -0,0 +1,558 @@ +using System; +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using RESPite.Internal; + +namespace StackExchange.Redis; + +internal readonly ref struct MessageWriter(PhysicalConnection connection) +{ + public PhysicalBridge? BridgeCouldBeNull => connection.BridgeCouldBeNull; + private readonly IBufferWriter _writer = BlockBufferSerializer.Shared; + + public ReadOnlyMemory Flush() => + BlockBufferSerializer.BlockBuffer.FinalizeMessage(BlockBufferSerializer.Shared); + + public void Revert() => BlockBufferSerializer.Shared.Revert(); + + public static void Release(ReadOnlyMemory memory) + { + if (MemoryMarshal.TryGetMemoryManager( + memory, out var block)) + { + block.Release(); + } + } + + public static void Release(in ReadOnlySequence request) => + BlockBufferSerializer.BlockBuffer.Release(in request); + + public void Write(in RedisKey key) + { + var val = key.KeyValue; + if (val is string s) + { + WriteUnifiedPrefixedString(_writer, key.KeyPrefix, s); + } + else + { + WriteUnifiedPrefixedBlob(_writer, key.KeyPrefix, (byte[]?)val); + } + } + + internal void Write(in RedisChannel channel) + => WriteUnifiedPrefixedBlob(_writer, channel.IgnoreChannelPrefix ? null : connection.ChannelPrefix, channel.Value); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void WriteBulkString(in RedisValue value) + => WriteBulkString(value, _writer); + + internal static void WriteBulkString(in RedisValue value, IBufferWriter writer) + { + switch (value.Type) + { + case RedisValue.StorageType.Null: + WriteUnifiedBlob(writer, (byte[]?)null); + break; + case RedisValue.StorageType.Int64: + WriteUnifiedInt64(writer, value.OverlappedValueInt64); + break; + case RedisValue.StorageType.UInt64: + WriteUnifiedUInt64(writer, value.OverlappedValueUInt64); + break; + case RedisValue.StorageType.Double: + WriteUnifiedDouble(writer, value.OverlappedValueDouble); + break; + case RedisValue.StorageType.String: + WriteUnifiedPrefixedString(writer, null, (string?)value); + break; + case RedisValue.StorageType.Raw: + WriteUnifiedSpan(writer, ((ReadOnlyMemory)value).Span); + break; + default: + throw new InvalidOperationException($"Unexpected {value.Type} value: '{value}'"); + } + } + + internal void WriteBulkString(ReadOnlySpan value) => WriteUnifiedSpan(_writer, value); + + internal const int + REDIS_MAX_ARGS = + 1024 * 1024; // there is a <= 1024*1024 max constraint inside redis itself: https://github.com/antirez/redis/blob/6c60526db91e23fb2d666fc52facc9a11780a2a3/src/networking.c#L1024 + + internal void WriteHeader(RedisCommand command, int arguments, CommandBytes commandBytes = default) + { + var bridge = connection.BridgeCouldBeNull ?? throw new ObjectDisposedException(connection.ToString()); + + if (command == RedisCommand.UNKNOWN) + { + // using >= here because we will be adding 1 for the command itself (which is an arg for the purposes of the multi-bulk protocol) + if (arguments >= REDIS_MAX_ARGS) throw ExceptionFactory.TooManyArgs(commandBytes.ToString(), arguments); + } + else + { + // using >= here because we will be adding 1 for the command itself (which is an arg for the purposes of the multi-bulk protocol) + if (arguments >= REDIS_MAX_ARGS) throw ExceptionFactory.TooManyArgs(command.ToString(), arguments); + + // for everything that isn't custom commands: ask the muxer for the actual bytes + commandBytes = bridge.Multiplexer.CommandMap.GetBytes(command); + } + + // in theory we should never see this; CheckMessage dealt with "regular" messages, and + // ExecuteMessage should have dealt with everything else + if (commandBytes.IsEmpty) throw ExceptionFactory.CommandDisabled(command); + + // *{argCount}\r\n = 3 + MaxInt32TextLen + // ${cmd-len}\r\n = 3 + MaxInt32TextLen + // {cmd}\r\n = 2 + commandBytes.Length + var span = _writer.GetSpan(commandBytes.Length + 8 + Format.MaxInt32TextLen + Format.MaxInt32TextLen); + span[0] = (byte)'*'; + + int offset = WriteRaw(span, arguments + 1, offset: 1); + + offset = AppendToSpanCommand(span, commandBytes, offset: offset); + + _writer.Advance(offset); + } + + internal static void WriteMultiBulkHeader(IBufferWriter writer, long count) + { + // *{count}\r\n = 3 + MaxInt32TextLen + var span = writer.GetSpan(3 + Format.MaxInt32TextLen); + span[0] = (byte)'*'; + int offset = WriteRaw(span, count, offset: 1); + writer.Advance(offset); + } + + private static ReadOnlySpan NullBulkString => "$-1\r\n"u8; + private static ReadOnlySpan EmptyBulkString => "$0\r\n\r\n"u8; + + internal static void WriteUnifiedPrefixedString(IBufferWriter writer, byte[]? prefix, string? value) + { + if (value == null) + { + // special case + writer.Write(NullBulkString); + } + else + { + // ${total-len}\r\n 3 + MaxInt32TextLen + // {prefix}{value}\r\n + int encodedLength = Encoding.UTF8.GetByteCount(value), + prefixLength = prefix?.Length ?? 0, + totalLength = prefixLength + encodedLength; + + if (totalLength == 0) + { + // special-case + writer.Write(EmptyBulkString); + } + else + { + var span = writer.GetSpan(3 + Format.MaxInt32TextLen); + span[0] = (byte)'$'; + int bytes = WriteRaw(span, totalLength, offset: 1); + writer.Advance(bytes); + + if (prefixLength != 0) writer.Write(prefix); + if (encodedLength != 0) WriteRaw(writer, value, encodedLength); + WriteCrlf(writer); + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int WriteCrlf(Span span, int offset) + { + span[offset++] = (byte)'\r'; + span[offset++] = (byte)'\n'; + return offset; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static void WriteCrlf(IBufferWriter writer) + { + var span = writer.GetSpan(2); + span[0] = (byte)'\r'; + span[1] = (byte)'\n'; + writer.Advance(2); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void WriteRaw(ReadOnlySpan value) => _writer.Write(value); + + internal static int WriteRaw(Span span, long value, bool withLengthPrefix = false, int offset = 0) + { + if (value >= 0 && value <= 9) + { + if (withLengthPrefix) + { + span[offset++] = (byte)'1'; + offset = WriteCrlf(span, offset); + } + + span[offset++] = (byte)((int)'0' + (int)value); + } + else if (value >= 10 && value < 100) + { + if (withLengthPrefix) + { + span[offset++] = (byte)'2'; + offset = WriteCrlf(span, offset); + } + + span[offset++] = (byte)((int)'0' + ((int)value / 10)); + span[offset++] = (byte)((int)'0' + ((int)value % 10)); + } + else if (value >= 100 && value < 1000) + { + int v = (int)value; + int units = v % 10; + v /= 10; + int tens = v % 10, hundreds = v / 10; + if (withLengthPrefix) + { + span[offset++] = (byte)'3'; + offset = WriteCrlf(span, offset); + } + + span[offset++] = (byte)((int)'0' + hundreds); + span[offset++] = (byte)((int)'0' + tens); + span[offset++] = (byte)((int)'0' + units); + } + else if (value < 0 && value >= -9) + { + if (withLengthPrefix) + { + span[offset++] = (byte)'2'; + offset = WriteCrlf(span, offset); + } + + span[offset++] = (byte)'-'; + span[offset++] = (byte)((int)'0' - (int)value); + } + else if (value <= -10 && value > -100) + { + if (withLengthPrefix) + { + span[offset++] = (byte)'3'; + offset = WriteCrlf(span, offset); + } + + value = -value; + span[offset++] = (byte)'-'; + span[offset++] = (byte)((int)'0' + ((int)value / 10)); + span[offset++] = (byte)((int)'0' + ((int)value % 10)); + } + else + { + // we're going to write it, but *to the wrong place* + var availableChunk = span.Slice(offset); + var formattedLength = Format.FormatInt64(value, availableChunk); + if (withLengthPrefix) + { + // now we know how large the prefix is: write the prefix, then write the value + var prefixLength = Format.FormatInt32(formattedLength, availableChunk); + offset += prefixLength; + offset = WriteCrlf(span, offset); + + availableChunk = span.Slice(offset); + var finalLength = Format.FormatInt64(value, availableChunk); + offset += finalLength; + Debug.Assert(finalLength == formattedLength); + } + else + { + offset += formattedLength; + } + } + + return WriteCrlf(span, offset); + } + + [ThreadStatic] + private static Encoder? s_PerThreadEncoder; + + internal static Encoder GetPerThreadEncoder() + { + var encoder = s_PerThreadEncoder; + if (encoder == null) + { + s_PerThreadEncoder = encoder = Encoding.UTF8.GetEncoder(); + } + else + { + encoder.Reset(); + } + + return encoder; + } + + internal static unsafe void WriteRaw(IBufferWriter writer, string value, int expectedLength) + { + const int MaxQuickEncodeSize = 512; + + fixed (char* cPtr = value) + { + int totalBytes; + if (expectedLength <= MaxQuickEncodeSize) + { + // encode directly in one hit + var span = writer.GetSpan(expectedLength); + fixed (byte* bPtr = span) + { + totalBytes = Encoding.UTF8.GetBytes( + cPtr, + value.Length, + bPtr, + expectedLength); + } + + writer.Advance(expectedLength); + } + else + { + // use an encoder in a loop + var encoder = GetPerThreadEncoder(); + int charsRemaining = value.Length, charOffset = 0; + totalBytes = 0; + + bool final = false; + while (true) + { + var span = writer + .GetSpan(5); // get *some* memory - at least enough for 1 character (but hopefully lots more) + + int charsUsed, bytesUsed; + bool completed; + fixed (byte* bPtr = span) + { + encoder.Convert( + cPtr + charOffset, + charsRemaining, + bPtr, + span.Length, + final, + out charsUsed, + out bytesUsed, + out completed); + } + + writer.Advance(bytesUsed); + totalBytes += bytesUsed; + charOffset += charsUsed; + charsRemaining -= charsUsed; + + if (charsRemaining <= 0) + { + if (charsRemaining < 0) throw new InvalidOperationException("String encode went negative"); + if (completed) break; // fine + if (final) throw new InvalidOperationException("String encode failed to complete"); + final = true; // flush the encoder to one more span, then exit + } + } + } + + if (totalBytes != expectedLength) throw new InvalidOperationException("String encode length check failure"); + } + } + + private static void WriteUnifiedPrefixedBlob(IBufferWriter writer, byte[]? prefix, byte[]? value) + { + // ${total-len}\r\n + // {prefix}{value}\r\n + if (prefix == null || prefix.Length == 0 || value == null) + { + // if no prefix, just use the non-prefixed version; + // even if prefixed, a null value writes as null, so can use the non-prefixed version + WriteUnifiedBlob(writer, value); + } + else + { + var span = writer.GetSpan(3 + + Format + .MaxInt32TextLen); // note even with 2 max-len, we're still in same text range + span[0] = (byte)'$'; + int bytes = WriteRaw(span, prefix.LongLength + value.LongLength, offset: 1); + writer.Advance(bytes); + + writer.Write(prefix); + writer.Write(value); + + span = writer.GetSpan(2); + WriteCrlf(span, 0); + writer.Advance(2); + } + } + + private static void WriteUnifiedInt64(IBufferWriter writer, long value) + { + // note from specification: A client sends to the Redis server a RESP Array consisting of just Bulk Strings. + // (i.e. we can't just send ":123\r\n", we need to send "$3\r\n123\r\n" + + // ${asc-len}\r\n = 4/5 (asc-len at most 2 digits) + // {asc}\r\n = MaxInt64TextLen + 2 + var span = writer.GetSpan(7 + Format.MaxInt64TextLen); + + span[0] = (byte)'$'; + var bytes = WriteRaw(span, value, withLengthPrefix: true, offset: 1); + writer.Advance(bytes); + } + + private static void WriteUnifiedUInt64(IBufferWriter writer, ulong value) + { + // note from specification: A client sends to the Redis server a RESP Array consisting of just Bulk Strings. + // (i.e. we can't just send ":123\r\n", we need to send "$3\r\n123\r\n" + Span valueSpan = stackalloc byte[Format.MaxInt64TextLen]; + + var len = Format.FormatUInt64(value, valueSpan); + // ${asc-len}\r\n = 4/5 (asc-len at most 2 digits) + // {asc}\r\n = {len} + 2 + var span = writer.GetSpan(7 + len); + span[0] = (byte)'$'; + int offset = WriteRaw(span, len, withLengthPrefix: false, offset: 1); + valueSpan.Slice(0, len).CopyTo(span.Slice(offset)); + offset += len; + offset = WriteCrlf(span, offset); + writer.Advance(offset); + } + + private static void WriteUnifiedDouble(IBufferWriter writer, double value) + { +#if NET8_0_OR_GREATER + Span valueSpan = stackalloc byte[Format.MaxDoubleTextLen]; + var len = Format.FormatDouble(value, valueSpan); + + // ${asc-len}\r\n = 4/5 (asc-len at most 2 digits) + // {asc}\r\n = {len} + 2 + var span = writer.GetSpan(7 + len); + span[0] = (byte)'$'; + int offset = WriteRaw(span, len, withLengthPrefix: false, offset: 1); + valueSpan.Slice(0, len).CopyTo(span.Slice(offset)); + offset += len; + offset = WriteCrlf(span, offset); + writer.Advance(offset); +#else + // fallback: drop to string + WriteUnifiedPrefixedString(writer, null, Format.ToString(value)); +#endif + } + + internal static void WriteInteger(IBufferWriter writer, long value) + { + // note: client should never write integer; only server does this + // :{asc}\r\n = MaxInt64TextLen + 3 + var span = writer.GetSpan(3 + Format.MaxInt64TextLen); + + span[0] = (byte)':'; + var bytes = WriteRaw(span, value, withLengthPrefix: false, offset: 1); + writer.Advance(bytes); + } + + private static void WriteUnifiedBlob(IBufferWriter writer, byte[]? value) + { + if (value is null) + { + // special case: + writer.Write(NullBulkString); + } + else + { + WriteUnifiedSpan(writer, new ReadOnlySpan(value)); + } + } + + private static void WriteUnifiedSpan(IBufferWriter writer, ReadOnlySpan value) + { + // ${len}\r\n = 3 + MaxInt32TextLen + // {value}\r\n = 2 + value.Length + const int MaxQuickSpanSize = 512; + if (value.Length == 0) + { + // special case: + writer.Write(EmptyBulkString); + } + else if (value.Length <= MaxQuickSpanSize) + { + var span = writer.GetSpan(5 + Format.MaxInt32TextLen + value.Length); + span[0] = (byte)'$'; + int bytes = AppendToSpan(span, value, 1); + writer.Advance(bytes); + } + else + { + // too big to guarantee can do in a single span + var span = writer.GetSpan(3 + Format.MaxInt32TextLen); + span[0] = (byte)'$'; + int bytes = WriteRaw(span, value.Length, offset: 1); + writer.Advance(bytes); + + writer.Write(value); + + WriteCrlf(writer); + } + } + + private static int AppendToSpanCommand(Span span, in CommandBytes value, int offset = 0) + { + span[offset++] = (byte)'$'; + int len = value.Length; + offset = WriteRaw(span, len, offset: offset); + value.CopyTo(span.Slice(offset, len)); + offset += value.Length; + return WriteCrlf(span, offset); + } + + private static int AppendToSpan(Span span, ReadOnlySpan value, int offset = 0) + { + offset = WriteRaw(span, value.Length, offset: offset); + value.CopyTo(span.Slice(offset, value.Length)); + offset += value.Length; + return WriteCrlf(span, offset); + } + + internal void WriteSha1AsHex(byte[]? value) + { + var writer = _writer; + if (value is null) + { + writer.Write(NullBulkString); + } + else if (value.Length == ResultProcessor.ScriptLoadProcessor.Sha1HashLength) + { + // $40\r\n = 5 + // {40 bytes}\r\n = 42 + var span = writer.GetSpan(47); + span[0] = (byte)'$'; + span[1] = (byte)'4'; + span[2] = (byte)'0'; + span[3] = (byte)'\r'; + span[4] = (byte)'\n'; + + int offset = 5; + for (int i = 0; i < value.Length; i++) + { + var b = value[i]; + span[offset++] = ToHexNibble(b >> 4); + span[offset++] = ToHexNibble(b & 15); + } + + span[offset++] = (byte)'\r'; + span[offset++] = (byte)'\n'; + + writer.Advance(offset); + } + else + { + throw new InvalidOperationException("Invalid SHA1 length: " + value.Length); + } + } + + internal static byte ToHexNibble(int value) + { + return value < 10 ? (byte)('0' + value) : (byte)('a' - 10 + value); + } +} diff --git a/src/StackExchange.Redis/PhysicalConnection.cs b/src/StackExchange.Redis/PhysicalConnection.cs index 29991d7ca..a6e1f4437 100644 --- a/src/StackExchange.Redis/PhysicalConnection.cs +++ b/src/StackExchange.Redis/PhysicalConnection.cs @@ -10,6 +10,7 @@ using System.Net.Security; using System.Net.Sockets; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Text; @@ -17,7 +18,6 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Pipelines.Sockets.Unofficial; -using Pipelines.Sockets.Unofficial.Arenas; using static StackExchange.Redis.Message; @@ -352,6 +352,8 @@ public Task FlushAsync() internal void SimulateConnectionFailure(SimulatedFailureType failureType) { + throw new NotImplementedException(nameof(SimulateConnectionFailure)); + /* var raiseFailed = false; if (connectionType == ConnectionType.Interactive) { @@ -383,6 +385,7 @@ internal void SimulateConnectionFailure(SimulatedFailureType failureType) { RecordConnectionFailed(ConnectionFailureType.SocketFailure); } + */ } public void RecordConnectionFailed( @@ -390,7 +393,7 @@ public void RecordConnectionFailed( Exception? innerException = null, [CallerMemberName] string? origin = null, bool isInitialConnect = false, - IDuplexPipe? connectingPipe = null) + Stream? connectingStream = null) { Exception? outerException = innerException; IdentifyFailureType(innerException, ref failureType); @@ -439,7 +442,8 @@ public void RecordConnectionFailed( // If the reason for the shutdown was we asked for the socket to die, don't log it as an error (only informational) var weAskedForThis = Volatile.Read(ref clientSentQuit) != 0; - var pipe = connectingPipe ?? _ioPipe; + /* + var pipe = connectingStream ?? _ioStream; if (pipe is SocketConnection sc) { exMessage.Append(" (").Append(sc.ShutdownKind); @@ -458,9 +462,14 @@ public void RecordConnectionFailed( if (sent == 0) { exMessage.Append(recd == 0 ? " (0-read, 0-sent)" : " (0-sent)"); } else if (recd == 0) { exMessage.Append(" (0-read)"); } } + */ + + long sent = totalBytesSent, recd = totalBytesReceived; + if (sent == 0) { exMessage.Append(recd == 0 ? " (0-read, 0-sent)" : " (0-sent)"); } + else if (recd == 0) { exMessage.Append(" (0-read)"); } var data = new List>(); - void AddData(string lk, string sk, string? v) + void AddData(string? lk, string? sk, string? v) { if (lk != null) data.Add(Tuple.Create(lk, v)); if (sk != null) exMessage.Append(", ").Append(sk).Append(": ").Append(v); @@ -811,228 +820,37 @@ internal void SetUnknownDatabase() currentDatabase = -1; } - internal void Write(in RedisKey key) - { - var val = key.KeyValue; - if (val is string s) - { - WriteUnifiedPrefixedString(_ioPipe?.Output, key.KeyPrefix, s); - } - else - { - WriteUnifiedPrefixedBlob(_ioPipe?.Output, key.KeyPrefix, (byte[]?)val); - } - } - - internal void Write(in RedisChannel channel) - => WriteUnifiedPrefixedBlob(_ioPipe?.Output, channel.IgnoreChannelPrefix ? null : ChannelPrefix, channel.Value); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void WriteBulkString(in RedisValue value) - => WriteBulkString(value, _ioPipe?.Output); - internal static void WriteBulkString(in RedisValue value, PipeWriter? maybeNullWriter) - { - if (maybeNullWriter is not PipeWriter writer) - { - return; // Prevent null refs during disposal - } - - switch (value.Type) - { - case RedisValue.StorageType.Null: - WriteUnifiedBlob(writer, (byte[]?)null); - break; - case RedisValue.StorageType.Int64: - WriteUnifiedInt64(writer, value.OverlappedValueInt64); - break; - case RedisValue.StorageType.UInt64: - WriteUnifiedUInt64(writer, value.OverlappedValueUInt64); - break; - case RedisValue.StorageType.Double: - WriteUnifiedDouble(writer, value.OverlappedValueDouble); - break; - case RedisValue.StorageType.String: - WriteUnifiedPrefixedString(writer, null, (string?)value); - break; - case RedisValue.StorageType.Raw: - WriteUnifiedSpan(writer, ((ReadOnlyMemory)value).Span); - break; - default: - throw new InvalidOperationException($"Unexpected {value.Type} value: '{value}'"); - } - } - - internal void WriteBulkString(ReadOnlySpan value) + internal void WriteDirect(ReadOnlyMemory bytes) { - if (_ioPipe?.Output is { } writer) - { - WriteUnifiedSpan(writer, value); - } - } - - internal const int REDIS_MAX_ARGS = 1024 * 1024; // there is a <= 1024*1024 max constraint inside redis itself: https://github.com/antirez/redis/blob/6c60526db91e23fb2d666fc52facc9a11780a2a3/src/networking.c#L1024 - - internal void WriteHeader(RedisCommand command, int arguments, CommandBytes commandBytes = default) - { - if (_ioPipe?.Output is not PipeWriter writer) - { - return; // Prevent null refs during disposal - } + if (_ioStream is not { } output) return; + totalBytesSent += bytes.Length; - var bridge = BridgeCouldBeNull ?? throw new ObjectDisposedException(ToString()); - - if (command == RedisCommand.UNKNOWN) +#if NET || NETSTANDARD2_1_OR_GREATER + output.Write(bytes.Span); +#else + if (MemoryMarshal.TryGetArray(bytes, out var segment)) { - // using >= here because we will be adding 1 for the command itself (which is an arg for the purposes of the multi-bulk protocol) - if (arguments >= REDIS_MAX_ARGS) throw ExceptionFactory.TooManyArgs(commandBytes.ToString(), arguments); + output.Write(segment.Array!, segment.Offset, segment.Count); } else { - // using >= here because we will be adding 1 for the command itself (which is an arg for the purposes of the multi-bulk protocol) - if (arguments >= REDIS_MAX_ARGS) throw ExceptionFactory.TooManyArgs(command.ToString(), arguments); - - // for everything that isn't custom commands: ask the muxer for the actual bytes - commandBytes = bridge.Multiplexer.CommandMap.GetBytes(command); + var oversized = ArrayPool.Shared.Rent(bytes.Length); + bytes.CopyTo(oversized); + output.Write(oversized, 0, bytes.Length); + ArrayPool.Shared.Return(oversized); } - - // in theory we should never see this; CheckMessage dealt with "regular" messages, and - // ExecuteMessage should have dealt with everything else - if (commandBytes.IsEmpty) throw ExceptionFactory.CommandDisabled(command); - - // *{argCount}\r\n = 3 + MaxInt32TextLen - // ${cmd-len}\r\n = 3 + MaxInt32TextLen - // {cmd}\r\n = 2 + commandBytes.Length - var span = writer.GetSpan(commandBytes.Length + 8 + Format.MaxInt32TextLen + Format.MaxInt32TextLen); - span[0] = (byte)'*'; - - int offset = WriteRaw(span, arguments + 1, offset: 1); - - offset = AppendToSpanCommand(span, commandBytes, offset: offset); - - writer.Advance(offset); +#endif } - internal void WriteRaw(ReadOnlySpan bytes) => _ioPipe?.Output?.Write(bytes); - internal void RecordQuit() { // don't blame redis if we fired the first shot Volatile.Write(ref clientSentQuit, 1); - (_ioPipe as SocketConnection)?.TrySetProtocolShutdown(PipeShutdownKind.ProtocolExitClient); - } - - internal static void WriteMultiBulkHeader(PipeWriter output, long count) - { - // *{count}\r\n = 3 + MaxInt32TextLen - var span = output.GetSpan(3 + Format.MaxInt32TextLen); - span[0] = (byte)'*'; - int offset = WriteRaw(span, count, offset: 1); - output.Advance(offset); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static int WriteCrlf(Span span, int offset) - { - span[offset++] = (byte)'\r'; - span[offset++] = (byte)'\n'; - return offset; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static void WriteCrlf(PipeWriter writer) - { - var span = writer.GetSpan(2); - span[0] = (byte)'\r'; - span[1] = (byte)'\n'; - writer.Advance(2); - } - - internal static int WriteRaw(Span span, long value, bool withLengthPrefix = false, int offset = 0) - { - if (value >= 0 && value <= 9) - { - if (withLengthPrefix) - { - span[offset++] = (byte)'1'; - offset = WriteCrlf(span, offset); - } - span[offset++] = (byte)((int)'0' + (int)value); - } - else if (value >= 10 && value < 100) - { - if (withLengthPrefix) - { - span[offset++] = (byte)'2'; - offset = WriteCrlf(span, offset); - } - span[offset++] = (byte)((int)'0' + ((int)value / 10)); - span[offset++] = (byte)((int)'0' + ((int)value % 10)); - } - else if (value >= 100 && value < 1000) - { - int v = (int)value; - int units = v % 10; - v /= 10; - int tens = v % 10, hundreds = v / 10; - if (withLengthPrefix) - { - span[offset++] = (byte)'3'; - offset = WriteCrlf(span, offset); - } - span[offset++] = (byte)((int)'0' + hundreds); - span[offset++] = (byte)((int)'0' + tens); - span[offset++] = (byte)((int)'0' + units); - } - else if (value < 0 && value >= -9) - { - if (withLengthPrefix) - { - span[offset++] = (byte)'2'; - offset = WriteCrlf(span, offset); - } - span[offset++] = (byte)'-'; - span[offset++] = (byte)((int)'0' - (int)value); - } - else if (value <= -10 && value > -100) - { - if (withLengthPrefix) - { - span[offset++] = (byte)'3'; - offset = WriteCrlf(span, offset); - } - value = -value; - span[offset++] = (byte)'-'; - span[offset++] = (byte)((int)'0' + ((int)value / 10)); - span[offset++] = (byte)((int)'0' + ((int)value % 10)); - } - else - { - // we're going to write it, but *to the wrong place* - var availableChunk = span.Slice(offset); - var formattedLength = Format.FormatInt64(value, availableChunk); - if (withLengthPrefix) - { - // now we know how large the prefix is: write the prefix, then write the value - var prefixLength = Format.FormatInt32(formattedLength, availableChunk); - offset += prefixLength; - offset = WriteCrlf(span, offset); - - availableChunk = span.Slice(offset); - var finalLength = Format.FormatInt64(value, availableChunk); - offset += finalLength; - Debug.Assert(finalLength == formattedLength); - } - else - { - offset += formattedLength; - } - } - - return WriteCrlf(span, offset); + // (_ioPipe as SocketConnection)?.TrySetProtocolShutdown(PipeShutdownKind.ProtocolExitClient); } [System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1822:Mark members as static", Justification = "DEBUG uses instance data")] - private async ValueTask FlushAsync_Awaited(PhysicalConnection connection, ValueTask flush, bool throwOnFailure) + private async ValueTask FlushAsync_Awaited(PhysicalConnection connection, Task flush, bool throwOnFailure) { try { @@ -1083,11 +901,12 @@ void ThrowTimeout() } internal ValueTask FlushAsync(bool throwOnFailure, CancellationToken cancellationToken = default) { - var tmp = _ioPipe?.Output; + var tmp = _ioStream; if (tmp == null) return new ValueTask(WriteResult.NoConnectionAvailable); try { _writeStatus = WriteStatus.Flushing; + tmp.FlushAsync(cancellationToken); var flush = tmp.FlushAsync(cancellationToken); if (!flush.IsCompletedSuccessfully) return FlushAsync_Awaited(this, flush, throwOnFailure); _writeStatus = WriteStatus.Flushed; @@ -1101,319 +920,6 @@ internal ValueTask FlushAsync(bool throwOnFailure, CancellationToke } } - private static ReadOnlySpan NullBulkString => "$-1\r\n"u8; - private static ReadOnlySpan EmptyBulkString => "$0\r\n\r\n"u8; - - private static void WriteUnifiedBlob(PipeWriter writer, byte[]? value) - { - if (value is null) - { - // special case: - writer.Write(NullBulkString); - } - else - { - WriteUnifiedSpan(writer, new ReadOnlySpan(value)); - } - } - - private static void WriteUnifiedSpan(PipeWriter writer, ReadOnlySpan value) - { - // ${len}\r\n = 3 + MaxInt32TextLen - // {value}\r\n = 2 + value.Length - const int MaxQuickSpanSize = 512; - if (value.Length == 0) - { - // special case: - writer.Write(EmptyBulkString); - } - else if (value.Length <= MaxQuickSpanSize) - { - var span = writer.GetSpan(5 + Format.MaxInt32TextLen + value.Length); - span[0] = (byte)'$'; - int bytes = AppendToSpan(span, value, 1); - writer.Advance(bytes); - } - else - { - // too big to guarantee can do in a single span - var span = writer.GetSpan(3 + Format.MaxInt32TextLen); - span[0] = (byte)'$'; - int bytes = WriteRaw(span, value.Length, offset: 1); - writer.Advance(bytes); - - writer.Write(value); - - WriteCrlf(writer); - } - } - - private static int AppendToSpanCommand(Span span, in CommandBytes value, int offset = 0) - { - span[offset++] = (byte)'$'; - int len = value.Length; - offset = WriteRaw(span, len, offset: offset); - value.CopyTo(span.Slice(offset, len)); - offset += value.Length; - return WriteCrlf(span, offset); - } - - private static int AppendToSpan(Span span, ReadOnlySpan value, int offset = 0) - { - offset = WriteRaw(span, value.Length, offset: offset); - value.CopyTo(span.Slice(offset, value.Length)); - offset += value.Length; - return WriteCrlf(span, offset); - } - - internal void WriteSha1AsHex(byte[]? value) - { - if (_ioPipe?.Output is not PipeWriter writer) - { - return; // Prevent null refs during disposal - } - - if (value is null) - { - writer.Write(NullBulkString); - } - else if (value.Length == ResultProcessor.ScriptLoadProcessor.Sha1HashLength) - { - // $40\r\n = 5 - // {40 bytes}\r\n = 42 - var span = writer.GetSpan(47); - span[0] = (byte)'$'; - span[1] = (byte)'4'; - span[2] = (byte)'0'; - span[3] = (byte)'\r'; - span[4] = (byte)'\n'; - - int offset = 5; - for (int i = 0; i < value.Length; i++) - { - var b = value[i]; - span[offset++] = ToHexNibble(b >> 4); - span[offset++] = ToHexNibble(b & 15); - } - span[offset++] = (byte)'\r'; - span[offset++] = (byte)'\n'; - - writer.Advance(offset); - } - else - { - throw new InvalidOperationException("Invalid SHA1 length: " + value.Length); - } - } - - internal static byte ToHexNibble(int value) - { - return value < 10 ? (byte)('0' + value) : (byte)('a' - 10 + value); - } - - internal static void WriteUnifiedPrefixedString(PipeWriter? writer, byte[]? prefix, string? value) - { - if (writer is null) - { - return; // Prevent null refs during disposal - } - - if (value == null) - { - // special case - writer.Write(NullBulkString); - } - else - { - // ${total-len}\r\n 3 + MaxInt32TextLen - // {prefix}{value}\r\n - int encodedLength = Encoding.UTF8.GetByteCount(value), - prefixLength = prefix?.Length ?? 0, - totalLength = prefixLength + encodedLength; - - if (totalLength == 0) - { - // special-case - writer.Write(EmptyBulkString); - } - else - { - var span = writer.GetSpan(3 + Format.MaxInt32TextLen); - span[0] = (byte)'$'; - int bytes = WriteRaw(span, totalLength, offset: 1); - writer.Advance(bytes); - - if (prefixLength != 0) writer.Write(prefix); - if (encodedLength != 0) WriteRaw(writer, value, encodedLength); - WriteCrlf(writer); - } - } - } - - [ThreadStatic] - private static Encoder? s_PerThreadEncoder; - internal static Encoder GetPerThreadEncoder() - { - var encoder = s_PerThreadEncoder; - if (encoder == null) - { - s_PerThreadEncoder = encoder = Encoding.UTF8.GetEncoder(); - } - else - { - encoder.Reset(); - } - return encoder; - } - - internal static unsafe void WriteRaw(PipeWriter writer, string value, int expectedLength) - { - const int MaxQuickEncodeSize = 512; - - fixed (char* cPtr = value) - { - int totalBytes; - if (expectedLength <= MaxQuickEncodeSize) - { - // encode directly in one hit - var span = writer.GetSpan(expectedLength); - fixed (byte* bPtr = span) - { - totalBytes = Encoding.UTF8.GetBytes(cPtr, value.Length, bPtr, expectedLength); - } - writer.Advance(expectedLength); - } - else - { - // use an encoder in a loop - var encoder = GetPerThreadEncoder(); - int charsRemaining = value.Length, charOffset = 0; - totalBytes = 0; - - bool final = false; - while (true) - { - var span = writer.GetSpan(5); // get *some* memory - at least enough for 1 character (but hopefully lots more) - - int charsUsed, bytesUsed; - bool completed; - fixed (byte* bPtr = span) - { - encoder.Convert(cPtr + charOffset, charsRemaining, bPtr, span.Length, final, out charsUsed, out bytesUsed, out completed); - } - writer.Advance(bytesUsed); - totalBytes += bytesUsed; - charOffset += charsUsed; - charsRemaining -= charsUsed; - - if (charsRemaining <= 0) - { - if (charsRemaining < 0) throw new InvalidOperationException("String encode went negative"); - if (completed) break; // fine - if (final) throw new InvalidOperationException("String encode failed to complete"); - final = true; // flush the encoder to one more span, then exit - } - } - } - if (totalBytes != expectedLength) throw new InvalidOperationException("String encode length check failure"); - } - } - - private static void WriteUnifiedPrefixedBlob(PipeWriter? maybeNullWriter, byte[]? prefix, byte[]? value) - { - if (maybeNullWriter is not PipeWriter writer) - { - return; // Prevent null refs during disposal - } - - // ${total-len}\r\n - // {prefix}{value}\r\n - if (prefix == null || prefix.Length == 0 || value == null) - { - // if no prefix, just use the non-prefixed version; - // even if prefixed, a null value writes as null, so can use the non-prefixed version - WriteUnifiedBlob(writer, value); - } - else - { - var span = writer.GetSpan(3 + Format.MaxInt32TextLen); // note even with 2 max-len, we're still in same text range - span[0] = (byte)'$'; - int bytes = WriteRaw(span, prefix.LongLength + value.LongLength, offset: 1); - writer.Advance(bytes); - - writer.Write(prefix); - writer.Write(value); - - span = writer.GetSpan(2); - WriteCrlf(span, 0); - writer.Advance(2); - } - } - - private static void WriteUnifiedInt64(PipeWriter writer, long value) - { - // note from specification: A client sends to the Redis server a RESP Array consisting of just Bulk Strings. - // (i.e. we can't just send ":123\r\n", we need to send "$3\r\n123\r\n" - - // ${asc-len}\r\n = 4/5 (asc-len at most 2 digits) - // {asc}\r\n = MaxInt64TextLen + 2 - var span = writer.GetSpan(7 + Format.MaxInt64TextLen); - - span[0] = (byte)'$'; - var bytes = WriteRaw(span, value, withLengthPrefix: true, offset: 1); - writer.Advance(bytes); - } - - private static void WriteUnifiedUInt64(PipeWriter writer, ulong value) - { - // note from specification: A client sends to the Redis server a RESP Array consisting of just Bulk Strings. - // (i.e. we can't just send ":123\r\n", we need to send "$3\r\n123\r\n" - Span valueSpan = stackalloc byte[Format.MaxInt64TextLen]; - - var len = Format.FormatUInt64(value, valueSpan); - // ${asc-len}\r\n = 4/5 (asc-len at most 2 digits) - // {asc}\r\n = {len} + 2 - var span = writer.GetSpan(7 + len); - span[0] = (byte)'$'; - int offset = WriteRaw(span, len, withLengthPrefix: false, offset: 1); - valueSpan.Slice(0, len).CopyTo(span.Slice(offset)); - offset += len; - offset = WriteCrlf(span, offset); - writer.Advance(offset); - } - - private static void WriteUnifiedDouble(PipeWriter writer, double value) - { -#if NET8_0_OR_GREATER - Span valueSpan = stackalloc byte[Format.MaxDoubleTextLen]; - var len = Format.FormatDouble(value, valueSpan); - - // ${asc-len}\r\n = 4/5 (asc-len at most 2 digits) - // {asc}\r\n = {len} + 2 - var span = writer.GetSpan(7 + len); - span[0] = (byte)'$'; - int offset = WriteRaw(span, len, withLengthPrefix: false, offset: 1); - valueSpan.Slice(0, len).CopyTo(span.Slice(offset)); - offset += len; - offset = WriteCrlf(span, offset); - writer.Advance(offset); -#else - // fallback: drop to string - WriteUnifiedPrefixedString(writer, null, Format.ToString(value)); -#endif - } - - internal static void WriteInteger(PipeWriter writer, long value) - { - // note: client should never write integer; only server does this - // :{asc}\r\n = MaxInt64TextLen + 3 - var span = writer.GetSpan(3 + Format.MaxInt64TextLen); - - span[0] = (byte)':'; - var bytes = WriteRaw(span, value, withLengthPrefix: false, offset: 1); - writer.Advance(bytes); - } - internal readonly struct ConnectionStatus { /// @@ -1562,6 +1068,7 @@ internal async ValueTask ConnectedAsync(Socket? socket, ILogger? log, Sock var bridge = BridgeCouldBeNull; if (bridge == null) return false; + Stream? stream = null; try { // disallow connection in some cases @@ -1573,7 +1080,6 @@ internal async ValueTask ConnectedAsync(Socket? socket, ILogger? log, Sock var config = bridge.Multiplexer.RawConfig; var tunnel = config.Tunnel; - Stream? stream = null; if (tunnel is not null) { stream = await tunnel.BeforeAuthenticateAsync(bridge.ServerEndPoint.EndPoint, bridge.ConnectionType, socket, CancellationToken.None).ForAwait(); @@ -1646,7 +1152,7 @@ static Stream DemandSocketStream(Socket? socket) } catch (Exception ex) { - RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex, isInitialConnect: true, connectingPipe: pipe); // includes a bridge.OnDisconnected + RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex, isInitialConnect: true, connectingStream: stream); // includes a bridge.OnDisconnected bridge.Multiplexer.Trace("Could not connect: " + ex.Message, ToString()); return false; } diff --git a/src/StackExchange.Redis/RedisDatabase.cs b/src/StackExchange.Redis/RedisDatabase.cs index 8c0027d13..bed35029d 100644 --- a/src/StackExchange.Redis/RedisDatabase.cs +++ b/src/StackExchange.Redis/RedisDatabase.cs @@ -1323,18 +1323,18 @@ public KeyMigrateCommandMessage(int db, RedisKey key, EndPoint toServer, int toD this.migrateOptions = migrateOptions; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { bool isCopy = (migrateOptions & MigrateOptions.Copy) != 0; bool isReplace = (migrateOptions & MigrateOptions.Replace) != 0; - physical.WriteHeader(Command, 5 + (isCopy ? 1 : 0) + (isReplace ? 1 : 0)); - physical.WriteBulkString(toHost); - physical.WriteBulkString(toPort); - physical.Write(Key); - physical.WriteBulkString(toDatabase); - physical.WriteBulkString(timeoutMilliseconds); - if (isCopy) physical.WriteBulkString("COPY"u8); - if (isReplace) physical.WriteBulkString("REPLACE"u8); + writer.WriteHeader(Command, 5 + (isCopy ? 1 : 0) + (isReplace ? 1 : 0)); + writer.WriteBulkString(toHost); + writer.WriteBulkString(toPort); + writer.Write(Key); + writer.WriteBulkString(toDatabase); + writer.WriteBulkString(timeoutMilliseconds); + if (isCopy) writer.WriteBulkString("COPY"u8); + if (isReplace) writer.WriteBulkString("REPLACE"u8); } public override int ArgCount @@ -4255,38 +4255,38 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) return slot; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, argCount); - physical.WriteBulkString("GROUP"u8); - physical.WriteBulkString(groupName); - physical.WriteBulkString(consumerName); + writer.WriteHeader(Command, argCount); + writer.WriteBulkString("GROUP"u8); + writer.WriteBulkString(groupName); + writer.WriteBulkString(consumerName); if (countPerStream.HasValue) { - physical.WriteBulkString("COUNT"u8); - physical.WriteBulkString(countPerStream.Value); + writer.WriteBulkString("COUNT"u8); + writer.WriteBulkString(countPerStream.Value); } if (noAck) { - physical.WriteBulkString("NOACK"u8); + writer.WriteBulkString("NOACK"u8); } if (claimMinIdleTime.HasValue) { - physical.WriteBulkString("CLAIM"u8); - physical.WriteBulkString(claimMinIdleTime.Value.TotalMilliseconds); + writer.WriteBulkString("CLAIM"u8); + writer.WriteBulkString(claimMinIdleTime.Value.TotalMilliseconds); } - physical.WriteBulkString("STREAMS"u8); + writer.WriteBulkString("STREAMS"u8); for (int i = 0; i < streamPositions.Length; i++) { - physical.Write(streamPositions[i].Key); + writer.Write(streamPositions[i].Key); } for (int i = 0; i < streamPositions.Length; i++) { - physical.WriteBulkString(StreamPosition.Resolve(streamPositions[i].Position, RedisCommand.XREADGROUP)); + writer.WriteBulkString(StreamPosition.Resolve(streamPositions[i].Position, RedisCommand.XREADGROUP)); } } @@ -4335,24 +4335,24 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) return slot; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, argCount); + writer.WriteHeader(Command, argCount); if (countPerStream.HasValue) { - physical.WriteBulkString("COUNT"u8); - physical.WriteBulkString(countPerStream.Value); + writer.WriteBulkString("COUNT"u8); + writer.WriteBulkString(countPerStream.Value); } - physical.WriteBulkString("STREAMS"u8); + writer.WriteBulkString("STREAMS"u8); for (int i = 0; i < streamPositions.Length; i++) { - physical.Write(streamPositions[i].Key); + writer.Write(streamPositions[i].Key); } for (int i = 0; i < streamPositions.Length; i++) { - physical.WriteBulkString(StreamPosition.Resolve(streamPositions[i].Position, RedisCommand.XREADGROUP)); + writer.WriteBulkString(StreamPosition.Resolve(streamPositions[i].Position, RedisCommand.XREADGROUP)); } } @@ -5058,33 +5058,33 @@ public SingleStreamReadGroupCommandMessage(int db, CommandFlags flags, RedisKey argCount = 6 + (count.HasValue ? 2 : 0) + (noAck ? 1 : 0) + (claimMinIdleTime.HasValue ? 2 : 0); } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, argCount); - physical.WriteBulkString("GROUP"u8); - physical.WriteBulkString(groupName); - physical.WriteBulkString(consumerName); + writer.WriteHeader(Command, argCount); + writer.WriteBulkString("GROUP"u8); + writer.WriteBulkString(groupName); + writer.WriteBulkString(consumerName); if (count.HasValue) { - physical.WriteBulkString("COUNT"u8); - physical.WriteBulkString(count.Value); + writer.WriteBulkString("COUNT"u8); + writer.WriteBulkString(count.Value); } if (noAck) { - physical.WriteBulkString("NOACK"u8); + writer.WriteBulkString("NOACK"u8); } if (claimMinIdleTime.HasValue) { - physical.WriteBulkString("CLAIM"u8); - physical.WriteBulkString(claimMinIdleTime.Value.TotalMilliseconds); + writer.WriteBulkString("CLAIM"u8); + writer.WriteBulkString(claimMinIdleTime.Value.TotalMilliseconds); } - physical.WriteBulkString("STREAMS"u8); - physical.Write(Key); - physical.WriteBulkString(afterId); + writer.WriteBulkString("STREAMS"u8); + writer.Write(Key); + writer.WriteBulkString(afterId); } public override int ArgCount => argCount; @@ -5114,19 +5114,19 @@ public SingleStreamReadCommandMessage(int db, CommandFlags flags, RedisKey key, argCount = count.HasValue ? 5 : 3; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, argCount); + writer.WriteHeader(Command, argCount); if (count.HasValue) { - physical.WriteBulkString("COUNT"u8); - physical.WriteBulkString(count.Value); + writer.WriteBulkString("COUNT"u8); + writer.WriteBulkString(count.Value); } - physical.WriteBulkString("STREAMS"u8); - physical.Write(Key); - physical.WriteBulkString(afterId); + writer.WriteBulkString("STREAMS"u8); + writer.Write(Key); + writer.WriteBulkString(afterId); } public override int ArgCount => argCount; @@ -5559,11 +5559,11 @@ public ScriptLoadMessage(CommandFlags flags, string script) Script = script ?? throw new ArgumentNullException(nameof(script)); } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 2); - physical.WriteBulkString("LOAD"u8); - physical.WriteBulkString((RedisValue)Script); + writer.WriteHeader(Command, 2); + writer.WriteBulkString("LOAD"u8); + writer.WriteBulkString((RedisValue)Script); } public override int ArgCount => 2; } @@ -5610,7 +5610,7 @@ internal sealed class ExecuteMessage : Message public ExecuteMessage(CommandMap? map, int db, CommandFlags flags, string command, ICollection? args) : base(db, flags, RedisCommand.UNKNOWN) { - if (args != null && args.Count >= PhysicalConnection.REDIS_MAX_ARGS) // using >= here because we will be adding 1 for the command itself (which is an arg for the purposes of the multi-bulk protocol) + if (args != null && args.Count >= MessageWriter.REDIS_MAX_ARGS) // using >= here because we will be adding 1 for the command itself (which is an arg for the purposes of the multi-bulk protocol) { throw ExceptionFactory.TooManyArgs(command, args.Count); } @@ -5619,24 +5619,24 @@ public ExecuteMessage(CommandMap? map, int db, CommandFlags flags, string comman _args = args ?? Array.Empty(); } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(RedisCommand.UNKNOWN, _args.Count, Command); + writer.WriteHeader(RedisCommand.UNKNOWN, _args.Count, Command); foreach (object arg in _args) { if (arg is RedisKey key) { - physical.Write(key); + writer.Write(key); } else if (arg is RedisChannel channel) { - physical.Write(channel); + writer.Write(channel); } else { // recognises well-known types var val = RedisValue.TryParse(arg, out var valid); if (!valid) throw new InvalidCastException($"Unable to parse value: '{arg}'"); - physical.WriteBulkString(val); + writer.WriteBulkString(val); } } } @@ -5725,28 +5725,28 @@ public IEnumerable GetMessages(PhysicalConnection connection) yield return this; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { if (hexHash != null) { - physical.WriteHeader(RedisCommand.EVALSHA, 2 + keys.Length + values.Length); - physical.WriteSha1AsHex(hexHash); + writer.WriteHeader(RedisCommand.EVALSHA, 2 + keys.Length + values.Length); + writer.WriteSha1AsHex(hexHash); } else if (asciiHash != null) { - physical.WriteHeader(RedisCommand.EVALSHA, 2 + keys.Length + values.Length); - physical.WriteBulkString((RedisValue)asciiHash); + writer.WriteHeader(RedisCommand.EVALSHA, 2 + keys.Length + values.Length); + writer.WriteBulkString((RedisValue)asciiHash); } else { - physical.WriteHeader(RedisCommand.EVAL, 2 + keys.Length + values.Length); - physical.WriteBulkString((RedisValue)script); + writer.WriteHeader(RedisCommand.EVAL, 2 + keys.Length + values.Length); + writer.WriteBulkString((RedisValue)script); } - physical.WriteBulkString(keys.Length); + writer.WriteBulkString(keys.Length); for (int i = 0; i < keys.Length; i++) - physical.Write(keys[i]); + writer.Write(keys[i]); for (int i = 0; i < values.Length; i++) - physical.WriteBulkString(values[i]); + writer.WriteBulkString(values[i]); } public override int ArgCount => 2 + keys.Length + values.Length; } @@ -5858,15 +5858,15 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) return slot; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(Command, 2 + keys.Length + values.Length); - physical.Write(Key); - physical.WriteBulkString(keys.Length); + writer.WriteHeader(Command, 2 + keys.Length + values.Length); + writer.Write(Key); + writer.WriteBulkString(keys.Length); for (int i = 0; i < keys.Length; i++) - physical.Write(keys[i]); + writer.Write(keys[i]); for (int i = 0; i < values.Length; i++) - physical.WriteBulkString(values[i]); + writer.WriteBulkString(values[i]); } public override int ArgCount => 2 + keys.Length + values.Length; } @@ -5915,10 +5915,10 @@ public bool UnwrapValue(out TimeSpan? value, out Exception? ex) return false; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - physical.WriteHeader(command, 1); - physical.Write(Key); + writer.WriteHeader(command, 1); + writer.Write(Key); } public override int ArgCount => 1; } diff --git a/src/StackExchange.Redis/RedisSubscriber.cs b/src/StackExchange.Redis/RedisSubscriber.cs index 425fb7534..ef9d42d22 100644 --- a/src/StackExchange.Redis/RedisSubscriber.cs +++ b/src/StackExchange.Redis/RedisSubscriber.cs @@ -3,7 +3,6 @@ using System.Diagnostics.CodeAnalysis; using System.Net; using System.Threading.Tasks; -using Pipelines.Sockets.Unofficial.Arenas; using static StackExchange.Redis.ConnectionMultiplexer; namespace StackExchange.Redis diff --git a/src/StackExchange.Redis/RedisTransaction.cs b/src/StackExchange.Redis/RedisTransaction.cs index deeee46b4..f890c035c 100644 --- a/src/StackExchange.Redis/RedisTransaction.cs +++ b/src/StackExchange.Redis/RedisTransaction.cs @@ -188,9 +188,9 @@ public bool WasQueued set => wasQueued = value; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { - Wrapped.WriteTo(physical); + Wrapped.WriteTo(writer); Wrapped.SetRequestSent(); } public override int ArgCount => Wrapped.ArgCount; @@ -443,7 +443,7 @@ public IEnumerable GetMessages(PhysicalConnection connection) } } - protected override void WriteImpl(PhysicalConnection physical) => physical.WriteHeader(Command, 0); + protected override void WriteImpl(in MessageWriter writer) => writer.WriteHeader(Command, 0); public override int ArgCount => 0; diff --git a/src/StackExchange.Redis/RedisValue.cs b/src/StackExchange.Redis/RedisValue.cs index 46228a912..f9efc50ce 100644 --- a/src/StackExchange.Redis/RedisValue.cs +++ b/src/StackExchange.Redis/RedisValue.cs @@ -1277,7 +1277,7 @@ private ReadOnlyMemory AsMemory(out byte[]? leased) goto HaveString; case StorageType.Int64: leased = ArrayPool.Shared.Rent(Format.MaxInt64TextLen + 2); // reused code has CRLF terminator - len = PhysicalConnection.WriteRaw(leased, OverlappedValueInt64) - 2; // drop the CRLF + len = MessageWriter.WriteRaw(leased, OverlappedValueInt64) - 2; // drop the CRLF return new ReadOnlyMemory(leased, 0, len); case StorageType.UInt64: leased = ArrayPool.Shared.Rent(Format.MaxInt64TextLen); // reused code has CRLF terminator diff --git a/src/StackExchange.Redis/ResultProcessor.VectorSets.cs b/src/StackExchange.Redis/ResultProcessor.VectorSets.cs index 70f548264..7b4618f9d 100644 --- a/src/StackExchange.Redis/ResultProcessor.VectorSets.cs +++ b/src/StackExchange.Redis/ResultProcessor.VectorSets.cs @@ -1,6 +1,7 @@ -using Pipelines.Sockets.Unofficial.Arenas; +// ReSharper disable once CheckNamespace + +using Pipelines.Sockets.Unofficial.Arenas; -// ReSharper disable once CheckNamespace namespace StackExchange.Redis; internal abstract partial class ResultProcessor diff --git a/src/StackExchange.Redis/ResultProcessor.cs b/src/StackExchange.Redis/ResultProcessor.cs index 4ca08e663..b7d795187 100644 --- a/src/StackExchange.Redis/ResultProcessor.cs +++ b/src/StackExchange.Redis/ResultProcessor.cs @@ -485,17 +485,17 @@ public TimerMessage(int db, CommandFlags flags, RedisCommand command, RedisValue this.value = value; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { StartedWritingTimestamp = Stopwatch.GetTimestamp(); if (value.IsNull) { - physical.WriteHeader(command, 0); + writer.WriteHeader(command, 0); } else { - physical.WriteHeader(command, 1); - physical.WriteBulkString(value); + writer.WriteHeader(command, 1); + writer.WriteBulkString(value); } } public override int ArgCount => value.IsNull ? 0 : 1; diff --git a/src/StackExchange.Redis/ValueCondition.cs b/src/StackExchange.Redis/ValueCondition.cs index c5cf4bd5a..f9883b8b9 100644 --- a/src/StackExchange.Redis/ValueCondition.cs +++ b/src/StackExchange.Redis/ValueCondition.cs @@ -249,33 +249,33 @@ static byte ToNibble(int b) _ => 0, }; - internal void WriteTo(PhysicalConnection physical) + internal void WriteTo(in MessageWriter writer) { switch (_kind) { case ConditionKind.Exists: - physical.WriteBulkString("XX"u8); + writer.WriteBulkString("XX"u8); break; case ConditionKind.NotExists: - physical.WriteBulkString("NX"u8); + writer.WriteBulkString("NX"u8); break; case ConditionKind.ValueEquals: - physical.WriteBulkString("IFEQ"u8); - physical.WriteBulkString(_value); + writer.WriteBulkString("IFEQ"u8); + writer.WriteBulkString(_value); break; case ConditionKind.ValueNotEquals: - physical.WriteBulkString("IFNE"u8); - physical.WriteBulkString(_value); + writer.WriteBulkString("IFNE"u8); + writer.WriteBulkString(_value); break; case ConditionKind.DigestEquals: - physical.WriteBulkString("IFDEQ"u8); + writer.WriteBulkString("IFDEQ"u8); var written = WriteHex(_value.DirectOverlappedBits64, stackalloc byte[2 * DigestBytes]); - physical.WriteBulkString(written); + writer.WriteBulkString(written); break; case ConditionKind.DigestNotEquals: - physical.WriteBulkString("IFDNE"u8); + writer.WriteBulkString("IFDNE"u8); written = WriteHex(_value.DirectOverlappedBits64, stackalloc byte[2 * DigestBytes]); - physical.WriteBulkString(written); + writer.WriteBulkString(written); break; } } diff --git a/src/StackExchange.Redis/VectorSetAddMessage.cs b/src/StackExchange.Redis/VectorSetAddMessage.cs index 0beb65205..75dbbd630 100644 --- a/src/StackExchange.Redis/VectorSetAddMessage.cs +++ b/src/StackExchange.Redis/VectorSetAddMessage.cs @@ -60,31 +60,31 @@ internal static void SuppressFp32() { } internal static void RestoreFp32() { } #endif - protected abstract void WriteElement(bool packed, PhysicalConnection physical); + protected abstract void WriteElement(bool packed, in MessageWriter writer); - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { bool packed = UseFp32; // snapshot to avoid race in debug scenarios - physical.WriteHeader(Command, GetArgCount(packed)); - physical.Write(key); + writer.WriteHeader(Command, GetArgCount(packed)); + writer.Write(key); if (reducedDimensions.HasValue) { - physical.WriteBulkString("REDUCE"u8); - physical.WriteBulkString(reducedDimensions.GetValueOrDefault()); + writer.WriteBulkString("REDUCE"u8); + writer.WriteBulkString(reducedDimensions.GetValueOrDefault()); } - WriteElement(packed, physical); - if (useCheckAndSet) physical.WriteBulkString("CAS"u8); + WriteElement(packed, writer); + if (useCheckAndSet) writer.WriteBulkString("CAS"u8); switch (quantization) { case VectorSetQuantization.Int8: break; case VectorSetQuantization.None: - physical.WriteBulkString("NOQUANT"u8); + writer.WriteBulkString("NOQUANT"u8); break; case VectorSetQuantization.Binary: - physical.WriteBulkString("BIN"u8); + writer.WriteBulkString("BIN"u8); break; default: throw new ArgumentOutOfRangeException(nameof(quantization)); @@ -92,20 +92,20 @@ protected override void WriteImpl(PhysicalConnection physical) if (buildExplorationFactor.HasValue) { - physical.WriteBulkString("EF"u8); - physical.WriteBulkString(buildExplorationFactor.GetValueOrDefault()); + writer.WriteBulkString("EF"u8); + writer.WriteBulkString(buildExplorationFactor.GetValueOrDefault()); } - WriteAttributes(physical); + WriteAttributes(writer); if (maxConnections.HasValue) { - physical.WriteBulkString("M"u8); - physical.WriteBulkString(maxConnections.GetValueOrDefault()); + writer.WriteBulkString("M"u8); + writer.WriteBulkString(maxConnections.GetValueOrDefault()); } } - protected abstract void WriteAttributes(PhysicalConnection physical); + protected abstract void WriteAttributes(in MessageWriter writer); internal sealed class VectorSetAddMemberMessage( int db, @@ -136,32 +136,32 @@ public override int GetElementArgCount(bool packed) public override int GetAttributeArgCount() => _attributesJson is null ? 0 : 2; // [SETATTR {attributes}] - protected override void WriteElement(bool packed, PhysicalConnection physical) + protected override void WriteElement(bool packed, in MessageWriter writer) { if (packed) { - physical.WriteBulkString("FP32"u8); - physical.WriteBulkString(MemoryMarshal.AsBytes(values.Span)); + writer.WriteBulkString("FP32"u8); + writer.WriteBulkString(MemoryMarshal.AsBytes(values.Span)); } else { - physical.WriteBulkString("VALUES"u8); - physical.WriteBulkString(values.Length); + writer.WriteBulkString("VALUES"u8); + writer.WriteBulkString(values.Length); foreach (var val in values.Span) { - physical.WriteBulkString(val); + writer.WriteBulkString(val); } } - physical.WriteBulkString(element); + writer.WriteBulkString(element); } - protected override void WriteAttributes(PhysicalConnection physical) + protected override void WriteAttributes(in MessageWriter writer) { if (_attributesJson is not null) { - physical.WriteBulkString("SETATTR"u8); - physical.WriteBulkString(_attributesJson); + writer.WriteBulkString("SETATTR"u8); + writer.WriteBulkString(_attributesJson); } } } diff --git a/src/StackExchange.Redis/VectorSetSimilaritySearchMessage.cs b/src/StackExchange.Redis/VectorSetSimilaritySearchMessage.cs index a492be4d8..b3a113530 100644 --- a/src/StackExchange.Redis/VectorSetSimilaritySearchMessage.cs +++ b/src/StackExchange.Redis/VectorSetSimilaritySearchMessage.cs @@ -33,20 +33,20 @@ internal sealed class VectorSetSimilaritySearchBySingleVectorMessage( internal override int GetSearchTargetArgCount(bool packed) => packed ? 2 : 2 + vector.Length; // FP32 {vector} or VALUES {num} {vector} - internal override void WriteSearchTarget(bool packed, PhysicalConnection physical) + internal override void WriteSearchTarget(bool packed, in MessageWriter writer) { if (packed) { - physical.WriteBulkString("FP32"u8); - physical.WriteBulkString(System.Runtime.InteropServices.MemoryMarshal.AsBytes(vector.Span)); + writer.WriteBulkString("FP32"u8); + writer.WriteBulkString(System.Runtime.InteropServices.MemoryMarshal.AsBytes(vector.Span)); } else { - physical.WriteBulkString("VALUES"u8); - physical.WriteBulkString(vector.Length); + writer.WriteBulkString("VALUES"u8); + writer.WriteBulkString(vector.Length); foreach (var val in vector.Span) { - physical.WriteBulkString(val); + writer.WriteBulkString(val); } } } @@ -68,15 +68,15 @@ internal sealed class VectorSetSimilaritySearchByMemberMessage( { internal override int GetSearchTargetArgCount(bool packed) => 2; // ELE {member} - internal override void WriteSearchTarget(bool packed, PhysicalConnection physical) + internal override void WriteSearchTarget(bool packed, in MessageWriter writer) { - physical.WriteBulkString("ELE"u8); - physical.WriteBulkString(member); + writer.WriteBulkString("ELE"u8); + writer.WriteBulkString(member); } } internal abstract int GetSearchTargetArgCount(bool packed); - internal abstract void WriteSearchTarget(bool packed, PhysicalConnection physical); + internal abstract void WriteSearchTarget(bool packed, in MessageWriter writer); public ResultProcessor?> GetResultProcessor() => VectorSetSimilaritySearchProcessor.Instance; @@ -194,67 +194,67 @@ private int GetArgCount(bool packed) return argCount; } - protected override void WriteImpl(PhysicalConnection physical) + protected override void WriteImpl(in MessageWriter writer) { // snapshot to avoid race in debug scenarios bool packed = VectorSetAddMessage.UseFp32; - physical.WriteHeader(Command, GetArgCount(packed)); + writer.WriteHeader(Command, GetArgCount(packed)); // Write key - physical.Write(key); + writer.Write(key); // Write search target: either "ELE {member}" or vector data - WriteSearchTarget(packed, physical); + WriteSearchTarget(packed, writer); if (HasFlag(VsimFlags.WithScores)) { - physical.WriteBulkString("WITHSCORES"u8); + writer.WriteBulkString("WITHSCORES"u8); } if (HasFlag(VsimFlags.WithAttributes)) { - physical.WriteBulkString("WITHATTRIBS"u8); + writer.WriteBulkString("WITHATTRIBS"u8); } // Write optional parameters if (HasFlag(VsimFlags.Count)) { - physical.WriteBulkString("COUNT"u8); - physical.WriteBulkString(count); + writer.WriteBulkString("COUNT"u8); + writer.WriteBulkString(count); } if (HasFlag(VsimFlags.Epsilon)) { - physical.WriteBulkString("EPSILON"u8); - physical.WriteBulkString(epsilon); + writer.WriteBulkString("EPSILON"u8); + writer.WriteBulkString(epsilon); } if (HasFlag(VsimFlags.SearchExplorationFactor)) { - physical.WriteBulkString("EF"u8); - physical.WriteBulkString(searchExplorationFactor); + writer.WriteBulkString("EF"u8); + writer.WriteBulkString(searchExplorationFactor); } if (HasFlag(VsimFlags.FilterExpression)) { - physical.WriteBulkString("FILTER"u8); - physical.WriteBulkString(filterExpression); + writer.WriteBulkString("FILTER"u8); + writer.WriteBulkString(filterExpression); } if (HasFlag(VsimFlags.MaxFilteringEffort)) { - physical.WriteBulkString("FILTER-EF"u8); - physical.WriteBulkString(maxFilteringEffort); + writer.WriteBulkString("FILTER-EF"u8); + writer.WriteBulkString(maxFilteringEffort); } if (HasFlag(VsimFlags.UseExactSearch)) { - physical.WriteBulkString("TRUTH"u8); + writer.WriteBulkString("TRUTH"u8); } if (HasFlag(VsimFlags.DisableThreading)) { - physical.WriteBulkString("NOTHREAD"u8); + writer.WriteBulkString("NOTHREAD"u8); } } diff --git a/tests/StackExchange.Redis.Tests/ParseTests.cs b/tests/StackExchange.Redis.Tests/ParseTests.cs index 2621ddab9..9542b6c32 100644 --- a/tests/StackExchange.Redis.Tests/ParseTests.cs +++ b/tests/StackExchange.Redis.Tests/ParseTests.cs @@ -1,8 +1,12 @@ using System; using System.Buffers; using System.Collections.Generic; +using System.IO; +using System.Runtime.InteropServices; using System.Text; +using System.Threading.Tasks; using Pipelines.Sockets.Unofficial.Arenas; +using StackExchange.Redis.Configuration; using Xunit; namespace StackExchange.Redis.Tests; @@ -29,18 +33,15 @@ public static IEnumerable GetTestData() [Theory] [MemberData(nameof(GetTestData))] - public void ParseAsSingleChunk(string ascii, int expected) + public Task ParseAsSingleChunk(string ascii, int expected) { var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(ascii)); - using (var arena = new Arena()) - { - ProcessMessages(arena, buffer, expected); - } + return ProcessMessagesAsync(buffer, expected); } [Theory] [MemberData(nameof(GetTestData))] - public void ParseAsLotsOfChunks(string ascii, int expected) + public Task ParseAsLotsOfChunks(string ascii, int expected) { var bytes = Encoding.ASCII.GetBytes(ascii); FragmentedSegment? chain = null, tail = null; @@ -59,21 +60,42 @@ public void ParseAsLotsOfChunks(string ascii, int expected) } var buffer = new ReadOnlySequence(chain!, 0, tail!, 1); Assert.Equal(bytes.Length, buffer.Length); - using (var arena = new Arena()) - { - ProcessMessages(arena, buffer, expected); - } + return ProcessMessagesAsync(buffer, expected); } - private void ProcessMessages(Arena arena, ReadOnlySequence buffer, int expected) + private async Task ProcessMessagesAsync(ReadOnlySequence buffer, int expected, bool isInbound = false) { Log($"chain: {buffer.Length}"); - var reader = new BufferReader(buffer); - RawResult result; + MemoryStream ms; + if (buffer.IsSingleSegment && MemoryMarshal.TryGetArray(buffer.First, out var segment)) + { + // use existing buffer + ms = new MemoryStream(segment.Array!, segment.Offset, (int)buffer.Length, false, true); + } + else + { + ms = new MemoryStream(checked((int)buffer.Length)); + foreach (var chunk in buffer) + { +#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER + ms.Write(chunk.Span); +#else + ms.Write(chunk); +#endif + } + + ms.Position = 0; + } + +#pragma warning disable CS0618 // Type or member is obsolete + var reader = new LoggingTunnel.StreamRespReader(ms, isInbound: isInbound); +#pragma warning restore CS0618 // Type or member is obsolete int found = 0; - while (!(result = PhysicalConnection.TryParseResult(false, arena, buffer, ref reader, false, null, false)).IsNull) + while (true) { - Log($"{result} - {result.GetString()}"); + var result = await reader.ReadOneAsync().ForAwait(); + if (result.Result is null) break; + Log($"{result} - {result.Result}"); found++; } Assert.Equal(expected, found); diff --git a/toys/StackExchange.Redis.Server/RespServer.cs b/toys/StackExchange.Redis.Server/RespServer.cs index 75a0273ea..bab38c0ba 100644 --- a/toys/StackExchange.Redis.Server/RespServer.cs +++ b/toys/StackExchange.Redis.Server/RespServer.cs @@ -309,20 +309,21 @@ public void Log(string message) public static async ValueTask WriteResponseAsync(RedisClient client, PipeWriter output, TypedRedisValue value) { - static void WritePrefix(PipeWriter ooutput, char pprefix) + static void WritePrefix(IBufferWriter output, char prefix) { - var span = ooutput.GetSpan(1); - span[0] = (byte)pprefix; - ooutput.Advance(1); + var span = output.GetSpan(1); + span[0] = (byte)prefix; + output.Advance(1); } if (value.IsNil) return; // not actually a request (i.e. empty/whitespace request) if (client != null && client.ShouldSkipResponse()) return; // intentionally skipping the result char prefix; + switch (value.Type.ToResp2()) { case ResultType.Integer: - PhysicalConnection.WriteInteger(output, (long)value.AsRedisValue()); + MessageWriter.WriteInteger(output, (long)value.AsRedisValue()); break; case ResultType.Error: prefix = '-'; @@ -333,21 +334,21 @@ static void WritePrefix(PipeWriter ooutput, char pprefix) WritePrefix(output, prefix); var val = (string)value.AsRedisValue(); var expectedLength = Encoding.UTF8.GetByteCount(val); - PhysicalConnection.WriteRaw(output, val, expectedLength); - PhysicalConnection.WriteCrlf(output); + MessageWriter.WriteRaw(output, val, expectedLength); + MessageWriter.WriteCrlf(output); break; case ResultType.BulkString: - PhysicalConnection.WriteBulkString(value.AsRedisValue(), output); + MessageWriter.WriteBulkString(value.AsRedisValue(), output); break; case ResultType.Array: if (value.IsNullArray) { - PhysicalConnection.WriteMultiBulkHeader(output, -1); + MessageWriter.WriteMultiBulkHeader(output, -1); } else { var segment = value.Segment; - PhysicalConnection.WriteMultiBulkHeader(output, segment.Count); + MessageWriter.WriteMultiBulkHeader(output, segment.Count); var arr = segment.Array; int offset = segment.Offset; for (int i = 0; i < segment.Count; i++) @@ -383,8 +384,6 @@ private static bool TryParseRequest(Arena arena, ref ReadOnlySequence return false; } - private readonly Arena _arena = new Arena(); - public ValueTask TryProcessRequestAsync(ref ReadOnlySequence buffer, RedisClient client, PipeWriter output) { static async ValueTask Awaited(ValueTask wwrite, TypedRedisValue rresponse) From d09b8f66fbc1df51355faf36f6746d98cd201618 Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Tue, 17 Feb 2026 14:36:24 +0000 Subject: [PATCH 09/11] fix integer unit tests --- src/RESPite/Messages/RespAttributeReader.cs | 6 +-- .../RespReader.AggregateEnumerator.cs | 52 +++++++++++++------ src/RESPite/PublicAPI/PublicAPI.Unshipped.txt | 3 +- tests/RESPite.Tests/RespReaderTests.cs | 2 +- 4 files changed, 42 insertions(+), 21 deletions(-) diff --git a/src/RESPite/Messages/RespAttributeReader.cs b/src/RESPite/Messages/RespAttributeReader.cs index bfeaede79..9d61802c0 100644 --- a/src/RESPite/Messages/RespAttributeReader.cs +++ b/src/RESPite/Messages/RespAttributeReader.cs @@ -29,13 +29,13 @@ protected virtual int ReadKeyValuePairs(ref RespReader reader, ref T value) byte[] pooledBuffer = []; Span localBuffer = stackalloc byte[128]; int count = 0; - while (iterator.MoveNext() && iterator.Value.TryReadNext()) + while (iterator.MoveNext()) { if (iterator.Value.IsScalar) { var key = iterator.Value.Buffer(ref pooledBuffer, localBuffer); - if (iterator.MoveNext() && iterator.Value.TryReadNext()) + if (iterator.MoveNext()) { if (ReadKeyValuePair(key, ref iterator.Value, ref value)) { @@ -49,7 +49,7 @@ protected virtual int ReadKeyValuePairs(ref RespReader reader, ref T value) } else { - if (iterator.MoveNext() && iterator.Value.TryReadNext()) + if (iterator.MoveNext()) { // we won't try to handle aggregate keys; skip the value } diff --git a/src/RESPite/Messages/RespReader.AggregateEnumerator.cs b/src/RESPite/Messages/RespReader.AggregateEnumerator.cs index be10cd5cb..cd9892b68 100644 --- a/src/RESPite/Messages/RespReader.AggregateEnumerator.cs +++ b/src/RESPite/Messages/RespReader.AggregateEnumerator.cs @@ -61,7 +61,7 @@ public AggregateEnumerator(scoped in RespReader reader) /// public bool MoveNext(RespPrefix prefix) { - bool result = MoveNext(); + bool result = MoveNextRaw(); if (result) { Value.MoveNext(prefix); @@ -75,7 +75,7 @@ public bool MoveNext(RespPrefix prefix) /// The type of data represented by this reader. public bool MoveNext(RespPrefix prefix, RespAttributeReader respAttributeReader, ref T attributes) { - bool result = MoveNext(respAttributeReader, ref attributes); + bool result = MoveNextRaw(respAttributeReader, ref attributes); if (result) { Value.MoveNext(prefix); @@ -83,16 +83,38 @@ public bool MoveNext(RespPrefix prefix, RespAttributeReader respAttributeR return result; } - /// > - public bool MoveNext() + /// + /// Move to the next child and leave the reader *ahead of* the first element, + /// allowing us to read attribute data. + /// + /// If you are not consuming attribute data, is preferred. + public bool MoveNextRaw() { object? attributes = null; return MoveNextCore(null, ref attributes); } - /// > - /// The type of data represented by this reader. - public bool MoveNext(RespAttributeReader respAttributeReader, ref T attributes) + /// + /// Move to the next child and move into the first element (skipping attributes etc), leaving it ready to consume. + /// + public bool MoveNext() + { + object? attributes = null; + if (MoveNextCore(null, ref attributes)) + { + Value.MoveNext(); + return true; + } + return false; + } + + /// + /// Move to the next child (capturing attribute data) and leave the reader *ahead of* the first element, + /// allowing us to also read attribute data of the child. + /// + /// The type of attribute data represented by this reader. + /// If you are not consuming attribute data, is preferred. + public bool MoveNextRaw(RespAttributeReader respAttributeReader, ref T attributes) => MoveNextCore(respAttributeReader, ref attributes); /// > @@ -146,14 +168,16 @@ private bool MoveNextCore(RespAttributeReader? attributeReader, ref T attr /// used to update a tree reader, to get to the next data after the aggregate. public void MovePast(out RespReader reader) { - while (MoveNext()) { } + while (MoveNextRaw()) { } reader = _reader; } + /// + /// Moves to the next element, and moves into that element (skipping attributes etc), leaving it ready to consume. + /// public void DemandNext() { if (!MoveNext()) ThrowEof(); - Value.MoveNext(); // skip any attributes etc } public T ReadOne(Projection projection) @@ -166,9 +190,7 @@ public void FillAll(scoped Span target, Projection projection) { for (int i = 0; i < target.Length; i++) { - if (!MoveNext()) ThrowEof(); - - Value.MoveNext(); // skip any attributes etc + DemandNext(); target[i] = projection(ref Value); } } @@ -181,14 +203,12 @@ public void FillAll( { for (int i = 0; i < target.Length; i++) { - if (!MoveNext()) ThrowEof(); + DemandNext(); - Value.MoveNext(); // skip any attributes etc var x = first(ref Value); - if (!MoveNext()) ThrowEof(); + DemandNext(); - Value.MoveNext(); // skip any attributes etc var y = second(ref Value); target[i] = combine(x, y); } diff --git a/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt b/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt index 123e6c86e..d81dc0388 100644 --- a/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt +++ b/src/RESPite/PublicAPI/PublicAPI.Unshipped.txt @@ -20,6 +20,8 @@ [SER004]RESPite.Buffers.CycleBuffer.UncommittedAvailable.get -> int [SER004]RESPite.Buffers.CycleBuffer.Write(in System.Buffers.ReadOnlySequence value) -> void [SER004]RESPite.Buffers.CycleBuffer.Write(System.ReadOnlySpan value) -> void +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.MoveNextRaw() -> bool +[SER004]RESPite.Messages.RespReader.AggregateEnumerator.MoveNextRaw(RESPite.Messages.RespAttributeReader! respAttributeReader, ref T attributes) -> bool [SER004]static RESPite.Buffers.CycleBuffer.Create(System.Buffers.MemoryPool? pool = null, int pageSize = 8192) -> RESPite.Buffers.CycleBuffer [SER004]const RESPite.Messages.RespScanState.MinBytes = 3 -> int [SER004]override RESPite.Messages.RespScanState.Equals(object? obj) -> bool @@ -62,7 +64,6 @@ [SER004]RESPite.Messages.RespReader.AggregateEnumerator.GetEnumerator() -> RESPite.Messages.RespReader.AggregateEnumerator [SER004]RESPite.Messages.RespReader.AggregateEnumerator.MoveNext() -> bool [SER004]RESPite.Messages.RespReader.AggregateEnumerator.MoveNext(RESPite.Messages.RespPrefix prefix) -> bool -[SER004]RESPite.Messages.RespReader.AggregateEnumerator.MoveNext(RESPite.Messages.RespAttributeReader! respAttributeReader, ref T attributes) -> bool [SER004]RESPite.Messages.RespReader.AggregateEnumerator.MoveNext(RESPite.Messages.RespPrefix prefix, RESPite.Messages.RespAttributeReader! respAttributeReader, ref T attributes) -> bool [SER004]RESPite.Messages.RespReader.AggregateEnumerator.MovePast(out RESPite.Messages.RespReader reader) -> void [SER004]RESPite.Messages.RespReader.AggregateEnumerator.ReadOne(RESPite.Messages.RespReader.Projection! projection) -> T diff --git a/tests/RESPite.Tests/RespReaderTests.cs b/tests/RESPite.Tests/RespReaderTests.cs index 690235795..bbc602a2c 100644 --- a/tests/RESPite.Tests/RespReaderTests.cs +++ b/tests/RESPite.Tests/RespReaderTests.cs @@ -729,7 +729,7 @@ public void AttributeInner(RespPayload payload) Assert.Equal(3, iterator.Value.ReadInt32()); iterator.Value.DemandEnd(); - Assert.False(iterator.MoveNext(TestAttributeReader.Instance, ref state)); + Assert.False(iterator.MoveNextRaw(TestAttributeReader.Instance, ref state)); Assert.Equal(0, state.Count); iterator.MovePast(out reader); reader.DemandEnd(); From c64f0d90d96fc9c58cfef3eafe613ca86917fe3e Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Tue, 17 Feb 2026 15:24:23 +0000 Subject: [PATCH 10/11] start migrating processors --- .../RespReaderExtensions.cs | 20 +++ src/StackExchange.Redis/ResultProcessor.cs | 120 +++++++++--------- .../ResultProcessorUnitTests.cs | 54 ++++++++ 3 files changed, 131 insertions(+), 63 deletions(-) diff --git a/src/StackExchange.Redis/RespReaderExtensions.cs b/src/StackExchange.Redis/RespReaderExtensions.cs index 6714fa5a0..52c1812b1 100644 --- a/src/StackExchange.Redis/RespReaderExtensions.cs +++ b/src/StackExchange.Redis/RespReaderExtensions.cs @@ -53,6 +53,26 @@ public RespPrefix GetFirstPrefix() } return prefix; } + + public bool AggregateHasAtLeast(int count) + { + reader.DemandAggregate(); + if (reader.IsNull) return false; + if (reader.IsStreaming) return CheckStreamingAggregateAtLeast(in reader, count); + return reader.AggregateLength() >= count; + + static bool CheckStreamingAggregateAtLeast(in RespReader reader, int count) + { + var iter = reader.AggregateChildren(); + object? attributes = null; + while (count > 0 && iter.MoveNextRaw(null!, ref attributes)) + { + count--; + } + + return count == 0; + } + } } extension(ref RespReader reader) diff --git a/src/StackExchange.Redis/ResultProcessor.cs b/src/StackExchange.Redis/ResultProcessor.cs index b7d795187..eafcc29f6 100644 --- a/src/StackExchange.Redis/ResultProcessor.cs +++ b/src/StackExchange.Redis/ResultProcessor.cs @@ -1312,20 +1312,20 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class DoubleProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - switch (result.Resp2TypeBulkString) + switch (reader.Resp2PrefixBulkString) { - case ResultType.Integer: - if (result.TryGetInt64(out long i64)) + case RespPrefix.Integer: + if (reader.TryReadInt64(out long i64)) { SetResult(message, i64); return true; } break; - case ResultType.SimpleString: - case ResultType.BulkString: - if (result.TryGetDouble(out double val)) + case RespPrefix.SimpleString: + case RespPrefix.BulkString: + if (reader.TryReadDouble(out double val)) { SetResult(message, val); return true; @@ -1403,14 +1403,14 @@ private sealed class Int64DefaultValueProcessor : ResultProcessor public Int64DefaultValueProcessor(long defaultValue) => _defaultValue = defaultValue; - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - if (result.IsNull) + if (reader.IsNull) { SetResult(message, _defaultValue); return true; } - if (result.Resp2TypeBulkString == ResultType.Integer && result.TryGetInt64(out var i64)) + if (reader.Resp2PrefixBulkString == RespPrefix.Integer && reader.TryReadInt64(out var i64)) { SetResult(message, i64); return true; @@ -1421,14 +1421,14 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private class Int64Processor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - switch (result.Resp2TypeBulkString) + switch (reader.Resp2PrefixBulkString) { - case ResultType.Integer: - case ResultType.SimpleString: - case ResultType.BulkString: - if (result.TryGetInt64(out long i64)) + case RespPrefix.Integer: + case RespPrefix.SimpleString: + case RespPrefix.BulkString: + if (reader.TryReadInt64(out long i64)) { SetResult(message, i64); return true; @@ -1532,18 +1532,19 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class PubSubNumSubProcessor : Int64Processor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - if (result.Resp2TypeArray == ResultType.Array) + var snapshot = reader; + if (reader.Resp2PrefixArray == RespPrefix.Array && reader.AggregateLength() == 2) { - var arr = result.GetItems(); - if (arr.Length == 2 && arr[1].TryGetInt64(out long val)) + var agg = reader.AggregateChildren(); + if (agg.MoveNext() && agg.MoveNext() && agg.Value.TryReadInt64(out long val)) { SetResult(message, val); return true; } } - return base.SetResultCore(connection, message, result); + return base.SetResultCore(connection, message, ref snapshot); } } @@ -1563,19 +1564,19 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class NullableDoubleProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - switch (result.Resp2TypeBulkString) + switch (reader.Resp2PrefixBulkString) { - case ResultType.Integer: - case ResultType.SimpleString: - case ResultType.BulkString: - if (result.IsNull) + case RespPrefix.Integer: + case RespPrefix.SimpleString: + case RespPrefix.BulkString: + if (reader.IsNull) { SetResult(message, null); return true; } - if (result.TryGetDouble(out double val)) + if (reader.TryReadDouble(out double val)) { SetResult(message, val); return true; @@ -1588,35 +1589,28 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class NullableInt64Processor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - switch (result.Resp2TypeBulkString) + switch (reader.Resp2PrefixBulkString) { - case ResultType.Integer: - case ResultType.SimpleString: - case ResultType.BulkString: - if (result.IsNull) + case RespPrefix.Integer: + case RespPrefix.SimpleString: + case RespPrefix.BulkString: + if (reader.IsNull) { SetResult(message, null); return true; } - if (result.TryGetInt64(out long i64)) + if (reader.TryReadInt64(out long i64)) { SetResult(message, i64); return true; } break; - case ResultType.Array: - var items = result.GetItems(); - if (items.Length == 1) - { // treat an array of 1 like a single reply - if (items[0].TryGetInt64(out long value)) - { - SetResult(message, value); - return true; - } - } - break; + case RespPrefix.Array when reader.TryReadNext() && reader.IsScalar && reader.TryReadInt64(out long value) && !reader.TryReadNext(): + // treat an array of 1 like a single reply + SetResult(message, value); + return true; } return false; } @@ -1703,14 +1697,14 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class RedisKeyProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - switch (result.Resp2TypeBulkString) + switch (reader.Resp2PrefixBulkString) { - case ResultType.Integer: - case ResultType.SimpleString: - case ResultType.BulkString: - SetResult(message, result.AsRedisKey()); + case RespPrefix.Integer: + case RespPrefix.SimpleString: + case RespPrefix.BulkString: + SetResult(message, reader.ReadByteArray()); return true; } return false; @@ -1719,13 +1713,13 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class RedisTypeProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - switch (result.Resp2TypeBulkString) + switch (reader.Resp2PrefixBulkString) { - case ResultType.SimpleString: - case ResultType.BulkString: - string s = result.GetString()!; + case RespPrefix.SimpleString: + case RespPrefix.BulkString: + string s = reader.ReadString()!; RedisType value; if (string.Equals(s, "zset", StringComparison.OrdinalIgnoreCase)) value = Redis.RedisType.SortedSet; else if (!Enum.TryParse(s, true, out value)) value = global::StackExchange.Redis.RedisType.Unknown; @@ -1977,14 +1971,14 @@ private static LCSMatchResult Parse(in RawResult result) private sealed class RedisValueProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - switch (result.Resp2TypeBulkString) + switch (reader.Resp2PrefixBulkString) { - case ResultType.Integer: - case ResultType.SimpleString: - case ResultType.BulkString: - SetResult(message, result.AsRedisValue()); + case RespPrefix.Integer: + case RespPrefix.SimpleString: + case RespPrefix.BulkString: + SetResult(message, reader.ReadRedisValue()); return true; } return false; diff --git a/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs b/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs index 41560290e..cca08e057 100644 --- a/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs +++ b/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs @@ -10,13 +10,18 @@ namespace StackExchange.Redis.Tests; public class ResultProcessorUnitTests(ITestOutputHelper log) { + private const string ATTRIB_FOO_BAR = "|1\r\n+foo\r\n+bar\r\n"; + [Theory] [InlineData(":1\r\n", 1)] [InlineData("+1\r\n", 1)] [InlineData("$1\r\n1\r\n", 1)] + [InlineData(",1\r\n", 1)] + [InlineData(ATTRIB_FOO_BAR + ":1\r\n", 1)] [InlineData(":-42\r\n", -42)] [InlineData("+-42\r\n", -42)] [InlineData("$3\r\n-42\r\n", -42)] + [InlineData(",-42\r\n", -42)] public void Int32(string resp, int value) => Assert.Equal(value, Execute(resp, ResultProcessor.Int32)); [Theory] @@ -28,9 +33,12 @@ public class ResultProcessorUnitTests(ITestOutputHelper log) [InlineData(":1\r\n", 1)] [InlineData("+1\r\n", 1)] [InlineData("$1\r\n1\r\n", 1)] + [InlineData(",1\r\n", 1)] + [InlineData(ATTRIB_FOO_BAR + ":1\r\n", 1)] [InlineData(":-42\r\n", -42)] [InlineData("+-42\r\n", -42)] [InlineData("$3\r\n-42\r\n", -42)] + [InlineData(",-42\r\n", -42)] public void Int64(string resp, long value) => Assert.Equal(value, Execute(resp, ResultProcessor.Int64)); [Theory] @@ -43,8 +51,54 @@ public class ResultProcessorUnitTests(ITestOutputHelper log) [InlineData("*0\r\n", "")] [InlineData("*1\r\n+42\r\n", "42")] [InlineData("*2\r\n+42\r\n:78\r\n", "42,78")] + [InlineData(ATTRIB_FOO_BAR + "*1\r\n+42\r\n", "42")] public void Int64Array(string resp, string? value) => Assert.Equal(value, Join(Execute(resp, ResultProcessor.Int64Array))); + [Theory] + [InlineData(":42\r\n", 42.0)] + [InlineData("+3.14\r\n", 3.14)] + [InlineData("$4\r\n3.14\r\n", 3.14)] + [InlineData(",3.14\r\n", 3.14)] + [InlineData(ATTRIB_FOO_BAR + ",3.14\r\n", 3.14)] + [InlineData(":-1\r\n", -1.0)] + [InlineData("+inf\r\n", double.PositiveInfinity)] + [InlineData(",inf\r\n", double.PositiveInfinity)] + [InlineData("$4\r\n-inf\r\n", double.NegativeInfinity)] + [InlineData(",-inf\r\n", double.NegativeInfinity)] + [InlineData(",nan\r\n", double.NaN)] + public void Double(string resp, double value) => Assert.Equal(value, Execute(resp, ResultProcessor.Double)); + + [Theory] + [InlineData("_\r\n", null)] + [InlineData("$-1\r\n", null)] + [InlineData(":42\r\n", 42L)] + [InlineData("+42\r\n", 42L)] + [InlineData("$2\r\n42\r\n", 42L)] + [InlineData(",42\r\n", 42L)] + [InlineData(ATTRIB_FOO_BAR + ":42\r\n", 42L)] + public void NullableInt64(string resp, long? value) => Assert.Equal(value, Execute(resp, ResultProcessor.NullableInt64)); + + [Theory] + [InlineData("*1\r\n:99\r\n", 99L)] + [InlineData(ATTRIB_FOO_BAR + "*1\r\n:99\r\n", 99L)] + public void NullableInt64ArrayOfOne(string resp, long? value) => Assert.Equal(value, Execute(resp, ResultProcessor.NullableInt64)); + + [Theory] + [InlineData("*-1\r\n")] // null array + [InlineData("*0\r\n")] // empty array + [InlineData("*2\r\n:1\r\n:2\r\n")] // two elements + public void FailingNullableInt64ArrayOfNonOne(string resp) => ExecuteUnexpected(resp, ResultProcessor.NullableInt64); + + [Theory] + [InlineData("_\r\n", null)] + [InlineData("$-1\r\n", null)] + [InlineData(":42\r\n", 42.0)] + [InlineData("+3.14\r\n", 3.14)] + [InlineData("$4\r\n3.14\r\n", 3.14)] + [InlineData(",3.14\r\n", 3.14)] + [InlineData(ATTRIB_FOO_BAR + ",3.14\r\n", 3.14)] + public void NullableDouble(string resp, double? value) => Assert.Equal(value, Execute(resp, ResultProcessor.NullableDouble)); + [return: NotNullIfNotNull(nameof(array))] protected static string? Join(T[]? array, string separator = ",") { From a970fce3b9ffff6c7414052eb6024fca158677f4 Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Tue, 17 Feb 2026 15:46:10 +0000 Subject: [PATCH 11/11] more processors --- src/StackExchange.Redis/ResultProcessor.cs | 59 +++++++++---------- .../ResultProcessorUnitTests.cs | 42 +++++++++++++ 2 files changed, 70 insertions(+), 31 deletions(-) diff --git a/src/StackExchange.Redis/ResultProcessor.cs b/src/StackExchange.Redis/ResultProcessor.cs index eafcc29f6..ddc8e535c 100644 --- a/src/StackExchange.Redis/ResultProcessor.cs +++ b/src/StackExchange.Redis/ResultProcessor.cs @@ -1120,37 +1120,33 @@ private static bool TryParseRole(string? val, out bool isReplica) private sealed class BooleanProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - if (result.IsNull) + if (reader.IsNull) { SetResult(message, false); // lots of ops return (nil) when they mean "no" return true; } - switch (result.Resp2TypeBulkString) + switch (reader.Resp2PrefixBulkString) { - case ResultType.SimpleString: - if (result.IsEqual(CommonReplies.OK)) + case RespPrefix.SimpleString: + if (reader.IsOK()) { SetResult(message, true); } else { - SetResult(message, result.GetBoolean()); + SetResult(message, reader.ReadBoolean()); } return true; - case ResultType.Integer: - case ResultType.BulkString: - SetResult(message, result.GetBoolean()); + case RespPrefix.Integer: + case RespPrefix.BulkString: + SetResult(message, reader.ReadBoolean()); + return true; + case RespPrefix.Array when reader.TryReadNext() && reader.IsScalar && reader.ReadBoolean() is var value && !reader.TryReadNext(): + // treat an array of 1 like a single reply (for example, SCRIPT EXISTS) + SetResult(message, value); return true; - case ResultType.Array: - var items = result.GetItems(); - if (items.Length == 1) - { // treat an array of 1 like a single reply (for example, SCRIPT EXISTS) - SetResult(message, items[0].GetBoolean()); - return true; - } - break; } return false; } @@ -1158,12 +1154,12 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes private sealed class ByteArrayProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - switch (result.Resp2TypeBulkString) + switch (reader.Resp2PrefixBulkString) { - case ResultType.BulkString: - SetResult(message, result.GetBlob()); + case RespPrefix.BulkString: + SetResult(message, reader.ReadByteArray()); return true; } return false; @@ -2838,20 +2834,21 @@ protected override KeyValuePair Parse(in RawResult first, in Raw private sealed class StringProcessor : ResultProcessor { - protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) + protected override bool SetResultCore(PhysicalConnection connection, Message message, ref RespReader reader) { - switch (result.Resp2TypeBulkString) + switch (reader.Resp2PrefixBulkString) { - case ResultType.Integer: - case ResultType.SimpleString: - case ResultType.BulkString: - SetResult(message, result.GetString()); + case RespPrefix.Integer: + case RespPrefix.SimpleString: + case RespPrefix.BulkString: + SetResult(message, reader.ReadString()); return true; - case ResultType.Array: - var arr = result.GetItems(); - if (arr.Length == 1) + case RespPrefix.Array when reader.TryReadNext() && reader.IsScalar: + // treat an array of 1 like a single reply + var value = reader.ReadString(); + if (!reader.TryReadNext()) { - SetResult(message, arr[0].GetString()); + SetResult(message, value); return true; } break; diff --git a/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs b/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs index cca08e057..22f37ce0b 100644 --- a/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs +++ b/tests/StackExchange.Redis.Tests/ResultProcessorUnitTests.cs @@ -99,6 +99,48 @@ public class ResultProcessorUnitTests(ITestOutputHelper log) [InlineData(ATTRIB_FOO_BAR + ",3.14\r\n", 3.14)] public void NullableDouble(string resp, double? value) => Assert.Equal(value, Execute(resp, ResultProcessor.NullableDouble)); + [Theory] + [InlineData("_\r\n", false)] // null = false + [InlineData(":0\r\n", false)] + [InlineData(":1\r\n", true)] + [InlineData("#f\r\n", false)] + [InlineData("#t\r\n", true)] + [InlineData("+OK\r\n", true)] + [InlineData(ATTRIB_FOO_BAR + ":1\r\n", true)] + public void Boolean(string resp, bool value) => Assert.Equal(value, Execute(resp, ResultProcessor.Boolean)); + + [Theory] + [InlineData("*1\r\n:1\r\n", true)] // SCRIPT EXISTS returns array + [InlineData("*1\r\n:0\r\n", false)] + [InlineData(ATTRIB_FOO_BAR + "*1\r\n:1\r\n", true)] + public void BooleanArrayOfOne(string resp, bool value) => Assert.Equal(value, Execute(resp, ResultProcessor.Boolean)); + + [Theory] + [InlineData("*0\r\n")] // empty array + [InlineData("*2\r\n:1\r\n:0\r\n")] // two elements + [InlineData("*1\r\n*1\r\n:1\r\n")] // nested array (not scalar) + public void FailingBooleanArrayOfNonOne(string resp) => ExecuteUnexpected(resp, ResultProcessor.Boolean); + + [Theory] + [InlineData("$5\r\nhello\r\n", "hello")] + [InlineData("+world\r\n", "world")] + [InlineData(":42\r\n", "42")] + [InlineData("$-1\r\n", null)] + [InlineData(ATTRIB_FOO_BAR + "$3\r\nfoo\r\n", "foo")] + public void String(string resp, string? value) => Assert.Equal(value, Execute(resp, ResultProcessor.String)); + + [Theory] + [InlineData("*1\r\n$3\r\nbar\r\n", "bar")] + [InlineData(ATTRIB_FOO_BAR + "*1\r\n$3\r\nbar\r\n", "bar")] + public void StringArrayOfOne(string resp, string? value) => Assert.Equal(value, Execute(resp, ResultProcessor.String)); + + [Theory] + [InlineData("*-1\r\n")] // null array + [InlineData("*0\r\n")] // empty array + [InlineData("*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n")] // two elements + [InlineData("*1\r\n*1\r\n$3\r\nfoo\r\n")] // nested array (not scalar) + public void FailingStringArrayOfNonOne(string resp) => ExecuteUnexpected(resp, ResultProcessor.String); + [return: NotNullIfNotNull(nameof(array))] protected static string? Join(T[]? array, string separator = ",") {