diff --git a/.gitignore b/.gitignore index 6a2470a2..9fe843ca 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,11 @@ jepsen/.lein-* jepsen/.nrepl-port .m2/ jepsen/store/ + +# Jepsen local SSH keys (generated locally; never commit) +jepsen/docker/id_rsa +jepsen/.ssh/ + +# Build and lint cache directories +.cache/ +.golangci-cache/ diff --git a/Dockerfile b/Dockerfile index 429364b9..836fa997 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ FROM golang:latest AS build WORKDIR $GOPATH/src/app COPY . . -RUN CGO_ENABLED=0 go build -o /app main.go +RUN CGO_ENABLED=0 go build -o /app . FROM gcr.io/distroless/static:latest COPY --from=build /app /app diff --git a/adapter/dynamodb.go b/adapter/dynamodb.go index 857701ab..7154ba31 100644 --- a/adapter/dynamodb.go +++ b/adapter/dynamodb.go @@ -32,7 +32,7 @@ type DynamoDBServer struct { httpServer *http.Server } -func NewDynamoDBServer(listen net.Listener, st store.MVCCStore, coordinate *kv.Coordinate) *DynamoDBServer { +func NewDynamoDBServer(listen net.Listener, st store.MVCCStore, coordinate kv.Coordinator) *DynamoDBServer { d := &DynamoDBServer{ listen: listen, store: st, diff --git a/adapter/grpc.go b/adapter/grpc.go index b4cbc7cb..efd93b58 100644 --- a/adapter/grpc.go +++ b/adapter/grpc.go @@ -28,7 +28,7 @@ type GRPCServer struct { pb.UnimplementedTransactionalKVServer } -func NewGRPCServer(store store.MVCCStore, coordinate *kv.Coordinate) *GRPCServer { +func NewGRPCServer(store store.MVCCStore, coordinate kv.Coordinator) *GRPCServer { return &GRPCServer{ log: slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelWarn, @@ -45,7 +45,7 @@ func (r GRPCServer) RawGet(ctx context.Context, req *pb.RawGetRequest) (*pb.RawG readTS = snapshotTS(r.coordinator.Clock(), r.store) } - if r.coordinator.IsLeader() { + if r.coordinator.IsLeaderForKey(req.Key) { v, err := r.store.GetAt(ctx, req.Key, readTS) if err != nil { switch { @@ -83,7 +83,7 @@ func (r GRPCServer) RawGet(ctx context.Context, req *pb.RawGetRequest) (*pb.RawG } func (r GRPCServer) tryLeaderGet(key []byte) ([]byte, error) { - addr := r.coordinator.RaftLeader() + addr := r.coordinator.RaftLeaderForKey(key) if addr == "" { return nil, ErrLeaderNotFound } diff --git a/adapter/redis.go b/adapter/redis.go index ff233e83..9d28f83b 100644 --- a/adapter/redis.go +++ b/adapter/redis.go @@ -21,19 +21,34 @@ import ( "google.golang.org/grpc/credentials/insecure" ) +const ( + cmdGet = "GET" + cmdSet = "SET" + cmdDel = "DEL" + cmdExists = "EXISTS" + cmdPing = "PING" + cmdKeys = "KEYS" + cmdMulti = "MULTI" + cmdExec = "EXEC" + cmdDiscard = "DISCARD" + cmdLRange = "LRANGE" + cmdRPush = "RPUSH" + minKeyedArgs = 2 +) + //nolint:mnd var argsLen = map[string]int{ - "GET": 2, - "SET": 3, - "DEL": 2, - "EXISTS": 2, - "PING": 1, - "KEYS": 2, - "MULTI": 1, - "EXEC": 1, - "DISCARD": 1, - "LRANGE": 4, - "RPUSH": -3, // negative means minimum number of args + cmdGet: 2, + cmdSet: 3, + cmdDel: 2, + cmdExists: 2, + cmdPing: 1, + cmdKeys: 2, + cmdMulti: 1, + cmdExec: 1, + cmdDiscard: 1, + cmdLRange: 4, + cmdRPush: -3, // negative means minimum number of args } type RedisServer struct { @@ -72,7 +87,7 @@ type redisResult struct { err error } -func NewRedisServer(listen net.Listener, store store.MVCCStore, coordinate *kv.Coordinate, leaderRedis map[raft.ServerAddress]string) *RedisServer { +func NewRedisServer(listen net.Listener, store store.MVCCStore, coordinate kv.Coordinator, leaderRedis map[raft.ServerAddress]string) *RedisServer { r := &RedisServer{ listen: listen, store: store, @@ -82,17 +97,17 @@ func NewRedisServer(listen net.Listener, store store.MVCCStore, coordinate *kv.C } r.route = map[string]func(conn redcon.Conn, cmd redcon.Command){ - "PING": r.ping, - "SET": r.set, - "GET": r.get, - "DEL": r.del, - "EXISTS": r.exists, - "KEYS": r.keys, - "MULTI": r.multi, - "EXEC": r.exec, - "DISCARD": r.discard, - "RPUSH": r.rpush, - "LRANGE": r.lrange, + cmdPing: r.ping, + cmdSet: r.set, + cmdGet: r.get, + cmdDel: r.del, + cmdExists: r.exists, + cmdKeys: r.keys, + cmdMulti: r.multi, + cmdExec: r.exec, + cmdDiscard: r.discard, + cmdRPush: r.rpush, + cmdLRange: r.lrange, } return r @@ -129,7 +144,7 @@ func (r *RedisServer) Run() error { } name := strings.ToUpper(string(cmd.Args[0])) - if state.inTxn && name != "EXEC" && name != "DISCARD" && name != "MULTI" { + if state.inTxn && name != cmdExec && name != cmdDiscard && name != cmdMulti { state.queue = append(state.queue, cmd) conn.WriteString("QUEUED") return @@ -255,7 +270,7 @@ func (r *RedisServer) del(conn redcon.Conn, cmd redcon.Command) { } func (r *RedisServer) exists(conn redcon.Conn, cmd redcon.Command) { - if !r.coordinator.IsLeader() { + if !r.coordinator.IsLeaderForKey(cmd.Args[1]) { res, err := r.proxyExists(cmd.Args[1]) if err != nil { conn.WriteError(err.Error()) @@ -265,7 +280,7 @@ func (r *RedisServer) exists(conn redcon.Conn, cmd redcon.Command) { return } - if err := r.coordinator.VerifyLeader(); err != nil { + if err := r.coordinator.VerifyLeaderForKey(cmd.Args[1]); err != nil { conn.WriteError(err.Error()) return } @@ -517,17 +532,17 @@ func (t *txnContext) listLength(st *listTxnState) int64 { func (t *txnContext) apply(cmd redcon.Command) (redisResult, error) { switch strings.ToUpper(string(cmd.Args[0])) { - case "SET": + case cmdSet: return t.applySet(cmd) - case "DEL": + case cmdDel: return t.applyDel(cmd) - case "GET": + case cmdGet: return t.applyGet(cmd) - case "EXISTS": + case cmdExists: return t.applyExists(cmd) - case "RPUSH": + case cmdRPush: return t.applyRPush(cmd) - case "LRANGE": + case cmdLRange: return t.applyLRange(cmd) default: return redisResult{}, errors.WithStack(errors.Newf("ERR unsupported command '%s'", cmd.Args[0])) @@ -770,14 +785,7 @@ func (t *txnContext) buildListElems() ([]*kv.Elem[kv.OP], error) { } func (r *RedisServer) runTransaction(queue []redcon.Command) ([]redisResult, error) { - if err := r.coordinator.VerifyLeader(); err != nil { - return nil, errors.WithStack(err) - } - - startTS := r.coordinator.Clock().Next() - if last := r.store.LastCommitTS(); last > startTS { - startTS = last - } + startTS := r.txnStartTS(queue) ctx := &txnContext{ server: r, @@ -802,6 +810,57 @@ func (r *RedisServer) runTransaction(queue []redcon.Command) ([]redisResult, err return results, nil } +func (r *RedisServer) txnStartTS(queue []redcon.Command) uint64 { + maxTS := r.maxLatestCommitTS(queue) + if r.coordinator != nil && r.coordinator.Clock() != nil && maxTS > 0 { + r.coordinator.Clock().Observe(maxTS) + } + if r.coordinator == nil || r.coordinator.Clock() == nil { + return maxTS + } + return r.coordinator.Clock().Next() +} + +func (r *RedisServer) maxLatestCommitTS(queue []redcon.Command) uint64 { + var maxTS uint64 + if r.store == nil { + return maxTS + } + seen := make(map[string]struct{}) + for _, cmd := range queue { + if len(cmd.Args) < minKeyedArgs { + continue + } + name := strings.ToUpper(string(cmd.Args[0])) + switch name { + case cmdSet, cmdGet, cmdDel, cmdExists, cmdRPush, cmdLRange: + key := cmd.Args[1] + r.bumpLatestCommitTS(&maxTS, key, seen) + // Also account for list metadata keys to avoid stale typing decisions. + r.bumpLatestCommitTS(&maxTS, listMetaKey(key), seen) + } + } + return maxTS +} + +func (r *RedisServer) bumpLatestCommitTS(maxTS *uint64, key []byte, seen map[string]struct{}) { + if len(key) == 0 { + return + } + k := string(key) + if _, ok := seen[k]; ok { + return + } + seen[k] = struct{}{} + latest, exists, err := r.store.LatestCommitTS(context.Background(), key) + if err != nil || !exists { + return + } + if latest > *maxTS { + *maxTS = latest + } +} + func (r *RedisServer) proxyExec(conn redcon.Conn, queue []redcon.Command) error { leader := r.coordinator.RaftLeader() if leader == "" { @@ -909,13 +968,13 @@ func newProxyCmd(name string, args []string, ctx context.Context) redis.Cmder { } switch name { - case "SET": + case cmdSet: return redis.NewStatusCmd(ctx, argv...) - case "DEL", "EXISTS", "RPUSH": + case cmdDel, cmdExists, cmdRPush: return redis.NewIntCmd(ctx, argv...) - case "GET": + case cmdGet: return redis.NewStringCmd(ctx, argv...) - case "LRANGE": + case cmdLRange: return redis.NewStringSliceCmd(ctx, argv...) default: return redis.NewCmd(ctx, argv...) @@ -1082,11 +1141,11 @@ func (r *RedisServer) fetchListRange(ctx context.Context, key []byte, meta store func (r *RedisServer) rangeList(key []byte, startRaw, endRaw []byte) ([]string, error) { readTS := r.readTS() - if !r.coordinator.IsLeader() { + if !r.coordinator.IsLeaderForKey(key) { return r.proxyLRange(key, startRaw, endRaw) } - if err := r.coordinator.VerifyLeader(); err != nil { + if err := r.coordinator.VerifyLeaderForKey(key); err != nil { return nil, errors.WithStack(err) } @@ -1116,7 +1175,7 @@ func (r *RedisServer) rangeList(key []byte, startRaw, endRaw []byte) ([]string, } func (r *RedisServer) proxyLRange(key []byte, startRaw, endRaw []byte) ([]string, error) { - leader := r.coordinator.RaftLeader() + leader := r.coordinator.RaftLeaderForKey(key) if leader == "" { return nil, ErrLeaderNotFound } @@ -1142,7 +1201,7 @@ func (r *RedisServer) proxyLRange(key []byte, startRaw, endRaw []byte) ([]string } func (r *RedisServer) proxyRPush(key []byte, values [][]byte) (int64, error) { - leader := r.coordinator.RaftLeader() + leader := r.coordinator.RaftLeaderForKey(key) if leader == "" { return 0, ErrLeaderNotFound } @@ -1171,7 +1230,7 @@ func parseInt(b []byte) (int, error) { // tryLeaderGet proxies a GET to the current Raft leader, returning the value and // whether the proxy succeeded. func (r *RedisServer) tryLeaderGetAt(key []byte, ts uint64) ([]byte, error) { - addr := r.coordinator.RaftLeader() + addr := r.coordinator.RaftLeaderForKey(key) if addr == "" { return nil, ErrLeaderNotFound } @@ -1195,8 +1254,8 @@ func (r *RedisServer) tryLeaderGetAt(key []byte, ts uint64) ([]byte, error) { } func (r *RedisServer) readValueAt(key []byte, readTS uint64) ([]byte, error) { - if r.coordinator.IsLeader() { - if err := r.coordinator.VerifyLeader(); err != nil { + if r.coordinator.IsLeaderForKey(key) { + if err := r.coordinator.VerifyLeaderForKey(key); err != nil { return nil, errors.WithStack(err) } v, err := r.store.GetAt(context.Background(), key, readTS) @@ -1210,7 +1269,7 @@ func (r *RedisServer) rpush(conn redcon.Conn, cmd redcon.Command) { var length int64 var err error - if r.coordinator.IsLeader() { + if r.coordinator.IsLeaderForKey(cmd.Args[1]) { length, err = r.listRPush(ctx, cmd.Args[1], cmd.Args[2:]) } else { length, err = r.proxyRPush(cmd.Args[1], cmd.Args[2:]) diff --git a/adapter/redis_proxy.go b/adapter/redis_proxy.go index dfed0125..a5df7d4c 100644 --- a/adapter/redis_proxy.go +++ b/adapter/redis_proxy.go @@ -8,7 +8,7 @@ import ( ) func (r *RedisServer) proxyExists(key []byte) (int, error) { - leader := r.coordinator.RaftLeader() + leader := r.coordinator.RaftLeaderForKey(key) if leader == "" { return 0, ErrLeaderNotFound } diff --git a/distribution/engine.go b/distribution/engine.go index bff6abe7..986f2fa3 100644 --- a/distribution/engine.go +++ b/distribution/engine.go @@ -121,6 +121,36 @@ func (e *Engine) Stats() []Route { return stats } +// GetIntersectingRoutes returns all routes whose key ranges intersect with [start, end). +// A route [rStart, rEnd) intersects with [start, end) if: +// - rStart < end (or end is nil, meaning unbounded scan) +// - start < rEnd (or rEnd is nil, meaning unbounded route) +func (e *Engine) GetIntersectingRoutes(start, end []byte) []Route { + e.mu.RLock() + defer e.mu.RUnlock() + + var result []Route + for _, r := range e.routes { + // Check if route intersects with [start, end) + // Route ends before scan starts: rEnd != nil && rEnd <= start + if r.End != nil && bytes.Compare(r.End, start) <= 0 { + continue + } + // Route starts at or after scan ends: end != nil && rStart >= end + if end != nil && bytes.Compare(r.Start, end) >= 0 { + continue + } + // Route intersects with scan range + result = append(result, Route{ + Start: cloneBytes(r.Start), + End: cloneBytes(r.End), + GroupID: r.GroupID, + Load: atomic.LoadUint64(&r.Load), + }) + } + return result +} + func (e *Engine) routeIndex(key []byte) int { if len(e.routes) == 0 { return -1 diff --git a/distribution/engine_test.go b/distribution/engine_test.go index 69ee0373..8025de3d 100644 --- a/distribution/engine_test.go +++ b/distribution/engine_test.go @@ -160,3 +160,74 @@ func assertRange(t *testing.T, r Route, start, end []byte) { t.Errorf("expected range [%q, %q), got [%q, %q]", start, end, r.Start, r.End) } } + +func TestEngineGetIntersectingRoutes(t *testing.T) { + e := NewEngine() + e.UpdateRoute([]byte("a"), []byte("m"), 1) + e.UpdateRoute([]byte("m"), []byte("z"), 2) + e.UpdateRoute([]byte("z"), nil, 3) + + cases := []struct { + name string + start []byte + end []byte + groups []uint64 + }{ + { + name: "scan in first range", + start: []byte("b"), + end: []byte("d"), + groups: []uint64{1}, + }, + { + name: "scan across first two ranges", + start: []byte("k"), + end: []byte("p"), + groups: []uint64{1, 2}, + }, + { + name: "scan across all ranges", + start: []byte("a"), + end: nil, + groups: []uint64{1, 2, 3}, + }, + { + name: "scan in last unbounded range", + start: []byte("za"), + end: nil, + groups: []uint64{3}, + }, + { + name: "scan before first range", + start: []byte("0"), + end: []byte("9"), + groups: []uint64{}, + }, + { + name: "scan at boundary", + start: []byte("m"), + end: []byte("n"), + groups: []uint64{2}, + }, + { + name: "scan ending at boundary", + start: []byte("k"), + end: []byte("m"), + groups: []uint64{1}, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + routes := e.GetIntersectingRoutes(c.start, c.end) + if len(routes) != len(c.groups) { + t.Fatalf("expected %d routes, got %d", len(c.groups), len(routes)) + } + for i, expectedGroup := range c.groups { + if routes[i].GroupID != expectedGroup { + t.Errorf("route %d: expected group %d, got %d", i, expectedGroup, routes[i].GroupID) + } + } + }) + } +} diff --git a/jepsen/Vagrantfile b/jepsen/Vagrantfile index d96d3bf5..57fa855f 100644 --- a/jepsen/Vagrantfile +++ b/jepsen/Vagrantfile @@ -1,3 +1,5 @@ +require "fileutils" + NODES = { ctrl: "192.168.56.10", n1: "192.168.56.11", @@ -7,6 +9,19 @@ NODES = { n5: "192.168.56.15" }.freeze +KEY_DIR = File.join(__dir__, ".ssh") +CTRL_KEY = File.join(KEY_DIR, "ctrl_id_rsa") +CTRL_PUB = "#{CTRL_KEY}.pub" + +unless File.exist?(CTRL_KEY) + FileUtils.mkdir_p(KEY_DIR) + unless system("ssh-keygen", "-t", "rsa", "-b", "2048", "-N", "", "-f", CTRL_KEY) + raise "failed to generate Jepsen SSH key at #{CTRL_KEY}" + end +end + +CTRL_PUB_KEY = File.read(CTRL_PUB).strip + Vagrant.configure("2") do |config| config.ssh.insert_key = false #config.vm.box = "debian/bookworm64" @@ -37,7 +52,7 @@ Vagrant.configure("2") do |config| node.vm.synced_folder ".", "/vagrant", disabled: true end - node.vm.provision "shell", path: "provision/base.sh", args: name.to_s + node.vm.provision "shell", path: "provision/base.sh", args: [name.to_s, CTRL_PUB_KEY] end end end diff --git a/jepsen/docker/run-in-docker.sh b/jepsen/docker/run-in-docker.sh new file mode 100644 index 00000000..730c09c8 --- /dev/null +++ b/jepsen/docker/run-in-docker.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Copy source to writable area +mkdir -p /root/elastickv +rsync -a /jepsen-ro/ /root/elastickv/ --exclude .git --exclude jepsen/target --exclude jepsen/tmp-home + +cd /root/elastickv/jepsen + +# Install Go +if ! command -v go >/dev/null 2>&1; then + GO_VERSION=1.25.5 + ARCH="amd64" # Assuming amd64 for now, or detect + if [ "$(uname -m)" = "aarch64" ]; then ARCH="arm64"; fi + + curl -fsSL "https://go.dev/dl/go${GO_VERSION}.linux-${ARCH}.tar.gz" -o go.tar.gz + tar -C /usr/local -xzf go.tar.gz + export PATH=$PATH:/usr/local/go/bin +fi + +# Install Leiningen +if ! command -v lein >/dev/null 2>&1; then + curl -L https://raw.githubusercontent.com/technomancy/leiningen/stable/bin/lein > /usr/local/bin/lein + chmod +x /usr/local/bin/lein +fi + +# Generate or install SSH key for control node to connect to others +if [ ! -f /root/.ssh/id_rsa ]; then + mkdir -p /root/.ssh + if [ -n "${JEPSEN_SSH_PRIVATE_KEY:-}" ]; then + printf "%s" "${JEPSEN_SSH_PRIVATE_KEY}" > /root/.ssh/id_rsa + elif [ -n "${JEPSEN_SSH_PRIVATE_KEY_PATH:-}" ] && [ -f "${JEPSEN_SSH_PRIVATE_KEY_PATH}" ]; then + cp "${JEPSEN_SSH_PRIVATE_KEY_PATH}" /root/.ssh/id_rsa + elif [ -f /jepsen-ro/jepsen/docker/id_rsa ]; then + # Backward-compatible path (local, uncommitted key file) + cp /jepsen-ro/jepsen/docker/id_rsa /root/.ssh/id_rsa + else + if ! command -v ssh-keygen >/dev/null 2>&1; then + apt-get update -y + apt-get install -y --no-install-recommends openssh-client + fi + ssh-keygen -t rsa -b 2048 -N "" -f /root/.ssh/id_rsa + fi + chmod 600 /root/.ssh/id_rsa + # Disable strict host checking + echo "Host *" > /root/.ssh/config + echo " StrictHostKeyChecking no" >> /root/.ssh/config + echo " UserKnownHostsFile /dev/null" >> /root/.ssh/config + echo " User vagrant" >> /root/.ssh/config +fi + +# Run test +# Nodes are reachable by hostname (n1, n2...) in docker network +export LEIN_ROOT=true +lein run -m elastickv.redis-workload \ + --nodes n1,n2,n3,n4,n5 \ + --time-limit 60 \ + --rate 10 \ + --faults partition,kill,clock \ + --concurrency 10 diff --git a/jepsen/docker/ssh_config b/jepsen/docker/ssh_config new file mode 100644 index 00000000..388f5211 --- /dev/null +++ b/jepsen/docker/ssh_config @@ -0,0 +1,44 @@ +Host n1 + HostName 127.0.0.1 + User vagrant + Port 2221 + IdentityFile ~/.ssh/id_rsa + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + LogLevel ERROR + +Host n2 + HostName 127.0.0.1 + User vagrant + Port 2222 + IdentityFile ~/.ssh/id_rsa + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + LogLevel ERROR + +Host n3 + HostName 127.0.0.1 + User vagrant + Port 2223 + IdentityFile ~/.ssh/id_rsa + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + LogLevel ERROR + +Host n4 + HostName 127.0.0.1 + User vagrant + Port 2224 + IdentityFile ~/.ssh/id_rsa + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + LogLevel ERROR + +Host n5 + HostName 127.0.0.1 + User vagrant + Port 2225 + IdentityFile ~/.ssh/id_rsa + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + LogLevel ERROR diff --git a/jepsen/provision/base.sh b/jepsen/provision/base.sh index 567d562f..f7f44313 100755 --- a/jepsen/provision/base.sh +++ b/jepsen/provision/base.sh @@ -2,6 +2,7 @@ set -euo pipefail ROLE="${1:-db}" +PUBKEY="${2:-}" echo "[jepsen] provisioning role=${ROLE}" sudo apt-get update -y @@ -44,38 +45,20 @@ if [ "$ROLE" = "ctrl" ]; then echo 'export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin' | sudo tee /etc/profile.d/go.sh >/dev/null if [ ! -f /home/vagrant/.ssh/id_rsa ]; then - cat <<'KEY' > /home/vagrant/.ssh/id_rsa ------BEGIN RSA PRIVATE KEY----- -MIIEogIBAAKCAQEA6NF8iallvQVp22WDkTkyrtvp9eWW6A8YVr+kz4TjGYe7gHzI -w+niNltGEFHzD8+v1I2YJ6oXevct1YeS0o9HZyN1Q9qgCgzUFtdOKLv6IedplqoP -kcmF0aYet2PkEDo3MlTBckFXPITAMzF8dJSIFo9D8HfdOV0IAdx4O7PtixWKn5y2 -hMNG0zQPyUecp4pzC6kivAIhyfHilFR61RGL+GPXQ2MWZWFYbAGjyiYJnAmCP3NO -Td0jMZEnDkbUvxhMmBYSdETk1rRgm+R4LOzFUGaHqHDLKLX+FIPKcF96hrucXzcW -yLbIbEgE98OHlnVYCzRdK8jlqm8tehUc9c9WhQIBIwKCAQEA4iqWPJXtzZA68mKd -ELs4jJsdyky+ewdZeNds5tjcnHU5zUYE25K+ffJED9qUWICcLZDc81TGWjHyAqD1 -Bw7XpgUwFgeUJwUlzQurAv+/ySnxiwuaGJfhFM1CaQHzfXphgVml+fZUvnJUTvzf -TK2Lg6EdbUE9TarUlBf/xPfuEhMSlIE5keb/Zz3/LUlRg8yDqz5w+QWVJ4utnKnK -iqwZN0mwpwU7YSyJhlT4YV1F3n4YjLswM5wJs2oqm0jssQu/BT0tyEXNDYBLEF4A -sClaWuSJ2kjq7KhrrYXzagqhnSei9ODYFShJu8UWVec3Ihb5ZXlzO6vdNQ1J9Xsf -4m+2ywKBgQD6qFxx/Rv9CNN96l/4rb14HKirC2o/orApiHmHDsURs5rUKDx0f9iP -cXN7S1uePXuJRK/5hsubaOCx3Owd2u9gD6Oq0CsMkE4CUSiJcYrMANtx54cGH7Rk -EjFZxK8xAv1ldELEyxrFqkbE4BKd8QOt414qjvTGyAK+OLD3M2QdCQKBgQDtx8pN -CAxR7yhHbIWT1AH66+XWN8bXq7l3RO/ukeaci98JfkbkxURZhtxV/HHuvUhnPLdX -3TwygPBYZFNo4pzVEhzWoTtnEtrFueKxyc3+LjZpuo+mBlQ6ORtfgkr9gBVphXZG -YEzkCD3lVdl8L4cw9BVpKrJCs1c5taGjDgdInQKBgHm/fVvv96bJxc9x1tffXAcj -3OVdUN0UgXNCSaf/3A/phbeBQe9xS+3mpc4r6qvx+iy69mNBeNZ0xOitIjpjBo2+ -dBEjSBwLk5q5tJqHmy/jKMJL4n9ROlx93XS+njxgibTvU6Fp9w+NOFD/HvxB3Tcz -6+jJF85D5BNAG3DBMKBjAoGBAOAxZvgsKN+JuENXsST7F89Tck2iTcQIT8g5rwWC -P9Vt74yboe2kDT531w8+egz7nAmRBKNM751U/95P9t88EDacDI/Z2OwnuFQHCPDF -llYOUI+SpLJ6/vURRbHSnnn8a/XG+nzedGH5JGqEJNQsz+xT2axM0/W/CRknmGaJ -kda/AoGANWrLCz708y7VYgAtW2Uf1DPOIYMdvo6fxIB5i9ZfISgcJ/bbCUkFrhoH -+vq/5CIWxCPp0f85R4qxxQ5ihxJ0YDQT9Jpx4TMss4PSavPaBH3RXow5Ohe+bYoQ -NE5OgEXk2wVfZczCZpigBKbKZHNYcelXtTt/nP3rsCuGcM4h53s= ------END RSA PRIVATE KEY----- -KEY + if [ -f /home/vagrant/elastickv/jepsen/.ssh/ctrl_id_rsa ]; then + cp /home/vagrant/elastickv/jepsen/.ssh/ctrl_id_rsa /home/vagrant/.ssh/id_rsa + else + if ! command -v ssh-keygen >/dev/null 2>&1; then + sudo apt-get install -y --no-install-recommends openssh-client + fi + ssh-keygen -t rsa -b 2048 -N "" -f /home/vagrant/.ssh/id_rsa + fi chmod 600 /home/vagrant/.ssh/id_rsa chown vagrant:vagrant /home/vagrant/.ssh/id_rsa fi + if [ -z "${PUBKEY}" ] && [ -f /home/vagrant/.ssh/id_rsa.pub ]; then + PUBKEY="$(cat /home/vagrant/.ssh/id_rsa.pub)" + fi cat <<'EOF' > /home/vagrant/.ssh/config Host n1 n2 n3 n4 n5 User vagrant @@ -87,10 +70,12 @@ EOF chown vagrant:vagrant /home/vagrant/.ssh/config fi -# authorize the same key on all nodes -cat <<'PUB' >> /home/vagrant/.ssh/authorized_keys -ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDo0XyJqWW9BWnbZYOROSyu2+n15ZbrgPGFa/pM+E4xmHu4B8yMPp4jbWRhR8w/Pr9SNmCeqF3r3LdWHktKPR2cjduPaoAoM1BbXTii7+iHnaZaqD5HJhXQhr3Y+QQOjcYVMFyQU8hMAzMF8dJSIFo9D8HfdOV0IAdx4O7PtixWKn5y2hMNG0zQPyUecp4pzC6kivAIhyfHilFR61RGL+GPXQ2MWZWFYbAGjyiYJnAmCP3NOTd0jMZEnDkbUvxhMmBYSdETk1rRgm+R4LOzFUGaHqHDLKLX+FIPKcF96hrucXzcWyLbIbEgE98OHlnVYCzRdK8jlqm8tehUc9c9WhQ== vagrant insecure public key -PUB +touch /home/vagrant/.ssh/authorized_keys +if [ -n "${PUBKEY}" ]; then + if ! grep -Fq "${PUBKEY}" /home/vagrant/.ssh/authorized_keys; then + echo "${PUBKEY}" >> /home/vagrant/.ssh/authorized_keys + fi +fi chown vagrant:vagrant /home/vagrant/.ssh/authorized_keys chmod 600 /home/vagrant/.ssh/authorized_keys diff --git a/jepsen/src/elastickv/db.clj b/jepsen/src/elastickv/db.clj index 63917a47..6248bd3c 100644 --- a/jepsen/src/elastickv/db.clj +++ b/jepsen/src/elastickv/db.clj @@ -71,19 +71,42 @@ (get port-spec node) port-spec)) -(defn- build-raft-redis-map [nodes grpc-port redis-port] - (->> nodes - (map (fn [n] - (let [g (node-addr n (port-for grpc-port n)) - r (node-addr n (port-for redis-port n))] - (str g "=" r)))) +(defn- group-ids [raft-groups] + (->> (keys raft-groups) + (sort))) + +(defn- group-addr [node raft-groups group-id] + (node-addr node (port-for (get raft-groups group-id) node))) + +(defn- build-raft-groups-arg [node raft-groups] + (->> (group-ids raft-groups) + (map (fn [gid] + (str gid "=" (group-addr node raft-groups gid)))) (clojure.string/join ","))) +(defn- build-raft-redis-map [nodes grpc-port redis-port raft-groups] + (let [groups (when (seq raft-groups) (group-ids raft-groups))] + (->> nodes + (mapcat (fn [n] + (let [redis (node-addr n (port-for redis-port n))] + (if (seq groups) + (map (fn [gid] + (str (group-addr n raft-groups gid) "=" redis)) + groups) + [(str (node-addr n (port-for grpc-port n)) "=" redis)])))) + (clojure.string/join ",")))) + (defn- start-node! - [test node {:keys [bootstrap-node grpc-port redis-port data-dir]}] - (let [grpc (node-addr node (port-for grpc-port node)) + [test node {:keys [bootstrap-node grpc-port redis-port data-dir raft-groups shard-ranges]}] + (when (and (seq raft-groups) + (> (count raft-groups) 1) + (nil? shard-ranges)) + (throw (ex-info "shard-ranges is required when raft-groups has multiple entries" {}))) + (let [grpc (if (seq raft-groups) + (group-addr node raft-groups (first (group-ids raft-groups))) + (node-addr node (port-for grpc-port node))) redis (node-addr node (port-for redis-port node)) - raft-redis-map (build-raft-redis-map (:nodes test) grpc-port redis-port) + raft-redis-map (build-raft-redis-map (:nodes test) grpc-port redis-port raft-groups) bootstrap? (= node bootstrap-node) args (cond-> [server-bin "--address" grpc @@ -91,6 +114,8 @@ "--raftId" (name node) "--raftDataDir" data-dir "--raftRedisMap" raft-redis-map] + (seq raft-groups) (conj "--raftGroups" (build-raft-groups-arg node raft-groups)) + (seq shard-ranges) (conj "--shardRanges" shard-ranges) bootstrap? (conj "--raftBootstrap"))] (c/on node (c/su @@ -111,10 +136,12 @@ (defn- wait-for-grpc! "Wait until the given node listens on grpc port." [node grpc-port] - (c/on node - (c/exec :bash "-c" - (format "for i in $(seq 1 60); do if nc -z -w 1 %s %s; then exit 0; fi; sleep 1; done; echo 'Timed out waiting for %s:%s'; exit 1" - (name node) grpc-port (name node) grpc-port)))) + (let [ports (if (sequential? grpc-port) grpc-port [grpc-port])] + (doseq [p ports] + (c/on node + (c/exec :bash "-c" + "for i in $(seq 1 60); do if nc -z -w 1 $1 $2; then exit 0; fi; sleep 1; done; echo \\\"Timed out waiting for $1:$2\\\"; exit 1" + "--" (name node) (str p)))))) (defn- join-node! "Join peer into cluster via raftadmin, executed on bootstrap node." @@ -138,14 +165,26 @@ :bootstrap-node (first (:nodes test))} opts)) (when (= node (first (:nodes test))) - (let [leader (node-addr node (or (:grpc-port opts) 50051))] + (let [raft-groups (:raft-groups opts) + grpc-port (or (:grpc-port opts) 50051) + group-ids (when (seq raft-groups) (group-ids raft-groups))] (doseq [peer (rest (:nodes test))] (util/await-fn (fn [] (try - (wait-for-grpc! peer (or (:grpc-port opts) 50051)) - (join-node! node leader (name peer) - (node-addr peer (or (:grpc-port opts) 50051))) + (if (seq raft-groups) + (doseq [gid group-ids] + (wait-for-grpc! peer (port-for (get raft-groups gid) peer)) + (join-node! node + (group-addr node raft-groups gid) + (name peer) + (group-addr peer raft-groups gid))) + (do + (wait-for-grpc! peer grpc-port) + (join-node! node + (node-addr node grpc-port) + (name peer) + (node-addr peer grpc-port)))) true (catch Throwable t (warn t "retrying join for" peer) @@ -171,7 +210,10 @@ :redis-port (or (:redis-port opts) 6379) :bootstrap-node (first (:nodes test))} opts)) - (wait-for-grpc! node (or (:grpc-port opts) 50051)) + (if-let [raft-groups (:raft-groups opts)] + (wait-for-grpc! node (map (fn [gid] (port-for (get raft-groups gid) node)) + (group-ids raft-groups))) + (wait-for-grpc! node (or (:grpc-port opts) 50051))) (info "node started" node) this) (kill! [this _test node] @@ -194,6 +236,8 @@ (defn db "Constructs an ElastickvDB with optional opts. - opts: {:grpc-port 50051 :redis-port 6379}" + opts: {:grpc-port 50051 :redis-port 6379 + :raft-groups {1 50051 2 50052} + :shard-ranges \":m=1,m:=2\"}" ([] (->ElastickvDB {})) ([opts] (->ElastickvDB opts))) diff --git a/jepsen/src/elastickv/redis_workload.clj b/jepsen/src/elastickv/redis_workload.clj index f844c1bf..4a5730c6 100644 --- a/jepsen/src/elastickv/redis_workload.clj +++ b/jepsen/src/elastickv/redis_workload.clj @@ -3,11 +3,14 @@ (:require [clojure.string :as str] [clojure.tools.cli :as tools.cli] [elastickv.db :as ekdb] + [jepsen.db :as jdb] [jepsen [client :as client] [core :as jepsen] [generator :as gen] [net :as net]] [jepsen.control :as control] + [jepsen.os :as os] + [jepsen.nemesis :as nemesis] [jepsen.nemesis.combined :as combined] [jepsen.os.debian :as debian] [jepsen.redis.client :as rc] @@ -82,44 +85,79 @@ (let [nodes (or (:nodes opts) default-nodes) redis-ports (or (:redis-ports opts) (repeat (count nodes) (or (:redis-port opts) 6379))) node->port (or (:node->port opts) (ports->node-map redis-ports nodes)) - db (ekdb/db {:grpc-port (or (:grpc-port opts) 50051) - :redis-port node->port}) + local? (:local opts) + db (if local? + jdb/noop + (ekdb/db {:grpc-port (or (:grpc-port opts) 50051) + :redis-port node->port + :raft-groups (:raft-groups opts) + :shard-ranges (:shard-ranges opts)})) rate (double (or (:rate opts) 5)) time-limit (or (:time-limit opts) 30) - faults (normalize-faults (or (:faults opts) [:partition :kill])) - nemesis-p (combined/nemesis-package {:db db - :faults faults - :interval (or (:fault-interval opts) 40)}) + faults (if local? + [] + (normalize-faults (or (:faults opts) [:partition :kill]))) + nemesis-p (when-not local? + (combined/nemesis-package {:db db + :faults faults + :interval (or (:fault-interval opts) 40)})) + nemesis-gen (if nemesis-p + (:generator nemesis-p) + (gen/once {:type :info :f :noop})) workload (elastickv-append-workload (assoc opts :node->port node->port))] (merge workload {:name (or (:name opts) "elastickv-redis-append") :nodes nodes :db db - :os debian/os - :net net/iptables + :redis-host (:redis-host opts) + :os (if local? os/noop debian/os) + :net (if local? net/noop net/iptables) :ssh (merge {:username "vagrant" :private-key-path "/home/vagrant/.ssh/id_rsa" :strict-host-key-checking false} + (when local? {:dummy true}) (:ssh opts)) :remote control/ssh - :nemesis (:nemesis nemesis-p) + :nemesis (if nemesis-p + (:nemesis nemesis-p) + nemesis/noop) ; Jepsen 0.3.x can't fressian-serialize some combined final gens; skip. :final-generator nil :concurrency (or (:concurrency opts) 5) :generator (->> (:generator workload) - (gen/nemesis (:generator nemesis-p)) + (gen/nemesis nemesis-gen) (gen/stagger (/ rate)) (gen/time-limit time-limit))})))) (def cli-opts [[nil "--nodes NODES" "Comma separated node names." :default "n1,n2,n3,n4,n5"] + [nil "--local" "Run locally without SSH or nemesis." + :default false] + [nil "--host HOST" "Redis host override for clients." + :default nil] + [nil "--ports PORTS" "Comma separated Redis ports (per node)." + :default nil + :parse-fn (fn [s] + (->> (str/split s #",") + (remove str/blank?) + (mapv #(Integer/parseInt %))))] [nil "--redis-port PORT" "Redis port (applied to all nodes)." :default 6379 :parse-fn #(Integer/parseInt %)] [nil "--grpc-port PORT" "gRPC/Raft port." :default 50051 :parse-fn #(Integer/parseInt %)] + [nil "--raft-groups GROUPS" "Comma separated raft groups (groupID=port,...)" + :parse-fn (fn [s] + (->> (str/split s #",") + (remove str/blank?) + (map (fn [part] + (let [[gid port] (str/split part #"=" 2)] + [(Long/parseLong gid) (Integer/parseInt port)]))) + (into {})))] + [nil "--shard-ranges RANGES" "Shard ranges (start:end=groupID,...)" + :default nil] [nil "--faults LIST" "Comma separated faults (partition,kill,clock)." :default "partition,kill,clock"] [nil "--ssh-key PATH" "SSH private key path." @@ -140,7 +178,13 @@ (defn -main [& args] (let [{:keys [options errors summary]} (tools.cli/parse-opts args cli-opts) - node-list (-> (:nodes options) + default-nodes "n1,n2,n3,n4,n5" + ports (:ports options) + local? (or (:local options) (and (:host options) (seq ports))) + nodes-raw (if (and ports (= (:nodes options) default-nodes)) + (str/join "," (map (fn [i] (str "n" i)) (range 1 (inc (count ports))))) + (:nodes options)) + node-list (-> nodes-raw (str/split #",") (->> (remove str/blank?) vec)) @@ -152,8 +196,13 @@ options (assoc options :nodes node-list :faults faults + :local local? + :redis-host (:host options) + :redis-ports ports :redis-port (:redis-port options) :grpc-port (:grpc-port options) + :raft-groups (:raft-groups options) + :shard-ranges (:shard-ranges options) :ssh {:username (:ssh-user options) :private-key-path (:ssh-key options) :strict-host-key-checking false})] @@ -161,4 +210,6 @@ (:help options) (println summary) (seq errors) (binding [*out* *err*] (println "Error parsing options:" (str/join "; " errors))) + (:local options) (binding [control/*dummy* true] + (jepsen/run! (elastickv-redis-test options))) :else (jepsen/run! (elastickv-redis-test options))))) diff --git a/kv/coordinator.go b/kv/coordinator.go index 8340105c..666866bf 100644 --- a/kv/coordinator.go +++ b/kv/coordinator.go @@ -35,6 +35,9 @@ type Coordinator interface { IsLeader() bool VerifyLeader() error RaftLeader() raft.ServerAddress + IsLeaderForKey(key []byte) bool + VerifyLeaderForKey(key []byte) error + RaftLeaderForKey(key []byte) raft.ServerAddress Clock() *HLC } @@ -73,15 +76,40 @@ func (c *Coordinate) Clock() *HLC { return c.clock } +func (c *Coordinate) IsLeaderForKey(_ []byte) bool { + return c.IsLeader() +} + +func (c *Coordinate) VerifyLeaderForKey(_ []byte) error { + return c.VerifyLeader() +} + +func (c *Coordinate) RaftLeaderForKey(_ []byte) raft.ServerAddress { + return c.RaftLeader() +} + func (c *Coordinate) nextStartTS() uint64 { return c.clock.Next() } func (c *Coordinate) dispatchTxn(reqs []*Elem[OP], startTS uint64) (*CoordinateResponse, error) { - var logs []*pb.Request + muts := make([]*pb.Mutation, 0, len(reqs)) for _, req := range reqs { - m := c.toTxnRequests(req, startTS) - logs = append(logs, m...) + muts = append(muts, elemToMutation(req)) + } + logs := []*pb.Request{ + { + IsTxn: true, + Phase: pb.Phase_PREPARE, + Ts: startTS, + Mutations: muts, + }, + { + IsTxn: true, + Phase: pb.Phase_COMMIT, + Ts: startTS, + Mutations: muts, + }, } r, err := c.transactionManager.Commit(logs) @@ -144,63 +172,8 @@ func (c *Coordinate) toRawRequest(req *Elem[OP]) *pb.Request { panic("unreachable") } -func (c *Coordinate) toTxnRequests(req *Elem[OP], startTS uint64) []*pb.Request { - switch req.Op { - case Put: - return []*pb.Request{ - { - IsTxn: true, - Phase: pb.Phase_PREPARE, - Ts: startTS, - Mutations: []*pb.Mutation{ - { - Key: req.Key, - Value: req.Value, - }, - }, - }, - { - IsTxn: true, - Phase: pb.Phase_COMMIT, - Ts: startTS, - Mutations: []*pb.Mutation{ - { - Key: req.Key, - Value: req.Value, - }, - }, - }, - } - - case Del: - return []*pb.Request{ - { - IsTxn: true, - Phase: pb.Phase_PREPARE, - Ts: startTS, - Mutations: []*pb.Mutation{ - { - Key: req.Key, - }, - }, - }, - { - IsTxn: true, - Phase: pb.Phase_COMMIT, - Ts: startTS, - Mutations: []*pb.Mutation{ - { - Key: req.Key, - }, - }, - }, - } - } - - panic("unreachable") -} - var ErrInvalidRequest = errors.New("invalid request") +var ErrLeaderNotFound = errors.New("leader not found") func (c *Coordinate) redirect(reqs *OperationGroup[OP]) (*CoordinateResponse, error) { ctx := context.Background() @@ -224,9 +197,14 @@ func (c *Coordinate) redirect(reqs *OperationGroup[OP]) (*CoordinateResponse, er var requests []*pb.Request if reqs.IsTxn { + muts := make([]*pb.Mutation, 0, len(reqs.Elems)) for _, req := range reqs.Elems { - requests = append(requests, c.toTxnRequests(req, reqs.StartTS)...) + muts = append(muts, elemToMutation(req)) } + requests = append(requests, + &pb.Request{IsTxn: true, Phase: pb.Phase_PREPARE, Ts: reqs.StartTS, Mutations: muts}, + &pb.Request{IsTxn: true, Phase: pb.Phase_COMMIT, Ts: reqs.StartTS, Mutations: muts}, + ) } else { for _, req := range reqs.Elems { requests = append(requests, c.toRawRequest(req)) @@ -260,3 +238,20 @@ func (c *Coordinate) toForwardRequest(reqs []*pb.Request) *pb.ForwardRequest { return out } + +func elemToMutation(req *Elem[OP]) *pb.Mutation { + switch req.Op { + case Put: + return &pb.Mutation{ + Op: pb.Op_PUT, + Key: req.Key, + Value: req.Value, + } + case Del: + return &pb.Mutation{ + Op: pb.Op_DEL, + Key: req.Key, + } + } + panic("unreachable") +} diff --git a/kv/leader_proxy.go b/kv/leader_proxy.go new file mode 100644 index 00000000..04639d96 --- /dev/null +++ b/kv/leader_proxy.go @@ -0,0 +1,74 @@ +package kv + +import ( + "context" + + pb "github.com/bootjp/elastickv/proto" + "github.com/cockroachdb/errors" + "github.com/hashicorp/raft" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +// LeaderProxy forwards transactional requests to the current raft leader when +// the local node is not the leader. +type LeaderProxy struct { + raft *raft.Raft + tm *TransactionManager +} + +// NewLeaderProxy creates a leader-aware transactional proxy for a raft group. +func NewLeaderProxy(r *raft.Raft) *LeaderProxy { + return &LeaderProxy{ + raft: r, + tm: NewTransaction(r), + } +} + +func (p *LeaderProxy) Commit(reqs []*pb.Request) (*TransactionResponse, error) { + if p.raft.State() == raft.Leader { + return p.tm.Commit(reqs) + } + return p.forward(reqs) +} + +func (p *LeaderProxy) Abort(reqs []*pb.Request) (*TransactionResponse, error) { + if p.raft.State() == raft.Leader { + return p.tm.Abort(reqs) + } + return p.forward(reqs) +} + +func (p *LeaderProxy) forward(reqs []*pb.Request) (*TransactionResponse, error) { + if len(reqs) == 0 { + return &TransactionResponse{}, nil + } + addr, _ := p.raft.LeaderWithID() + if addr == "" { + return nil, errors.WithStack(ErrLeaderNotFound) + } + + conn, err := grpc.NewClient(string(addr), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), + ) + if err != nil { + return nil, errors.WithStack(err) + } + defer conn.Close() + + cli := pb.NewInternalClient(conn) + resp, err := cli.Forward(context.Background(), &pb.ForwardRequest{ + IsTxn: reqs[0].IsTxn, + Requests: reqs, + }) + if err != nil { + return nil, errors.WithStack(err) + } + if !resp.Success { + return nil, ErrInvalidRequest + } + return &TransactionResponse{CommitIndex: resp.CommitIndex}, nil +} + +var _ Transactional = (*LeaderProxy)(nil) diff --git a/kv/shard_key.go b/kv/shard_key.go new file mode 100644 index 00000000..f1a06444 --- /dev/null +++ b/kv/shard_key.go @@ -0,0 +1,15 @@ +package kv + +import "github.com/bootjp/elastickv/store" + +// routeKey normalizes internal keys (e.g., list metadata/items) to the logical +// user key used for shard routing. +func routeKey(key []byte) []byte { + if key == nil { + return nil + } + if user := store.ExtractListUserKey(key); user != nil { + return user + } + return key +} diff --git a/kv/shard_router.go b/kv/shard_router.go index 34038474..5f055c1a 100644 --- a/kv/shard_router.go +++ b/kv/shard_router.go @@ -95,7 +95,7 @@ func (s *ShardRouter) groupRequests(reqs []*pb.Request) (map[uint64][]*pb.Reques if len(r.Mutations) == 0 { return nil, ErrInvalidRequest } - key := r.Mutations[0].Key + key := routeKey(r.Mutations[0].Key) route, ok := s.engine.GetRoute(key) if !ok { return nil, errors.Wrapf(ErrInvalidRequest, "no route for key %q", key) @@ -107,7 +107,7 @@ func (s *ShardRouter) groupRequests(reqs []*pb.Request) (map[uint64][]*pb.Reques // Get retrieves a key routed to the correct shard. func (s *ShardRouter) Get(ctx context.Context, key []byte) ([]byte, error) { - route, ok := s.engine.GetRoute(key) + route, ok := s.engine.GetRoute(routeKey(key)) if !ok { return nil, errors.Wrapf(ErrInvalidRequest, "no route for key %q", key) } diff --git a/kv/shard_router_test.go b/kv/shard_router_test.go index 7016e4df..a4e855eb 100644 --- a/kv/shard_router_test.go +++ b/kv/shard_router_test.go @@ -274,3 +274,31 @@ func TestShardRouterCommitFailure(t *testing.T) { t.Fatalf("unexpected abort on successful group") } } + +func TestShardRouterRoutesListKeys(t *testing.T) { + e := distribution.NewEngine() + e.UpdateRoute([]byte("a"), []byte("m"), 1) + e.UpdateRoute([]byte("m"), nil, 2) + + router := NewShardRouter(e) + + ok := &fakeTM{} + fail := &fakeTM{} + router.Register(1, ok, nil) + router.Register(2, fail, nil) + + listMetaKey := store.ListMetaKey([]byte("b")) + reqs := []*pb.Request{ + {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: listMetaKey, Value: []byte("v")}}}, + } + + if _, err := router.Commit(reqs); err != nil { + t.Fatalf("commit: %v", err) + } + if ok.commitCalls != 1 { + t.Fatalf("expected commit routed to group1") + } + if fail.commitCalls != 0 { + t.Fatalf("unexpected commit on group2") + } +} diff --git a/kv/shard_store.go b/kv/shard_store.go new file mode 100644 index 00000000..6d0da333 --- /dev/null +++ b/kv/shard_store.go @@ -0,0 +1,238 @@ +package kv + +import ( + "bytes" + "context" + "io" + "sort" + + "github.com/bootjp/elastickv/distribution" + pb "github.com/bootjp/elastickv/proto" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + "github.com/hashicorp/raft" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +// ShardStore routes MVCC reads to shard-specific stores and proxies to leaders when needed. +type ShardStore struct { + engine *distribution.Engine + groups map[uint64]*ShardGroup +} + +// NewShardStore creates a sharded MVCC store wrapper. +func NewShardStore(engine *distribution.Engine, groups map[uint64]*ShardGroup) *ShardStore { + return &ShardStore{ + engine: engine, + groups: groups, + } +} + +func (s *ShardStore) GetAt(ctx context.Context, key []byte, ts uint64) ([]byte, error) { + g, ok := s.groupForKey(key) + if !ok || g.Store == nil { + return nil, store.ErrKeyNotFound + } + if g.Raft != nil && g.Raft.State() == raft.Leader { + val, err := g.Store.GetAt(ctx, key, ts) + if err != nil { + return nil, errors.WithStack(err) + } + return val, nil + } + return s.proxyRawGet(ctx, g, key, ts) +} + +func (s *ShardStore) ExistsAt(ctx context.Context, key []byte, ts uint64) (bool, error) { + v, err := s.GetAt(ctx, key, ts) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return false, nil + } + return false, err + } + return v != nil, nil +} + +func (s *ShardStore) ScanAt(ctx context.Context, start []byte, end []byte, limit int, ts uint64) ([]*store.KVPair, error) { + if limit <= 0 { + return []*store.KVPair{}, nil + } + + // Get only the routes whose ranges intersect with [start, end) + intersectingRoutes := s.engine.GetIntersectingRoutes(start, end) + + var out []*store.KVPair + for _, route := range intersectingRoutes { + g, ok := s.groups[route.GroupID] + if !ok || g == nil || g.Store == nil { + continue + } + kvs, err := g.Store.ScanAt(ctx, start, end, limit, ts) + if err != nil { + return nil, errors.WithStack(err) + } + out = append(out, kvs...) + } + sort.Slice(out, func(i, j int) bool { + return bytes.Compare(out[i].Key, out[j].Key) < 0 + }) + if len(out) > limit { + out = out[:limit] + } + return out, nil +} + +func (s *ShardStore) PutAt(ctx context.Context, key []byte, value []byte, commitTS uint64, expireAt uint64) error { + g, ok := s.groupForKey(key) + if !ok || g.Store == nil { + return store.ErrNotSupported + } + return errors.WithStack(g.Store.PutAt(ctx, key, value, commitTS, expireAt)) +} + +func (s *ShardStore) DeleteAt(ctx context.Context, key []byte, commitTS uint64) error { + g, ok := s.groupForKey(key) + if !ok || g.Store == nil { + return store.ErrNotSupported + } + return errors.WithStack(g.Store.DeleteAt(ctx, key, commitTS)) +} + +func (s *ShardStore) PutWithTTLAt(ctx context.Context, key []byte, value []byte, commitTS uint64, expireAt uint64) error { + g, ok := s.groupForKey(key) + if !ok || g.Store == nil { + return store.ErrNotSupported + } + return errors.WithStack(g.Store.PutWithTTLAt(ctx, key, value, commitTS, expireAt)) +} + +func (s *ShardStore) ExpireAt(ctx context.Context, key []byte, expireAt uint64, commitTS uint64) error { + g, ok := s.groupForKey(key) + if !ok || g.Store == nil { + return store.ErrNotSupported + } + return errors.WithStack(g.Store.ExpireAt(ctx, key, expireAt, commitTS)) +} + +func (s *ShardStore) LatestCommitTS(ctx context.Context, key []byte) (uint64, bool, error) { + g, ok := s.groupForKey(key) + if !ok || g.Store == nil { + return 0, false, nil + } + ts, exists, err := g.Store.LatestCommitTS(ctx, key) + if err != nil { + return 0, false, errors.WithStack(err) + } + return ts, exists, nil +} + +func (s *ShardStore) ApplyMutations(ctx context.Context, mutations []*store.KVPairMutation, startTS, commitTS uint64) error { + if len(mutations) == 0 { + return nil + } + // Determine the shard group for the first mutation. + firstGroup, ok := s.groupForKey(mutations[0].Key) + if !ok || firstGroup == nil || firstGroup.Store == nil { + return store.ErrNotSupported + } + // Ensure that all mutations in the batch belong to the same shard. + for i := 1; i < len(mutations); i++ { + g, ok := s.groupForKey(mutations[i].Key) + if !ok || g == nil || g.Store == nil { + return store.ErrNotSupported + } + if g != firstGroup { + // Mixed-shard mutation batches are not supported. + return store.ErrNotSupported + } + } + return errors.WithStack(firstGroup.Store.ApplyMutations(ctx, mutations, startTS, commitTS)) +} + +func (s *ShardStore) LastCommitTS() uint64 { + var max uint64 + for _, g := range s.groups { + if g == nil || g.Store == nil { + continue + } + if ts := g.Store.LastCommitTS(); ts > max { + max = ts + } + } + return max +} + +func (s *ShardStore) Compact(ctx context.Context, minTS uint64) error { + for _, g := range s.groups { + if g == nil || g.Store == nil { + continue + } + if err := g.Store.Compact(ctx, minTS); err != nil { + return errors.WithStack(err) + } + } + return nil +} + +func (s *ShardStore) Snapshot() (io.ReadWriter, error) { + return nil, store.ErrNotSupported +} + +func (s *ShardStore) Restore(_ io.Reader) error { + return store.ErrNotSupported +} + +func (s *ShardStore) Close() error { + var first error + for _, g := range s.groups { + if g == nil || g.Store == nil { + continue + } + if err := g.Store.Close(); err != nil && first == nil { + first = errors.WithStack(err) + } + } + return first +} + +func (s *ShardStore) groupForKey(key []byte) (*ShardGroup, bool) { + route, ok := s.engine.GetRoute(routeKey(key)) + if !ok { + return nil, false + } + g, ok := s.groups[route.GroupID] + return g, ok +} + +func (s *ShardStore) proxyRawGet(ctx context.Context, g *ShardGroup, key []byte, ts uint64) ([]byte, error) { + if g == nil || g.Raft == nil { + return nil, store.ErrKeyNotFound + } + addr, _ := g.Raft.LeaderWithID() + if addr == "" { + return nil, errors.WithStack(ErrLeaderNotFound) + } + + conn, err := grpc.NewClient(string(addr), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), + ) + if err != nil { + return nil, errors.WithStack(err) + } + defer conn.Close() + + cli := pb.NewRawKVClient(conn) + resp, err := cli.RawGet(ctx, &pb.RawGetRequest{Key: key, Ts: ts}) + if err != nil { + return nil, errors.WithStack(err) + } + if resp.Value == nil { + return nil, store.ErrKeyNotFound + } + return resp.Value, nil +} + +var _ store.MVCCStore = (*ShardStore)(nil) diff --git a/kv/sharded_coordinator.go b/kv/sharded_coordinator.go new file mode 100644 index 00000000..a7b795d1 --- /dev/null +++ b/kv/sharded_coordinator.go @@ -0,0 +1,222 @@ +package kv + +import ( + "sort" + + "github.com/bootjp/elastickv/distribution" + pb "github.com/bootjp/elastickv/proto" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + "github.com/hashicorp/raft" +) + +type ShardGroup struct { + Raft *raft.Raft + Store store.MVCCStore + Txn Transactional +} + +const txnPhaseCount = 2 + +// ShardedCoordinator routes operations to shard-specific raft groups. +// It issues timestamps via a shared HLC and uses ShardRouter to dispatch. +type ShardedCoordinator struct { + engine *distribution.Engine + router *ShardRouter + groups map[uint64]*ShardGroup + defaultGroup uint64 + clock *HLC +} + +// NewShardedCoordinator builds a coordinator for the provided shard groups. +// The defaultGroup is used for non-keyed leader checks. +func NewShardedCoordinator(engine *distribution.Engine, groups map[uint64]*ShardGroup, defaultGroup uint64, clock *HLC) *ShardedCoordinator { + router := NewShardRouter(engine) + for gid, g := range groups { + router.Register(gid, g.Txn, g.Store) + } + return &ShardedCoordinator{ + engine: engine, + router: router, + groups: groups, + defaultGroup: defaultGroup, + clock: clock, + } +} + +func (c *ShardedCoordinator) Dispatch(reqs *OperationGroup[OP]) (*CoordinateResponse, error) { + if err := validateOperationGroup(reqs); err != nil { + return nil, err + } + + if reqs.IsTxn && reqs.StartTS == 0 { + reqs.StartTS = c.clock.Next() + } + + logs, err := c.requestLogs(reqs) + if err != nil { + return nil, err + } + + r, err := c.router.Commit(logs) + if err != nil { + return nil, errors.WithStack(err) + } + return &CoordinateResponse{CommitIndex: r.CommitIndex}, nil +} + +func (c *ShardedCoordinator) IsLeader() bool { + g, ok := c.groups[c.defaultGroup] + if !ok || g.Raft == nil { + return false + } + return g.Raft.State() == raft.Leader +} + +func (c *ShardedCoordinator) VerifyLeader() error { + g, ok := c.groups[c.defaultGroup] + if !ok || g.Raft == nil { + return errors.WithStack(ErrLeaderNotFound) + } + return errors.WithStack(g.Raft.VerifyLeader().Error()) +} + +func (c *ShardedCoordinator) RaftLeader() raft.ServerAddress { + g, ok := c.groups[c.defaultGroup] + if !ok || g.Raft == nil { + return "" + } + addr, _ := g.Raft.LeaderWithID() + return addr +} + +func (c *ShardedCoordinator) IsLeaderForKey(key []byte) bool { + g, ok := c.groupForKey(key) + if !ok || g.Raft == nil { + return false + } + return g.Raft.State() == raft.Leader +} + +func (c *ShardedCoordinator) VerifyLeaderForKey(key []byte) error { + g, ok := c.groupForKey(key) + if !ok || g.Raft == nil { + return errors.WithStack(ErrLeaderNotFound) + } + return errors.WithStack(g.Raft.VerifyLeader().Error()) +} + +func (c *ShardedCoordinator) RaftLeaderForKey(key []byte) raft.ServerAddress { + g, ok := c.groupForKey(key) + if !ok || g.Raft == nil { + return "" + } + addr, _ := g.Raft.LeaderWithID() + return addr +} + +func (c *ShardedCoordinator) Clock() *HLC { + return c.clock +} + +func (c *ShardedCoordinator) groupForKey(key []byte) (*ShardGroup, bool) { + route, ok := c.engine.GetRoute(routeKey(key)) + if !ok { + return nil, false + } + g, ok := c.groups[route.GroupID] + return g, ok +} + +func (c *ShardedCoordinator) toRawRequest(req *Elem[OP]) *pb.Request { + switch req.Op { + case Put: + return &pb.Request{ + IsTxn: false, + Phase: pb.Phase_NONE, + Ts: c.clock.Next(), + Mutations: []*pb.Mutation{ + { + Op: pb.Op_PUT, + Key: req.Key, + Value: req.Value, + }, + }, + } + case Del: + return &pb.Request{ + IsTxn: false, + Phase: pb.Phase_NONE, + Ts: c.clock.Next(), + Mutations: []*pb.Mutation{ + { + Op: pb.Op_DEL, + Key: req.Key, + }, + }, + } + } + panic("unreachable") +} + +var _ Coordinator = (*ShardedCoordinator)(nil) + +func validateOperationGroup(reqs *OperationGroup[OP]) error { + if reqs == nil || len(reqs.Elems) == 0 { + return ErrInvalidRequest + } + return nil +} + +func (c *ShardedCoordinator) requestLogs(reqs *OperationGroup[OP]) ([]*pb.Request, error) { + if reqs.IsTxn { + return c.txnLogs(reqs) + } + return c.rawLogs(reqs), nil +} + +func (c *ShardedCoordinator) rawLogs(reqs *OperationGroup[OP]) []*pb.Request { + logs := make([]*pb.Request, 0, len(reqs.Elems)) + for _, req := range reqs.Elems { + logs = append(logs, c.toRawRequest(req)) + } + return logs +} + +func (c *ShardedCoordinator) txnLogs(reqs *OperationGroup[OP]) ([]*pb.Request, error) { + grouped, gids, err := c.groupMutations(reqs.Elems) + if err != nil { + return nil, err + } + return buildTxnLogs(reqs.StartTS, grouped, gids), nil +} + +func (c *ShardedCoordinator) groupMutations(reqs []*Elem[OP]) (map[uint64][]*pb.Mutation, []uint64, error) { + grouped := make(map[uint64][]*pb.Mutation) + for _, req := range reqs { + mut := elemToMutation(req) + route, ok := c.engine.GetRoute(routeKey(mut.Key)) + if !ok { + return nil, nil, errors.Wrapf(ErrInvalidRequest, "no route for key %q", mut.Key) + } + grouped[route.GroupID] = append(grouped[route.GroupID], mut) + } + gids := make([]uint64, 0, len(grouped)) + for gid := range grouped { + gids = append(gids, gid) + } + sort.Slice(gids, func(i, j int) bool { return gids[i] < gids[j] }) + return grouped, gids, nil +} + +func buildTxnLogs(startTS uint64, grouped map[uint64][]*pb.Mutation, gids []uint64) []*pb.Request { + logs := make([]*pb.Request, 0, len(gids)*txnPhaseCount) + for _, gid := range gids { + muts := grouped[gid] + logs = append(logs, + &pb.Request{IsTxn: true, Phase: pb.Phase_PREPARE, Ts: startTS, Mutations: muts}, + &pb.Request{IsTxn: true, Phase: pb.Phase_COMMIT, Ts: startTS, Mutations: muts}, + ) + } + return logs +} diff --git a/kv/sharded_integration_test.go b/kv/sharded_integration_test.go new file mode 100644 index 00000000..84b58520 --- /dev/null +++ b/kv/sharded_integration_test.go @@ -0,0 +1,107 @@ +package kv + +import ( + "context" + "testing" + "time" + + "github.com/bootjp/elastickv/distribution" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + "github.com/hashicorp/raft" +) + +func newSingleRaft(t *testing.T, id string, fsm raft.FSM) (*raft.Raft, func()) { + t.Helper() + + addr, trans := raft.NewInmemTransport(raft.ServerAddress(id)) + c := raft.DefaultConfig() + c.LocalID = raft.ServerID(id) + c.HeartbeatTimeout = 50 * time.Millisecond + c.ElectionTimeout = 100 * time.Millisecond + c.LeaderLeaseTimeout = 50 * time.Millisecond + + ldb := raft.NewInmemStore() + sdb := raft.NewInmemStore() + fss := raft.NewInmemSnapshotStore() + r, err := raft.NewRaft(c, fsm, ldb, sdb, fss, trans) + if err != nil { + t.Fatalf("new raft: %v", err) + } + cfg := raft.Configuration{ + Servers: []raft.Server{ + { + Suffrage: raft.Voter, + ID: raft.ServerID(id), + Address: addr, + }, + }, + } + if err := r.BootstrapCluster(cfg).Error(); err != nil { + t.Fatalf("bootstrap: %v", err) + } + + for i := 0; i < 100; i++ { + if r.State() == raft.Leader { + break + } + time.Sleep(10 * time.Millisecond) + } + if r.State() != raft.Leader { + t.Fatalf("node %s is not leader", id) + } + + return r, func() { r.Shutdown() } +} + +func TestShardedCoordinatorDispatch(t *testing.T) { + ctx := context.Background() + + engine := distribution.NewEngine() + engine.UpdateRoute([]byte("a"), []byte("m"), 1) + engine.UpdateRoute([]byte("m"), nil, 2) + + s1 := store.NewMVCCStore() + r1, stop1 := newSingleRaft(t, "g1", NewKvFSM(s1)) + defer stop1() + + s2 := store.NewMVCCStore() + r2, stop2 := newSingleRaft(t, "g2", NewKvFSM(s2)) + defer stop2() + + groups := map[uint64]*ShardGroup{ + 1: {Raft: r1, Store: s1, Txn: NewLeaderProxy(r1)}, + 2: {Raft: r2, Store: s2, Txn: NewLeaderProxy(r2)}, + } + + coord := NewShardedCoordinator(engine, groups, 1, NewHLC()) + shardStore := NewShardStore(engine, groups) + + ops := &OperationGroup[OP]{ + IsTxn: false, + Elems: []*Elem[OP]{ + {Op: Put, Key: []byte("b"), Value: []byte("v1")}, + {Op: Put, Key: []byte("x"), Value: []byte("v2")}, + }, + } + if _, err := coord.Dispatch(ops); err != nil { + t.Fatalf("dispatch: %v", err) + } + + readTS := shardStore.LastCommitTS() + v, err := shardStore.GetAt(ctx, []byte("b"), readTS) + if err != nil || string(v) != "v1" { + t.Fatalf("get b: %v %v", v, err) + } + v, err = shardStore.GetAt(ctx, []byte("x"), readTS) + if err != nil || string(v) != "v2" { + t.Fatalf("get x: %v %v", v, err) + } + + if _, err := s1.GetAt(ctx, []byte("x"), readTS); !errors.Is(err, store.ErrKeyNotFound) { + t.Fatalf("expected key x missing in group1, got %v", err) + } + if _, err := s2.GetAt(ctx, []byte("b"), readTS); !errors.Is(err, store.ErrKeyNotFound) { + t.Fatalf("expected key b missing in group2, got %v", err) + } +} diff --git a/main.go b/main.go index 0013cb13..0cfe373b 100644 --- a/main.go +++ b/main.go @@ -3,26 +3,21 @@ package main import ( "context" "flag" - "fmt" "log" "net" - "os" - "path/filepath" "time" "github.com/Jille/raft-grpc-leader-rpc/leaderhealth" - transport "github.com/Jille/raft-grpc-transport" "github.com/Jille/raftadmin" "github.com/bootjp/elastickv/adapter" + "github.com/bootjp/elastickv/distribution" "github.com/bootjp/elastickv/kv" pb "github.com/bootjp/elastickv/proto" "github.com/bootjp/elastickv/store" "github.com/cockroachdb/errors" "github.com/hashicorp/raft" - boltdb "github.com/hashicorp/raft-boltdb/v2" "golang.org/x/sync/errgroup" "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/reflection" ) @@ -38,121 +33,145 @@ var ( raftId = flag.String("raftId", "", "Node id used by Raft") raftDir = flag.String("raftDataDir", "data/", "Raft data dir") raftBootstrap = flag.Bool("raftBootstrap", false, "Whether to bootstrap the Raft cluster") + raftGroups = flag.String("raftGroups", "", "Comma-separated raft groups (groupID=host:port,...)") + shardRanges = flag.String("shardRanges", "", "Comma-separated shard ranges (start:end=groupID,...)") + raftRedisMap = flag.String("raftRedisMap", "", "Map of Raft address to Redis address (raftAddr=redisAddr,...)") ) func main() { flag.Parse() + if err := run(); err != nil { + log.Fatalf("%v", err) + } +} + +func run() error { if *raftId == "" { - log.Fatalf("flag --raftId is required") + return errors.New("flag --raftId is required") } ctx := context.Background() var lc net.ListenConfig - _, port, err := net.SplitHostPort(*myAddr) + groups, err := parseRaftGroups(*raftGroups, *myAddr) if err != nil { - log.Fatalf("failed to parse local address (%q): %v", *myAddr, err) + return errors.Wrapf(err, "failed to parse raft groups") } - - grpcSock, err := lc.Listen(ctx, "tcp", fmt.Sprintf(":%s", port)) + defaultGroup := defaultGroupID(groups) + ranges, err := parseShardRanges(*shardRanges, defaultGroup) if err != nil { - log.Fatalf("failed to listen: %v", err) + return errors.Wrapf(err, "failed to parse shard ranges") } - - s := store.NewMVCCStore() - kvFSM := kv.NewKvFSM(s) - - r, tm, err := NewRaft(ctx, *raftId, *myAddr, kvFSM) - if err != nil { - log.Fatalf("failed to start raft: %v", err) + if err := validateShardRanges(ranges, groups); err != nil { + return errors.Wrapf(err, "invalid shard ranges") } - gs := grpc.NewServer() - trx := kv.NewTransaction(r) - coordinate := kv.NewCoordinator(trx, r) - pb.RegisterRawKVServer(gs, adapter.NewGRPCServer(s, coordinate)) - pb.RegisterTransactionalKVServer(gs, adapter.NewGRPCServer(s, coordinate)) - pb.RegisterInternalServer(gs, adapter.NewInternal(trx, r, coordinate.Clock())) - tm.Register(gs) - - leaderhealth.Setup(r, gs, []string{"RawKV", "Example"}) - raftadmin.Register(gs, r) - reflection.Register(gs) + engine := buildEngine(ranges) + leaderRedis := buildLeaderRedis(groups, *redisAddr, *raftRedisMap) - redisL, err := lc.Listen(ctx, "tcp", *redisAddr) + multi := len(groups) > 1 || *raftGroups != "" + runtimes, shardGroups, err := buildShardGroups(*raftId, *raftDir, groups, multi, *raftBootstrap) if err != nil { - log.Fatalf("failed to listen: %v", err) + return err } - leaderRedis := map[raft.ServerAddress]string{ - raft.ServerAddress(*myAddr): *redisAddr, - } + clock := kv.NewHLC() + coordinate := kv.NewShardedCoordinator(engine, shardGroups, defaultGroup, clock) + shardStore := kv.NewShardStore(engine, shardGroups) + distServer := adapter.NewDistributionServer(engine) eg := errgroup.Group{} - eg.Go(func() error { - return errors.WithStack(gs.Serve(grpcSock)) - }) - eg.Go(func() error { - return errors.WithStack(adapter.NewRedisServer(redisL, s, coordinate, leaderRedis).Run()) - }) + if err := startRaftServers(ctx, &lc, &eg, runtimes, shardStore, coordinate, distServer); err != nil { + return err + } + if err := startRedisServer(ctx, &lc, &eg, *redisAddr, shardStore, coordinate, leaderRedis); err != nil { + return err + } - err = eg.Wait() - if err != nil { - log.Fatalf("failed to serve: %v", err) + if err := eg.Wait(); err != nil { + return errors.Wrapf(err, "failed to serve") } + return nil } const snapshotRetainCount = 3 -func NewRaft(_ context.Context, myID, myAddress string, fsm raft.FSM) (*raft.Raft, *transport.Manager, error) { - c := raft.DefaultConfig() - c.LocalID = raft.ServerID(myID) - c.HeartbeatTimeout = heartbeatTimeout - c.ElectionTimeout = electionTimeout - c.LeaderLeaseTimeout = leaderLease - - baseDir := filepath.Join(*raftDir, myID) - - ldb, err := boltdb.NewBoltStore(filepath.Join(baseDir, "logs.dat")) - if err != nil { - return nil, nil, errors.WithStack(err) - } - - sdb, err := boltdb.NewBoltStore(filepath.Join(baseDir, "stable.dat")) - if err != nil { - return nil, nil, errors.WithStack(err) +func buildEngine(ranges []rangeSpec) *distribution.Engine { + engine := distribution.NewEngine() + for _, r := range ranges { + engine.UpdateRoute(r.start, r.end, r.groupID) } + return engine +} - fss, err := raft.NewFileSnapshotStore(baseDir, snapshotRetainCount, os.Stderr) - if err != nil { - return nil, nil, errors.WithStack(err) +func buildLeaderRedis(groups []groupSpec, redisAddr string, raftRedisMap string) map[raft.ServerAddress]string { + leaderRedis := parseRaftRedisMap(raftRedisMap) + for _, g := range groups { + leaderRedis[raft.ServerAddress(g.address)] = redisAddr } + return leaderRedis +} - tm := transport.New(raft.ServerAddress(myAddress), []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), - }) - - r, err := raft.NewRaft(c, fsm, ldb, sdb, fss, tm.Transport()) - if err != nil { - return nil, nil, errors.WithStack(err) +func buildShardGroups(raftID string, raftDir string, groups []groupSpec, multi bool, bootstrap bool) ([]*raftGroupRuntime, map[uint64]*kv.ShardGroup, error) { + runtimes := make([]*raftGroupRuntime, 0, len(groups)) + shardGroups := make(map[uint64]*kv.ShardGroup, len(groups)) + for _, g := range groups { + st := store.NewMVCCStore() + fsm := kv.NewKvFSM(st) + r, tm, err := newRaftGroup(raftID, g, raftDir, multi, bootstrap, fsm) + if err != nil { + return nil, nil, errors.Wrapf(err, "failed to start raft group %d", g.id) + } + runtimes = append(runtimes, &raftGroupRuntime{ + spec: g, + raft: r, + tm: tm, + store: st, + }) + shardGroups[g.id] = &kv.ShardGroup{ + Raft: r, + Store: st, + Txn: kv.NewLeaderProxy(r), + } } + return runtimes, shardGroups, nil +} - if *raftBootstrap { - cfg := raft.Configuration{ - Servers: []raft.Server{ - { - Suffrage: raft.Voter, - ID: raft.ServerID(myID), - Address: raft.ServerAddress(myAddress), - }, - }, - } - f := r.BootstrapCluster(cfg) - if err := f.Error(); err != nil { - return nil, nil, errors.WithStack(err) +func startRaftServers(ctx context.Context, lc *net.ListenConfig, eg *errgroup.Group, runtimes []*raftGroupRuntime, shardStore *kv.ShardStore, coordinate kv.Coordinator, distServer *adapter.DistributionServer) error { + for _, rt := range runtimes { + gs := grpc.NewServer() + trx := kv.NewTransaction(rt.raft) + grpcServer := adapter.NewGRPCServer(shardStore, coordinate) + pb.RegisterRawKVServer(gs, grpcServer) + pb.RegisterTransactionalKVServer(gs, grpcServer) + pb.RegisterInternalServer(gs, adapter.NewInternal(trx, rt.raft, coordinate.Clock())) + pb.RegisterDistributionServer(gs, distServer) + rt.tm.Register(gs) + leaderhealth.Setup(rt.raft, gs, []string{"RawKV"}) + raftadmin.Register(gs, rt.raft) + reflection.Register(gs) + + grpcSock, err := lc.Listen(ctx, "tcp", rt.spec.address) + if err != nil { + return errors.Wrapf(err, "failed to listen on %s", rt.spec.address) } + srv := gs + lis := grpcSock + eg.Go(func() error { + return errors.WithStack(srv.Serve(lis)) + }) } + return nil +} - return r, tm, nil +func startRedisServer(ctx context.Context, lc *net.ListenConfig, eg *errgroup.Group, redisAddr string, shardStore *kv.ShardStore, coordinate kv.Coordinator, leaderRedis map[raft.ServerAddress]string) error { + redisL, err := lc.Listen(ctx, "tcp", redisAddr) + if err != nil { + return errors.Wrapf(err, "failed to listen on %s", redisAddr) + } + eg.Go(func() error { + return errors.WithStack(adapter.NewRedisServer(redisL, shardStore, coordinate, leaderRedis).Run()) + }) + return nil } diff --git a/multiraft_runtime.go b/multiraft_runtime.go new file mode 100644 index 00000000..be75cfe3 --- /dev/null +++ b/multiraft_runtime.go @@ -0,0 +1,106 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + + transport "github.com/Jille/raft-grpc-transport" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + "github.com/hashicorp/raft" + boltdb "github.com/hashicorp/raft-boltdb/v2" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +type raftGroupRuntime struct { + spec groupSpec + raft *raft.Raft + tm *transport.Manager + store store.MVCCStore +} + +const raftDirPerm = 0o755 + +func groupDataDir(baseDir, raftID string, groupID uint64, multi bool) string { + if !multi { + return filepath.Join(baseDir, raftID) + } + return filepath.Join(baseDir, raftID, fmt.Sprintf("group-%d", groupID)) +} + +func newRaftGroup(raftID string, group groupSpec, baseDir string, multi bool, bootstrap bool, fsm raft.FSM) (*raft.Raft, *transport.Manager, error) { + c := raft.DefaultConfig() + c.LocalID = raft.ServerID(raftID) + c.HeartbeatTimeout = heartbeatTimeout + c.ElectionTimeout = electionTimeout + c.LeaderLeaseTimeout = leaderLease + + dir := groupDataDir(baseDir, raftID, group.id, multi) + if err := os.MkdirAll(dir, raftDirPerm); err != nil { + return nil, nil, errors.WithStack(err) + } + + ldb, err := boltdb.NewBoltStore(filepath.Join(dir, "logs.dat")) + if err != nil { + // No cleanup needed here - ldb creation failed, so no resource was allocated + return nil, nil, errors.WithStack(err) + } + + // Define cleanup function immediately after ldb is created to ensure + // proper cleanup in all error paths. If we return before this point, + // no cleanup is needed since ldb creation failed. + var sdb *boltdb.BoltStore + var r *raft.Raft + cleanup := func() { + if ldb != nil { + _ = ldb.Close() + } + if sdb != nil { + _ = sdb.Close() + } + } + + sdb, err = boltdb.NewBoltStore(filepath.Join(dir, "stable.dat")) + if err != nil { + cleanup() + return nil, nil, errors.WithStack(err) + } + + fss, err := raft.NewFileSnapshotStore(dir, snapshotRetainCount, os.Stderr) + if err != nil { + cleanup() + return nil, nil, errors.WithStack(err) + } + + tm := transport.New(raft.ServerAddress(group.address), []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + }) + + r, err = raft.NewRaft(c, fsm, ldb, sdb, fss, tm.Transport()) + if err != nil { + cleanup() + return nil, nil, errors.WithStack(err) + } + + if bootstrap { + cfg := raft.Configuration{ + Servers: []raft.Server{ + { + Suffrage: raft.Voter, + ID: raft.ServerID(raftID), + Address: raft.ServerAddress(group.address), + }, + }, + } + f := r.BootstrapCluster(cfg) + if err := f.Error(); err != nil { + _ = r.Shutdown().Error() + cleanup() + return nil, nil, errors.WithStack(err) + } + } + + return r, tm, nil +} diff --git a/shard_config.go b/shard_config.go new file mode 100644 index 00000000..f7bf97b3 --- /dev/null +++ b/shard_config.go @@ -0,0 +1,151 @@ +package main + +import ( + "fmt" + "strconv" + "strings" + + "github.com/cockroachdb/errors" + "github.com/hashicorp/raft" +) + +type groupSpec struct { + id uint64 + address string +} + +type rangeSpec struct { + start []byte + end []byte + groupID uint64 +} + +const splitParts = 2 + +var ( + ErrAddressRequired = errors.New("address is required") + ErrNoRaftGroupsConfigured = errors.New("no raft groups configured") + ErrNoShardRangesConfigured = errors.New("no shard ranges configured") +) + +func parseRaftGroups(raw, defaultAddr string) ([]groupSpec, error) { + if raw == "" { + if defaultAddr == "" { + return nil, ErrAddressRequired + } + return []groupSpec{{id: 1, address: defaultAddr}}, nil + } + parts := strings.Split(raw, ",") + groups := make([]groupSpec, 0, len(parts)) + seen := map[uint64]struct{}{} + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + kv := strings.SplitN(part, "=", splitParts) + if len(kv) != splitParts { + return nil, errors.WithStack(errors.Newf("invalid raftGroups entry: %q", part)) + } + id, err := strconv.ParseUint(kv[0], 10, 64) + if err != nil { + return nil, errors.Wrapf(err, "invalid group id %q", kv[0]) + } + addr := strings.TrimSpace(kv[1]) + if addr == "" { + return nil, errors.WithStack(errors.Newf("empty address for group %d", id)) + } + if _, ok := seen[id]; ok { + return nil, errors.WithStack(errors.Newf("duplicate group id %d", id)) + } + seen[id] = struct{}{} + groups = append(groups, groupSpec{id: id, address: addr}) + } + if len(groups) == 0 { + return nil, ErrNoRaftGroupsConfigured + } + return groups, nil +} + +func parseShardRanges(raw string, defaultGroup uint64) ([]rangeSpec, error) { + if raw == "" { + return []rangeSpec{{start: []byte(""), end: nil, groupID: defaultGroup}}, nil + } + parts := strings.Split(raw, ",") + ranges := make([]rangeSpec, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + kv := strings.SplitN(part, "=", splitParts) + if len(kv) != splitParts { + return nil, errors.WithStack(errors.Newf("invalid shardRanges entry: %q", part)) + } + groupID, err := strconv.ParseUint(kv[1], 10, 64) + if err != nil { + return nil, errors.Wrapf(err, "invalid group id in %q", part) + } + rangePart := kv[0] + bounds := strings.SplitN(rangePart, ":", splitParts) + if len(bounds) != splitParts { + return nil, errors.WithStack(errors.Newf("invalid range %q (expected start:end)", rangePart)) + } + start := []byte(bounds[0]) + var end []byte + if bounds[1] != "" { + end = []byte(bounds[1]) + } + ranges = append(ranges, rangeSpec{start: start, end: end, groupID: groupID}) + } + if len(ranges) == 0 { + return nil, ErrNoShardRangesConfigured + } + return ranges, nil +} + +func parseRaftRedisMap(raw string) map[raft.ServerAddress]string { + out := make(map[raft.ServerAddress]string) + if raw == "" { + return out + } + parts := strings.Split(raw, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + kv := strings.SplitN(part, "=", splitParts) + if len(kv) != splitParts { + continue + } + out[raft.ServerAddress(kv[0])] = kv[1] + } + return out +} + +func defaultGroupID(groups []groupSpec) uint64 { + min := uint64(0) + for _, g := range groups { + if min == 0 || g.id < min { + min = g.id + } + } + if min == 0 { + return 1 + } + return min +} + +func validateShardRanges(ranges []rangeSpec, groups []groupSpec) error { + ids := map[uint64]struct{}{} + for _, g := range groups { + ids[g.id] = struct{}{} + } + for _, r := range ranges { + if _, ok := ids[r.groupID]; !ok { + return fmt.Errorf("shard range references unknown group %d", r.groupID) + } + } + return nil +}