From 417139480cc2aa53ea3ae47141f2bda205feb0fd Mon Sep 17 00:00:00 2001 From: george-dorin Date: Mon, 29 Dec 2025 15:19:04 +0200 Subject: [PATCH 1/8] wip --- core/services/arbiter/arbiter.go | 163 ++++++++++++++++++++++ core/services/arbiter/grpc_server.go | 110 +++++++++++++++ core/services/arbiter/metrics.go | 74 ++++++++++ core/services/arbiter/proto/arbiter.go | 184 +++++++++++++++++++++++++ core/services/arbiter/shardconfig.go | 91 ++++++++++++ core/services/arbiter/state.go | 79 +++++++++++ core/services/arbiter/types.go | 64 +++++++++ 7 files changed, 765 insertions(+) create mode 100644 core/services/arbiter/arbiter.go create mode 100644 core/services/arbiter/grpc_server.go create mode 100644 core/services/arbiter/metrics.go create mode 100644 core/services/arbiter/proto/arbiter.go create mode 100644 core/services/arbiter/shardconfig.go create mode 100644 core/services/arbiter/state.go create mode 100644 core/services/arbiter/types.go diff --git a/core/services/arbiter/arbiter.go b/core/services/arbiter/arbiter.go new file mode 100644 index 00000000000..1f30c406c58 --- /dev/null +++ b/core/services/arbiter/arbiter.go @@ -0,0 +1,163 @@ +package arbiter + +import ( + "context" + "net" + "sync" + + "google.golang.org/grpc" + + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/types" + + // TODO: Update this import path once proto is generated + pb "github.com/smartcontractkit/chainlink/v2/core/services/arbiter/proto" + + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +const ( + // DefaultGRPCPort is the default port for the gRPC server. + DefaultGRPCPort = ":9090" +) + +// Arbiter is the main service interface. +type Arbiter interface { + services.Service + HealthReport() map[string]error +} + +type arbiter struct { + services.StateMachine + + grpcServer *grpc.Server + grpcHandler *GRPCServer + state *State + decision DecisionEngine + shardConfig ShardConfigReader + lggr logger.Logger + + grpcAddr string + stopCh services.StopChan + wg sync.WaitGroup +} + +var _ Arbiter = (*arbiter)(nil) + +// New creates a new Arbiter service. +func New( + lggr logger.Logger, + contractReader types.ContractReader, + shardConfigAddr string, +) (Arbiter, error) { + lggr = lggr.Named("Arbiter") + + // Create state + state := NewState() + + // Create ShardConfig reader + shardConfig := NewShardConfigReader(contractReader, shardConfigAddr, lggr) + + // Create decision engine + decision := NewDecisionEngine(shardConfig, lggr) + + // Create gRPC handler + grpcHandler := NewGRPCServer(state, decision, lggr) + + // Create gRPC server + grpcServer := grpc.NewServer() + pb.RegisterArbiterServiceServer(grpcServer, grpcHandler) + + return &arbiter{ + grpcServer: grpcServer, + grpcHandler: grpcHandler, + state: state, + decision: decision, + shardConfig: shardConfig, + lggr: lggr, + grpcAddr: DefaultGRPCPort, + stopCh: make(services.StopChan), + }, nil +} + +// Start starts the Arbiter service. +func (a *arbiter) Start(ctx context.Context) error { + return a.StartOnce("Arbiter", func() error { + a.lggr.Info("Starting Arbiter service") + + // Start gRPC server in a goroutine + a.wg.Add(1) + go func() { + defer a.wg.Done() + a.runGRPCServer() + }() + + a.lggr.Infow("Arbiter service started", + "grpcAddr", a.grpcAddr, + ) + + return nil + }) +} + +// runGRPCServer starts the gRPC server and blocks until stopped. +func (a *arbiter) runGRPCServer() { + lis, err := net.Listen("tcp", a.grpcAddr) + if err != nil { + a.lggr.Errorw("Failed to listen for gRPC", + "addr", a.grpcAddr, + "error", err, + ) + return + } + + a.lggr.Infow("gRPC server listening", + "addr", a.grpcAddr, + ) + + if err := a.grpcServer.Serve(lis); err != nil { + // Check if this is a normal shutdown + select { + case <-a.stopCh: + // Normal shutdown, don't log as error + a.lggr.Debug("gRPC server stopped") + default: + a.lggr.Errorw("gRPC server error", + "error", err, + ) + } + } +} + +// Close stops the Arbiter service. +func (a *arbiter) Close() error { + return a.StopOnce("Arbiter", func() (err error) { + a.lggr.Info("Stopping Arbiter service") + + // Signal stop + close(a.stopCh) + + // Graceful shutdown of gRPC server + a.grpcServer.GracefulStop() + a.lggr.Debug("gRPC server stopped gracefully") + + // Wait for goroutines + a.wg.Wait() + + a.lggr.Info("Arbiter service stopped") + + return nil + }) +} + +// HealthReport returns the health status of the service. +func (a *arbiter) HealthReport() map[string]error { + return map[string]error{ + a.Name(): a.Ready(), + } +} + +// Name returns the service name. +func (a *arbiter) Name() string { + return a.lggr.Name() +} diff --git a/core/services/arbiter/grpc_server.go b/core/services/arbiter/grpc_server.go new file mode 100644 index 00000000000..ea5f5a54c6e --- /dev/null +++ b/core/services/arbiter/grpc_server.go @@ -0,0 +1,110 @@ +package arbiter + +import ( + "context" + "strings" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + // TODO: Update this import path once proto is generated + pb "github.com/smartcontractkit/chainlink/v2/core/services/arbiter/proto" + + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +// GRPCServer implements the ArbiterService gRPC interface. +type GRPCServer struct { + pb.UnimplementedArbiterServiceServer + state *State + decision DecisionEngine + lggr logger.Logger +} + +// NewGRPCServer creates a new gRPC server instance. +func NewGRPCServer(state *State, decision DecisionEngine, lggr logger.Logger) *GRPCServer { + return &GRPCServer{ + state: state, + decision: decision, + lggr: lggr.Named("GRPCServer"), + } +} + +// SubmitScaleIntent handles the SubmitScaleIntent RPC. +func (s *GRPCServer) SubmitScaleIntent(ctx context.Context, req *pb.ScaleIntentRequest) (*pb.ScaleIntentResponse, error) { + // Validate request + if req.DesiredReplicaCount < 1 { + RecordRequest("SubmitScaleIntent", "INVALID_ARGUMENT") + return nil, status.Error(codes.InvalidArgument, "DesiredReplicaCount must be at least 1") + } + + // Convert protobuf types to internal types + currentReplicas := make(map[string]ShardReplica) + for name, replica := range req.CurrentReplicas { + currentReplicas[name] = ShardReplica{ + Status: protoStatusToString(replica.Status), + Message: replica.Message, + Metrics: replica.Metrics, + } + } + + // Update state + s.state.Update(currentReplicas, int(req.DesiredReplicaCount), req.Reason) + + // Update metrics + SetCurrentReplicas(len(currentReplicas)) + SetDesiredReplicas(int(req.DesiredReplicaCount)) + + // Compute approved count using decision engine + approved, err := s.decision.ComputeApprovedCount(ctx, int(req.DesiredReplicaCount)) + if err != nil { + s.lggr.Errorw("failed to compute approved count", + "error", err, + "desiredCount", req.DesiredReplicaCount, + ) + RecordRequest("SubmitScaleIntent", "INTERNAL") + return nil, status.Error(codes.Internal, "failed to compute approved count") + } + + // Update state and metrics with approved count + s.state.SetApprovedCount(approved) + SetApprovedReplicas(approved) + + RecordRequest("SubmitScaleIntent", "OK") + + s.lggr.Infow("Processed scale intent", + "currentReplicasCount", len(currentReplicas), + "desiredReplicasCount", req.DesiredReplicaCount, + "approvedReplicasCount", approved, + "reason", req.Reason, + ) + + return &pb.ScaleIntentResponse{Status: "ok"}, nil +} + +// GetScalingSpec handles the GetScalingSpec RPC. +func (s *GRPCServer) GetScalingSpec(ctx context.Context, req *pb.GetScalingSpecRequest) (*pb.ScalingSpecResponse, error) { + spec := s.state.GetScalingSpec() + + RecordRequest("GetScalingSpec", "OK") + + return &pb.ScalingSpecResponse{ + CurrentReplicaCount: int32(spec.CurrentReplicaCount), + DesiredReplicaCount: int32(spec.DesiredReplicaCount), + ApprovedReplicaCount: int32(spec.ApprovedReplicaCount), + LastScalingReason: spec.LastScalingReason, + }, nil +} + +// HealthCheck handles the HealthCheck RPC. +func (s *GRPCServer) HealthCheck(ctx context.Context, req *pb.HealthCheckRequest) (*pb.HealthCheckResponse, error) { + RecordRequest("HealthCheck", "OK") + return &pb.HealthCheckResponse{Status: "ok"}, nil +} + +// protoStatusToString converts protobuf ReleaseStatus to string. +func protoStatusToString(status pb.ReleaseStatus) string { + // Convert RELEASE_STATUS_INSTALLING -> INSTALLING + name := status.String() + return strings.TrimPrefix(name, "RELEASE_STATUS_") +} diff --git a/core/services/arbiter/metrics.go b/core/services/arbiter/metrics.go new file mode 100644 index 00000000000..f7b330e54bf --- /dev/null +++ b/core/services/arbiter/metrics.go @@ -0,0 +1,74 @@ +package arbiter + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // requestsTotal counts all gRPC requests by endpoint and status. + requestsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "arbiter_requests_total", + Help: "Total number of requests by endpoint and status", + }, + []string{"endpoint", "status"}, + ) + + // currentReplicas tracks the current number of replicas observed. + currentReplicas = promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "arbiter_current_replicas", + Help: "Current number of replicas", + }, + ) + + // desiredReplicas tracks the number of replicas KEDA wants. + desiredReplicas = promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "arbiter_desired_replicas", + Help: "Desired number of replicas", + }, + ) + + // approvedReplicas tracks the number of replicas the Arbiter approved. + approvedReplicas = promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "arbiter_approved_replicas", + Help: "Approved number of replicas", + }, + ) + + // onChainMaxReplicas tracks the on-chain governance limit. + onChainMaxReplicas = promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "arbiter_onchain_max_replicas", + Help: "On-chain maximum replicas from ShardConfig contract", + }, + ) +) + +// RecordRequest increments the request counter for the given endpoint and status. +func RecordRequest(endpoint, status string) { + requestsTotal.WithLabelValues(endpoint, status).Inc() +} + +// SetCurrentReplicas sets the current replica count gauge. +func SetCurrentReplicas(count int) { + currentReplicas.Set(float64(count)) +} + +// SetDesiredReplicas sets the desired replica count gauge. +func SetDesiredReplicas(count int) { + desiredReplicas.Set(float64(count)) +} + +// SetApprovedReplicas sets the approved replica count gauge. +func SetApprovedReplicas(count int) { + approvedReplicas.Set(float64(count)) +} + +// SetOnChainMaxReplicas sets the on-chain max replica count gauge. +func SetOnChainMaxReplicas(count uint64) { + onChainMaxReplicas.Set(float64(count)) +} diff --git a/core/services/arbiter/proto/arbiter.go b/core/services/arbiter/proto/arbiter.go new file mode 100644 index 00000000000..6995bbaabb5 --- /dev/null +++ b/core/services/arbiter/proto/arbiter.go @@ -0,0 +1,184 @@ +// Package proto contains placeholder types for the Arbiter gRPC service. +// TODO: Replace this file with the actual generated proto types. +package proto + +import ( + "context" + + "google.golang.org/grpc" +) + +// ReleaseStatus represents the status of a Helm release. +type ReleaseStatus int32 + +const ( + ReleaseStatus_RELEASE_STATUS_UNSPECIFIED ReleaseStatus = 0 + ReleaseStatus_RELEASE_STATUS_INSTALLING ReleaseStatus = 1 + ReleaseStatus_RELEASE_STATUS_INSTALL_SUCCEEDED ReleaseStatus = 2 + ReleaseStatus_RELEASE_STATUS_INSTALL_FAILED ReleaseStatus = 3 + ReleaseStatus_RELEASE_STATUS_READY ReleaseStatus = 4 + ReleaseStatus_RELEASE_STATUS_DEGRADED ReleaseStatus = 5 + ReleaseStatus_RELEASE_STATUS_UNKNOWN ReleaseStatus = 6 +) + +func (s ReleaseStatus) String() string { + switch s { + case ReleaseStatus_RELEASE_STATUS_INSTALLING: + return "RELEASE_STATUS_INSTALLING" + case ReleaseStatus_RELEASE_STATUS_INSTALL_SUCCEEDED: + return "RELEASE_STATUS_INSTALL_SUCCEEDED" + case ReleaseStatus_RELEASE_STATUS_INSTALL_FAILED: + return "RELEASE_STATUS_INSTALL_FAILED" + case ReleaseStatus_RELEASE_STATUS_READY: + return "RELEASE_STATUS_READY" + case ReleaseStatus_RELEASE_STATUS_DEGRADED: + return "RELEASE_STATUS_DEGRADED" + case ReleaseStatus_RELEASE_STATUS_UNKNOWN: + return "RELEASE_STATUS_UNKNOWN" + default: + return "RELEASE_STATUS_UNSPECIFIED" + } +} + +// ShardReplica contains the status and metadata of a single shard replica. +type ShardReplica struct { + Status ReleaseStatus + Message string + Metrics map[string]float64 +} + +// ScaleIntentRequest is sent by the scaler to report current state and desired replicas. +type ScaleIntentRequest struct { + CurrentReplicas map[string]*ShardReplica + DesiredReplicaCount int32 + Reason string +} + +// ScaleIntentResponse acknowledges receipt of the scale intent. +type ScaleIntentResponse struct { + Status string +} + +// GetScalingSpecRequest is used to query the current scaling specification. +type GetScalingSpecRequest struct { + ScalableUnitName string +} + +// ScalingSpecResponse contains the arbiter's view of the scaling state. +type ScalingSpecResponse struct { + CurrentReplicaCount int32 + DesiredReplicaCount int32 + ApprovedReplicaCount int32 + LastScalingReason string +} + +// HealthCheckRequest is used to verify service health. +type HealthCheckRequest struct{} + +// HealthCheckResponse indicates service health status. +type HealthCheckResponse struct { + Status string +} + +// ArbiterServiceServer is the server interface. +type ArbiterServiceServer interface { + SubmitScaleIntent(context.Context, *ScaleIntentRequest) (*ScaleIntentResponse, error) + GetScalingSpec(context.Context, *GetScalingSpecRequest) (*ScalingSpecResponse, error) + HealthCheck(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) +} + +// UnimplementedArbiterServiceServer provides default implementations. +type UnimplementedArbiterServiceServer struct{} + +func (UnimplementedArbiterServiceServer) SubmitScaleIntent(context.Context, *ScaleIntentRequest) (*ScaleIntentResponse, error) { + return nil, nil +} + +func (UnimplementedArbiterServiceServer) GetScalingSpec(context.Context, *GetScalingSpecRequest) (*ScalingSpecResponse, error) { + return nil, nil +} + +func (UnimplementedArbiterServiceServer) HealthCheck(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) { + return nil, nil +} + +// RegisterArbiterServiceServer registers the server with grpc. +func RegisterArbiterServiceServer(s *grpc.Server, srv ArbiterServiceServer) { + s.RegisterService(&ArbiterService_ServiceDesc, srv) +} + +// ArbiterService_ServiceDesc is the service descriptor. +var ArbiterService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "arbiter.v1.ArbiterService", + HandlerType: (*ArbiterServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "SubmitScaleIntent", + Handler: _ArbiterService_SubmitScaleIntent_Handler, + }, + { + MethodName: "GetScalingSpec", + Handler: _ArbiterService_GetScalingSpec_Handler, + }, + { + MethodName: "HealthCheck", + Handler: _ArbiterService_HealthCheck_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "arbiter.proto", +} + +func _ArbiterService_SubmitScaleIntent_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ScaleIntentRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ArbiterServiceServer).SubmitScaleIntent(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/arbiter.v1.ArbiterService/SubmitScaleIntent", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ArbiterServiceServer).SubmitScaleIntent(ctx, req.(*ScaleIntentRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ArbiterService_GetScalingSpec_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetScalingSpecRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ArbiterServiceServer).GetScalingSpec(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/arbiter.v1.ArbiterService/GetScalingSpec", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ArbiterServiceServer).GetScalingSpec(ctx, req.(*GetScalingSpecRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ArbiterService_HealthCheck_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HealthCheckRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ArbiterServiceServer).HealthCheck(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/arbiter.v1.ArbiterService/HealthCheck", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ArbiterServiceServer).HealthCheck(ctx, req.(*HealthCheckRequest)) + } + return interceptor(ctx, in, info, handler) +} diff --git a/core/services/arbiter/shardconfig.go b/core/services/arbiter/shardconfig.go new file mode 100644 index 00000000000..0ed9fe92558 --- /dev/null +++ b/core/services/arbiter/shardconfig.go @@ -0,0 +1,91 @@ +package arbiter + +import ( + "context" + "math/big" + + "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +const ( + // ShardConfigContractName is the name used to identify the ShardConfig contract. + ShardConfigContractName = "ShardConfig" + + // GetDesiredShardCountMethod is the method name for reading the desired shard count. + GetDesiredShardCountMethod = "getDesiredShardCount" +) + +// ShardConfigABI is the ABI for the ShardConfig contract. +// Only includes the getDesiredShardCount function we need. +const ShardConfigABI = `[ + { + "inputs": [], + "name": "getDesiredShardCount", + "outputs": [{"internalType": "uint256", "name": "", "type": "uint256"}], + "stateMutability": "view", + "type": "function" + } +]` + +// ShardConfigReader reads the desired shard count from the ShardConfig contract. +type ShardConfigReader interface { + // GetDesiredShardCount retrieves the current desired shard count from on-chain. + GetDesiredShardCount(ctx context.Context) (uint64, error) +} + +type shardConfigReader struct { + reader types.ContractReader + contract types.BoundContract + lggr logger.Logger +} + +// NewShardConfigReader creates a new ShardConfigReader. +func NewShardConfigReader( + reader types.ContractReader, + contractAddress string, + lggr logger.Logger, +) ShardConfigReader { + return &shardConfigReader{ + reader: reader, + contract: types.BoundContract{ + Address: contractAddress, + Name: ShardConfigContractName, + }, + lggr: lggr.Named("ShardConfigReader"), + } +} + +// GetDesiredShardCount retrieves the current desired shard count from on-chain. +func (s *shardConfigReader) GetDesiredShardCount(ctx context.Context) (uint64, error) { + var result *big.Int + + err := s.reader.GetLatestValue( + ctx, + s.contract.ReadIdentifier(GetDesiredShardCountMethod), + primitives.Finalized, // Use finalized for governance data + nil, // No input params + &result, + ) + if err != nil { + s.lggr.Errorw("failed to get desired shard count from on-chain", + "error", err, + "contract", s.contract.Address, + ) + return 0, err + } + + count := result.Uint64() + + // Update metrics + SetOnChainMaxReplicas(count) + + s.lggr.Debugw("read desired shard count from on-chain", + "count", count, + "contract", s.contract.Address, + ) + + return count, nil +} diff --git a/core/services/arbiter/state.go b/core/services/arbiter/state.go new file mode 100644 index 00000000000..6a83dfed80e --- /dev/null +++ b/core/services/arbiter/state.go @@ -0,0 +1,79 @@ +package arbiter + +import ( + "sync" +) + +// State holds the current scaling state. +type State struct { + currentReplicas map[string]ShardReplica + lastScalingReason string + desiredReplicasCount int + approvedReplicasCount int + mu sync.RWMutex +} + +// NewState creates a new State with default values. +func NewState() *State { + return &State{ + currentReplicas: make(map[string]ShardReplica), + desiredReplicasCount: 1, + approvedReplicasCount: 1, + lastScalingReason: "Initial state", + } +} + +// Update updates the state with new scale intent data. +func (s *State) Update(currentReplicas map[string]ShardReplica, desiredCount int, reason string) { + s.mu.Lock() + defer s.mu.Unlock() + + s.currentReplicas = currentReplicas + s.desiredReplicasCount = desiredCount + s.lastScalingReason = reason +} + +// SetApprovedCount sets the approved replica count. +func (s *State) SetApprovedCount(count int) { + s.mu.Lock() + defer s.mu.Unlock() + + s.approvedReplicasCount = count +} + +// GetScalingSpec returns the current scaling specification. +func (s *State) GetScalingSpec() ScalingSpecResponse { + s.mu.RLock() + defer s.mu.RUnlock() + + return ScalingSpecResponse{ + CurrentReplicaCount: len(s.currentReplicas), + DesiredReplicaCount: s.desiredReplicasCount, + ApprovedReplicaCount: s.approvedReplicasCount, + LastScalingReason: s.lastScalingReason, + } +} + +// GetCurrentReplicaCount returns the current number of replicas. +func (s *State) GetCurrentReplicaCount() int { + s.mu.RLock() + defer s.mu.RUnlock() + + return len(s.currentReplicas) +} + +// GetDesiredReplicaCount returns the desired number of replicas. +func (s *State) GetDesiredReplicaCount() int { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.desiredReplicasCount +} + +// GetApprovedReplicaCount returns the approved number of replicas. +func (s *State) GetApprovedReplicaCount() int { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.approvedReplicasCount +} diff --git a/core/services/arbiter/types.go b/core/services/arbiter/types.go new file mode 100644 index 00000000000..59b129212de --- /dev/null +++ b/core/services/arbiter/types.go @@ -0,0 +1,64 @@ +package arbiter + +// ReleaseStatus represents the status of a Helm release. +type ReleaseStatus int + +const ( + // StatusInstalling indicates that helm release is being installed. + StatusInstalling ReleaseStatus = iota + // StatusInstallSucceeded indicates that helm release installation succeeded. + StatusInstallSucceeded + // StatusInstallFailed indicates that helm release installation failed. + StatusInstallFailed + // StatusReady indicates that deployed service is ready and operational. + StatusReady + // StatusDegraded indicates that deployed service is degraded. + StatusDegraded + // StatusUnknown indicates that the status of the helm release is unknown. + StatusUnknown +) + +// String returns the string representation of ReleaseStatus. +func (s ReleaseStatus) String() string { + switch s { + case StatusInstalling: + return "INSTALLING" + case StatusInstallSucceeded: + return "INSTALL_SUCCEEDED" + case StatusInstallFailed: + return "INSTALL_FAILED" + case StatusReady: + return "READY" + case StatusDegraded: + return "DEGRADED" + default: + return "UNKNOWN" + } +} + +// ShardReplica represents the status and message of a single shard replica. +type ShardReplica struct { + Status string `json:"status"` + Message string `json:"message"` + Metrics map[string]float64 `json:"metrics,omitempty"` +} + +// ScaleIntentRequest represents the request body for scale-intents endpoint. +type ScaleIntentRequest struct { + CurrentReplicas map[string]ShardReplica `json:"currentReplicas"` + Reason string `json:"reason"` + DesiredReplicaCount int `json:"desiredReplicaCount"` +} + +// ScalingSpecResponse represents the response body for scaling-spec endpoint. +type ScalingSpecResponse struct { + LastScalingReason string `json:"lastScalingReason"` + CurrentReplicaCount int `json:"currentReplicaCount"` + DesiredReplicaCount int `json:"desiredReplicaCount"` + ApprovedReplicaCount int `json:"approvedReplicaCount"` +} + +// StatusResponse represents a simple status response. +type StatusResponse struct { + Status string `json:"status"` +} From 2c78324da5717ffcf600c6cc1133def0d3ea0ff9 Mon Sep 17 00:00:00 2001 From: george-dorin Date: Tue, 30 Dec 2025 12:57:24 +0200 Subject: [PATCH 2/8] wip --- core/services/arbiter/arbiter_test.go | 234 ++++++++++++++++++++ core/services/arbiter/decision.go | 68 ++++++ core/services/arbiter/decision_test.go | 173 +++++++++++++++ core/services/arbiter/grpc_server_test.go | 255 ++++++++++++++++++++++ core/services/arbiter/state_test.go | 144 ++++++++++++ 5 files changed, 874 insertions(+) create mode 100644 core/services/arbiter/arbiter_test.go create mode 100644 core/services/arbiter/decision.go create mode 100644 core/services/arbiter/decision_test.go create mode 100644 core/services/arbiter/grpc_server_test.go create mode 100644 core/services/arbiter/state_test.go diff --git a/core/services/arbiter/arbiter_test.go b/core/services/arbiter/arbiter_test.go new file mode 100644 index 00000000000..5d51ca2e709 --- /dev/null +++ b/core/services/arbiter/arbiter_test.go @@ -0,0 +1,234 @@ +package arbiter + +import ( + "context" + "iter" + "math/big" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" + "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/types/query" + "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +// mockContractReader is a mock implementation of types.ContractReader for testing. +type mockContractReader struct { + types.UnimplementedContractReader + desiredShardCount uint64 + err error +} + +func (m *mockContractReader) Name() string { + return "mockContractReader" +} + +func (m *mockContractReader) Start(ctx context.Context) error { + return nil +} + +func (m *mockContractReader) Close() error { + return nil +} + +func (m *mockContractReader) Ready() error { + return nil +} + +func (m *mockContractReader) HealthReport() map[string]error { + return nil +} + +func (m *mockContractReader) Bind(ctx context.Context, bindings []types.BoundContract) error { + return nil +} + +func (m *mockContractReader) Unbind(ctx context.Context, bindings []types.BoundContract) error { + return nil +} + +func (m *mockContractReader) GetLatestValue(ctx context.Context, readIdentifier string, confidenceLevel primitives.ConfidenceLevel, params any, returnVal any) error { + if m.err != nil { + return m.err + } + // Set the result to our mock value + if ptr, ok := returnVal.(**big.Int); ok { + *ptr = big.NewInt(int64(m.desiredShardCount)) + } + return nil +} + +func (m *mockContractReader) GetLatestValueWithHeadData(ctx context.Context, readIdentifier string, confidenceLevel primitives.ConfidenceLevel, params any, returnVal any) (head *types.Head, err error) { + err = m.GetLatestValue(ctx, readIdentifier, confidenceLevel, params, returnVal) + return nil, err +} + +func (m *mockContractReader) BatchGetLatestValues(ctx context.Context, request types.BatchGetLatestValuesRequest) (types.BatchGetLatestValuesResult, error) { + return nil, nil +} + +func (m *mockContractReader) QueryKey(ctx context.Context, contract types.BoundContract, filter query.KeyFilter, limitAndSort query.LimitAndSort, sequenceDataType any) ([]types.Sequence, error) { + return nil, nil +} + +func (m *mockContractReader) QueryKeys(ctx context.Context, filters []types.ContractKeyFilter, limitAndSort query.LimitAndSort) (iter.Seq2[string, types.Sequence], error) { + return nil, nil +} + +func TestArbiter_New(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockContractReader{desiredShardCount: 10} + + arb, err := New(lggr, mockReader, "0x1234567890abcdef") + + require.NoError(t, err) + require.NotNil(t, arb) + assert.Equal(t, "Arbiter", arb.Name()) +} + +func TestArbiter_StartClose(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockContractReader{desiredShardCount: 10} + + arb, err := New(lggr, mockReader, "0x1234567890abcdef") + require.NoError(t, err) + + // Test start + err = arb.Start(context.Background()) + require.NoError(t, err) + + // Give gRPC server a moment to start + time.Sleep(50 * time.Millisecond) + + // Test health after start + healthReport := arb.HealthReport() + require.Contains(t, healthReport, arb.Name()) + assert.NoError(t, healthReport[arb.Name()]) + + // Test close + err = arb.Close() + require.NoError(t, err) +} + +func TestArbiter_ServiceTestRun(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockContractReader{desiredShardCount: 10} + + arb, err := New(lggr, mockReader, "0x1234567890abcdef") + require.NoError(t, err) + + // Use servicetest.Run to handle lifecycle + // This starts the service and registers cleanup to stop it + servicetest.Run(t, arb) + + // Service should be running after servicetest.Run + err = arb.Ready() + require.NoError(t, err) +} + +func TestArbiter_HealthReport(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockContractReader{desiredShardCount: 10} + + arb, err := New(lggr, mockReader, "0x1234567890abcdef") + require.NoError(t, err) + + t.Run("before start - not ready", func(t *testing.T) { + healthReport := arb.HealthReport() + require.Contains(t, healthReport, arb.Name()) + // Before start, Ready() should return an error + assert.Error(t, healthReport[arb.Name()]) + }) + + t.Run("after start - ready", func(t *testing.T) { + err := arb.Start(context.Background()) + require.NoError(t, err) + + healthReport := arb.HealthReport() + require.Contains(t, healthReport, arb.Name()) + assert.NoError(t, healthReport[arb.Name()]) + + err = arb.Close() + require.NoError(t, err) + }) +} + +func TestArbiter_DoubleStart(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockContractReader{desiredShardCount: 10} + + arb, err := New(lggr, mockReader, "0x1234567890abcdef") + require.NoError(t, err) + + // First start should succeed + err = arb.Start(context.Background()) + require.NoError(t, err) + + // Second start should return error (StartOnce) + err = arb.Start(context.Background()) + assert.Error(t, err) + + err = arb.Close() + require.NoError(t, err) +} + +func TestArbiter_DoubleClose(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockContractReader{desiredShardCount: 10} + + arb, err := New(lggr, mockReader, "0x1234567890abcdef") + require.NoError(t, err) + + err = arb.Start(context.Background()) + require.NoError(t, err) + + // First close should succeed + err = arb.Close() + require.NoError(t, err) + + // Second close should return error (StopOnce) + err = arb.Close() + assert.Error(t, err) +} + +func TestArbiter_Name(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockContractReader{desiredShardCount: 10} + + arb, err := New(lggr, mockReader, "0x1234567890abcdef") + require.NoError(t, err) + + assert.Equal(t, "Arbiter", arb.Name()) +} + +func TestArbiter_Ready(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockContractReader{desiredShardCount: 10} + + arb, err := New(lggr, mockReader, "0x1234567890abcdef") + require.NoError(t, err) + + // Before start, Ready should return error + err = arb.Ready() + assert.Error(t, err) + + // After start, Ready should return nil + err = arb.Start(context.Background()) + require.NoError(t, err) + + err = arb.Ready() + assert.NoError(t, err) + + // After close, Ready should return error + err = arb.Close() + require.NoError(t, err) + + err = arb.Ready() + assert.Error(t, err) +} diff --git a/core/services/arbiter/decision.go b/core/services/arbiter/decision.go new file mode 100644 index 00000000000..d69c39dddba --- /dev/null +++ b/core/services/arbiter/decision.go @@ -0,0 +1,68 @@ +package arbiter + +import ( + "context" + "fmt" + + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +// DecisionEngine computes the approved replica count based on inputs. +type DecisionEngine interface { + // ComputeApprovedCount takes KEDA's desired count and returns approved count. + ComputeApprovedCount(ctx context.Context, desiredCount int) (int, error) +} + +type decisionEngine struct { + shardConfig ShardConfigReader + lggr logger.Logger +} + +// NewDecisionEngine creates a new DecisionEngine. +func NewDecisionEngine(shardConfig ShardConfigReader, lggr logger.Logger) DecisionEngine { + return &decisionEngine{ + shardConfig: shardConfig, + lggr: lggr.Named("DecisionEngine"), + } +} + +// ComputeApprovedCount applies the decision logic: +// approved = min(desired, onChainMax) +// with a minimum of 1 shard always. +func (d *decisionEngine) ComputeApprovedCount(ctx context.Context, desiredCount int) (int, error) { + // Get on-chain limit from ShardConfig contract + maxAllowed, err := d.shardConfig.GetDesiredShardCount(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get on-chain shard limit: %w", err) + } + + // Apply constraint: approved = min(desired, max) + approved := desiredCount + maxAllowedInt := int(maxAllowed) + + if approved > maxAllowedInt { + d.lggr.Infow("Capping desired count to on-chain limit", + "desired", desiredCount, + "maxAllowed", maxAllowedInt, + "approved", maxAllowedInt, + ) + approved = maxAllowedInt + } + + // Ensure minimum of 1 shard + if approved < 1 { + d.lggr.Warnw("Desired count is less than 1, setting to minimum", + "desired", desiredCount, + "approved", 1, + ) + approved = 1 + } + + d.lggr.Debugw("Computed approved replica count", + "desired", desiredCount, + "maxAllowed", maxAllowedInt, + "approved", approved, + ) + + return approved, nil +} diff --git a/core/services/arbiter/decision_test.go b/core/services/arbiter/decision_test.go new file mode 100644 index 00000000000..a976b61a560 --- /dev/null +++ b/core/services/arbiter/decision_test.go @@ -0,0 +1,173 @@ +package arbiter + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +// mockShardConfigReader is a mock implementation of ShardConfigReader for testing. +type mockShardConfigReader struct { + desiredCount uint64 + err error +} + +func (m *mockShardConfigReader) GetDesiredShardCount(ctx context.Context) (uint64, error) { + return m.desiredCount, m.err +} + +func TestDecisionEngine_ComputeApprovedCount(t *testing.T) { + tests := []struct { + name string + desiredCount int + onChainMax uint64 + shardConfigErr error + expectedResult int + expectError bool + }{ + { + name: "desired under limit", + desiredCount: 5, + onChainMax: 10, + expectedResult: 5, + expectError: false, + }, + { + name: "desired equals limit", + desiredCount: 10, + onChainMax: 10, + expectedResult: 10, + expectError: false, + }, + { + name: "desired exceeds limit - capped", + desiredCount: 15, + onChainMax: 10, + expectedResult: 10, + expectError: false, + }, + { + name: "desired zero - minimum 1", + desiredCount: 0, + onChainMax: 10, + expectedResult: 1, + expectError: false, + }, + { + name: "negative desired - minimum 1", + desiredCount: -5, + onChainMax: 10, + expectedResult: 1, + expectError: false, + }, + { + name: "small on-chain limit caps result", + desiredCount: 5, + onChainMax: 3, + expectedResult: 3, + expectError: false, + }, + { + name: "on-chain limit of 1", + desiredCount: 100, + onChainMax: 1, + expectedResult: 1, + expectError: false, + }, + { + name: "shard config error", + desiredCount: 5, + onChainMax: 10, + shardConfigErr: errors.New("contract read failed"), + expectedResult: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + lggr := logger.TestLogger(t) + + mockReader := &mockShardConfigReader{ + desiredCount: tc.onChainMax, + err: tc.shardConfigErr, + } + + engine := NewDecisionEngine(mockReader, lggr) + + result, err := engine.ComputeApprovedCount(context.Background(), tc.desiredCount) + + if tc.expectError { + require.Error(t, err) + assert.Equal(t, 0, result) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectedResult, result) + } + }) + } +} + +func TestDecisionEngine_ComputeApprovedCount_EdgeCases(t *testing.T) { + lggr := logger.TestLogger(t) + + t.Run("large desired count capped to large on-chain limit", func(t *testing.T) { + mockReader := &mockShardConfigReader{ + desiredCount: 1000, + } + engine := NewDecisionEngine(mockReader, lggr) + + result, err := engine.ComputeApprovedCount(context.Background(), 500) + + require.NoError(t, err) + assert.Equal(t, 500, result) + }) + + t.Run("exactly at on-chain limit", func(t *testing.T) { + mockReader := &mockShardConfigReader{ + desiredCount: 7, + } + engine := NewDecisionEngine(mockReader, lggr) + + result, err := engine.ComputeApprovedCount(context.Background(), 7) + + require.NoError(t, err) + assert.Equal(t, 7, result) + }) + + t.Run("on-chain limit is zero - minimum 1 applied", func(t *testing.T) { + mockReader := &mockShardConfigReader{ + desiredCount: 0, + } + engine := NewDecisionEngine(mockReader, lggr) + + result, err := engine.ComputeApprovedCount(context.Background(), 5) + + require.NoError(t, err) + // approved = min(5, 0) = 0, but minimum is 1 + assert.Equal(t, 1, result) + }) +} + +func TestDecisionEngine_ContextCancellation(t *testing.T) { + lggr := logger.TestLogger(t) + + t.Run("context cancellation propagated", func(t *testing.T) { + mockReader := &mockShardConfigReader{ + err: context.Canceled, + } + engine := NewDecisionEngine(mockReader, lggr) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := engine.ComputeApprovedCount(ctx, 5) + + require.Error(t, err) + }) +} diff --git a/core/services/arbiter/grpc_server_test.go b/core/services/arbiter/grpc_server_test.go new file mode 100644 index 00000000000..288d939da25 --- /dev/null +++ b/core/services/arbiter/grpc_server_test.go @@ -0,0 +1,255 @@ +package arbiter + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + pb "github.com/smartcontractkit/chainlink/v2/core/services/arbiter/proto" + + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +// mockDecisionEngine is a mock implementation of DecisionEngine for testing. +type mockDecisionEngine struct { + approvedCount int + err error +} + +func (m *mockDecisionEngine) ComputeApprovedCount(ctx context.Context, desiredCount int) (int, error) { + return m.approvedCount, m.err +} + +func TestGRPCServer_SubmitScaleIntent_Success(t *testing.T) { + lggr := logger.TestLogger(t) + state := NewState() + mockDecision := &mockDecisionEngine{approvedCount: 5} + + server := NewGRPCServer(state, mockDecision, lggr) + + req := &pb.ScaleIntentRequest{ + CurrentReplicas: map[string]*pb.ShardReplica{ + "shard-0": {Status: pb.ReleaseStatus_RELEASE_STATUS_READY, Message: "Running"}, + "shard-1": {Status: pb.ReleaseStatus_RELEASE_STATUS_READY, Message: "Running"}, + }, + DesiredReplicaCount: 5, + Reason: "high CPU utilization", + } + + resp, err := server.SubmitScaleIntent(context.Background(), req) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, "ok", resp.Status) + + // Verify state was updated + spec := state.GetScalingSpec() + assert.Equal(t, 2, spec.CurrentReplicaCount) + assert.Equal(t, 5, spec.DesiredReplicaCount) + assert.Equal(t, 5, spec.ApprovedReplicaCount) + assert.Equal(t, "high CPU utilization", spec.LastScalingReason) +} + +func TestGRPCServer_SubmitScaleIntent_InvalidArgument(t *testing.T) { + lggr := logger.TestLogger(t) + state := NewState() + mockDecision := &mockDecisionEngine{approvedCount: 1} + + server := NewGRPCServer(state, mockDecision, lggr) + + tests := []struct { + name string + request *pb.ScaleIntentRequest + }{ + { + name: "desired count is zero", + request: &pb.ScaleIntentRequest{ + CurrentReplicas: map[string]*pb.ShardReplica{}, + DesiredReplicaCount: 0, + Reason: "test", + }, + }, + { + name: "desired count is negative", + request: &pb.ScaleIntentRequest{ + CurrentReplicas: map[string]*pb.ShardReplica{}, + DesiredReplicaCount: -1, + Reason: "test", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + resp, err := server.SubmitScaleIntent(context.Background(), tc.request) + + require.Error(t, err) + assert.Nil(t, resp) + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.InvalidArgument, st.Code()) + }) + } +} + +func TestGRPCServer_SubmitScaleIntent_DecisionEngineError(t *testing.T) { + lggr := logger.TestLogger(t) + state := NewState() + mockDecision := &mockDecisionEngine{ + err: errors.New("contract read failed"), + } + + server := NewGRPCServer(state, mockDecision, lggr) + + req := &pb.ScaleIntentRequest{ + CurrentReplicas: map[string]*pb.ShardReplica{}, + DesiredReplicaCount: 5, + Reason: "test", + } + + resp, err := server.SubmitScaleIntent(context.Background(), req) + + require.Error(t, err) + assert.Nil(t, resp) + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) +} + +func TestGRPCServer_GetScalingSpec(t *testing.T) { + lggr := logger.TestLogger(t) + state := NewState() + mockDecision := &mockDecisionEngine{approvedCount: 5} + + // Set up some state + replicas := map[string]ShardReplica{ + "shard-0": {Status: "READY", Message: "Running"}, + "shard-1": {Status: "READY", Message: "Running"}, + "shard-2": {Status: "INSTALLING", Message: "In progress"}, + } + state.Update(replicas, 10, "scale-up request") + state.SetApprovedCount(8) + + server := NewGRPCServer(state, mockDecision, lggr) + + resp, err := server.GetScalingSpec(context.Background(), &pb.GetScalingSpecRequest{}) + + require.NoError(t, err) + require.NotNil(t, resp) + + assert.Equal(t, int32(3), resp.CurrentReplicaCount) + assert.Equal(t, int32(10), resp.DesiredReplicaCount) + assert.Equal(t, int32(8), resp.ApprovedReplicaCount) + assert.Equal(t, "scale-up request", resp.LastScalingReason) +} + +func TestGRPCServer_GetScalingSpec_InitialState(t *testing.T) { + lggr := logger.TestLogger(t) + state := NewState() + mockDecision := &mockDecisionEngine{} + + server := NewGRPCServer(state, mockDecision, lggr) + + resp, err := server.GetScalingSpec(context.Background(), &pb.GetScalingSpecRequest{}) + + require.NoError(t, err) + require.NotNil(t, resp) + + assert.Equal(t, int32(0), resp.CurrentReplicaCount) + assert.Equal(t, int32(1), resp.DesiredReplicaCount) + assert.Equal(t, int32(1), resp.ApprovedReplicaCount) + assert.Equal(t, "Initial state", resp.LastScalingReason) +} + +func TestGRPCServer_HealthCheck(t *testing.T) { + lggr := logger.TestLogger(t) + state := NewState() + mockDecision := &mockDecisionEngine{} + + server := NewGRPCServer(state, mockDecision, lggr) + + resp, err := server.HealthCheck(context.Background(), &pb.HealthCheckRequest{}) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, "ok", resp.Status) +} + +func TestGRPCServer_SubmitScaleIntent_WithMetrics(t *testing.T) { + lggr := logger.TestLogger(t) + state := NewState() + mockDecision := &mockDecisionEngine{approvedCount: 3} + + server := NewGRPCServer(state, mockDecision, lggr) + + req := &pb.ScaleIntentRequest{ + CurrentReplicas: map[string]*pb.ShardReplica{ + "shard-0": { + Status: pb.ReleaseStatus_RELEASE_STATUS_READY, + Message: "Running", + Metrics: map[string]float64{ + "cpu_usage": 0.75, + "memory_usage": 0.60, + }, + }, + }, + DesiredReplicaCount: 3, + Reason: "metrics-based scaling", + } + + resp, err := server.SubmitScaleIntent(context.Background(), req) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, "ok", resp.Status) + + spec := state.GetScalingSpec() + assert.Equal(t, 1, spec.CurrentReplicaCount) + assert.Equal(t, 3, spec.DesiredReplicaCount) + assert.Equal(t, 3, spec.ApprovedReplicaCount) +} + +func TestGRPCServer_SubmitScaleIntent_AllReleaseStatuses(t *testing.T) { + lggr := logger.TestLogger(t) + + statuses := []pb.ReleaseStatus{ + pb.ReleaseStatus_RELEASE_STATUS_INSTALLING, + pb.ReleaseStatus_RELEASE_STATUS_INSTALL_SUCCEEDED, + pb.ReleaseStatus_RELEASE_STATUS_INSTALL_FAILED, + pb.ReleaseStatus_RELEASE_STATUS_READY, + pb.ReleaseStatus_RELEASE_STATUS_DEGRADED, + pb.ReleaseStatus_RELEASE_STATUS_UNKNOWN, + } + + for _, s := range statuses { + t.Run(s.String(), func(t *testing.T) { + state := NewState() + mockDecision := &mockDecisionEngine{approvedCount: 1} + server := NewGRPCServer(state, mockDecision, lggr) + + req := &pb.ScaleIntentRequest{ + CurrentReplicas: map[string]*pb.ShardReplica{ + "shard-0": { + Status: s, + Message: "test", + }, + }, + DesiredReplicaCount: 1, + Reason: "status test", + } + + resp, err := server.SubmitScaleIntent(context.Background(), req) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, "ok", resp.Status) + }) + } +} diff --git a/core/services/arbiter/state_test.go b/core/services/arbiter/state_test.go new file mode 100644 index 00000000000..6ee761f76ad --- /dev/null +++ b/core/services/arbiter/state_test.go @@ -0,0 +1,144 @@ +package arbiter + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestState_NewState(t *testing.T) { + state := NewState() + + require.NotNil(t, state) + assert.Equal(t, 1, state.GetDesiredReplicaCount()) + assert.Equal(t, 1, state.GetApprovedReplicaCount()) + assert.Equal(t, 0, state.GetCurrentReplicaCount()) + + spec := state.GetScalingSpec() + assert.Equal(t, "Initial state", spec.LastScalingReason) +} + +func TestState_Update(t *testing.T) { + state := NewState() + + replicas := map[string]ShardReplica{ + "shard-0": {Status: "READY", Message: "Running"}, + "shard-1": {Status: "READY", Message: "Running"}, + "shard-2": {Status: "INSTALLING", Message: "In progress"}, + } + + state.Update(replicas, 5, "high CPU utilization") + + assert.Equal(t, 3, state.GetCurrentReplicaCount()) + assert.Equal(t, 5, state.GetDesiredReplicaCount()) + + spec := state.GetScalingSpec() + assert.Equal(t, "high CPU utilization", spec.LastScalingReason) + assert.Equal(t, 3, spec.CurrentReplicaCount) + assert.Equal(t, 5, spec.DesiredReplicaCount) +} + +func TestState_SetApprovedCount(t *testing.T) { + state := NewState() + + state.SetApprovedCount(7) + + assert.Equal(t, 7, state.GetApprovedReplicaCount()) + + spec := state.GetScalingSpec() + assert.Equal(t, 7, spec.ApprovedReplicaCount) +} + +func TestState_GetScalingSpec(t *testing.T) { + state := NewState() + + replicas := map[string]ShardReplica{ + "shard-0": {Status: "READY", Message: "Running"}, + "shard-1": {Status: "READY", Message: "Running"}, + } + + state.Update(replicas, 10, "scale-up request") + state.SetApprovedCount(8) + + spec := state.GetScalingSpec() + + assert.Equal(t, 2, spec.CurrentReplicaCount) + assert.Equal(t, 10, spec.DesiredReplicaCount) + assert.Equal(t, 8, spec.ApprovedReplicaCount) + assert.Equal(t, "scale-up request", spec.LastScalingReason) +} + +func TestState_Concurrency(t *testing.T) { + state := NewState() + var wg sync.WaitGroup + + // Run concurrent writes and reads + for i := 0; i < 100; i++ { + wg.Add(3) + + // Writer goroutine - Update + go func(i int) { + defer wg.Done() + replicas := map[string]ShardReplica{ + "shard-0": {Status: "READY", Message: "Running"}, + } + state.Update(replicas, i, "concurrent update") + }(i) + + // Writer goroutine - SetApprovedCount + go func(i int) { + defer wg.Done() + state.SetApprovedCount(i) + }(i) + + // Reader goroutine - GetScalingSpec + go func() { + defer wg.Done() + _ = state.GetScalingSpec() + }() + } + + wg.Wait() + + // If we got here without data race, the test passes + // The actual values don't matter, we're testing thread safety + spec := state.GetScalingSpec() + assert.NotNil(t, spec) +} + +func TestState_UpdateWithEmptyReplicas(t *testing.T) { + state := NewState() + + // Start with some replicas + replicas := map[string]ShardReplica{ + "shard-0": {Status: "READY", Message: "Running"}, + } + state.Update(replicas, 3, "initial") + assert.Equal(t, 1, state.GetCurrentReplicaCount()) + + // Update with empty replicas + state.Update(map[string]ShardReplica{}, 1, "scale-down") + assert.Equal(t, 0, state.GetCurrentReplicaCount()) + assert.Equal(t, 1, state.GetDesiredReplicaCount()) +} + +func TestState_UpdateWithMetrics(t *testing.T) { + state := NewState() + + replicas := map[string]ShardReplica{ + "shard-0": { + Status: "READY", + Message: "Running", + Metrics: map[string]float64{ + "cpu_usage": 0.75, + "memory_usage": 0.60, + }, + }, + } + + state.Update(replicas, 2, "metrics present") + + assert.Equal(t, 1, state.GetCurrentReplicaCount()) +} From d25f8c7e49db0578c4a63282c43f551d8e3d4d36 Mon Sep 17 00:00:00 2001 From: george-dorin Date: Tue, 30 Dec 2025 16:48:19 +0200 Subject: [PATCH 3/8] Fix logger --- core/services/arbiter/decision.go | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/core/services/arbiter/decision.go b/core/services/arbiter/decision.go index d69c39dddba..4999297edd8 100644 --- a/core/services/arbiter/decision.go +++ b/core/services/arbiter/decision.go @@ -15,14 +15,14 @@ type DecisionEngine interface { type decisionEngine struct { shardConfig ShardConfigReader - lggr logger.Logger + lggr logger.SugaredLogger } // NewDecisionEngine creates a new DecisionEngine. -func NewDecisionEngine(shardConfig ShardConfigReader, lggr logger.Logger) DecisionEngine { +func NewDecisionEngine(shardConfig ShardConfigReader, lggr logger.SugaredLogger) DecisionEngine { return &decisionEngine{ shardConfig: shardConfig, - lggr: lggr.Named("DecisionEngine"), + lggr: lggr, } } @@ -31,24 +31,11 @@ func NewDecisionEngine(shardConfig ShardConfigReader, lggr logger.Logger) Decisi // with a minimum of 1 shard always. func (d *decisionEngine) ComputeApprovedCount(ctx context.Context, desiredCount int) (int, error) { // Get on-chain limit from ShardConfig contract - maxAllowed, err := d.shardConfig.GetDesiredShardCount(ctx) + approved, err := d.shardConfig.GetDesiredShardCount(ctx) if err != nil { return 0, fmt.Errorf("failed to get on-chain shard limit: %w", err) } - // Apply constraint: approved = min(desired, max) - approved := desiredCount - maxAllowedInt := int(maxAllowed) - - if approved > maxAllowedInt { - d.lggr.Infow("Capping desired count to on-chain limit", - "desired", desiredCount, - "maxAllowed", maxAllowedInt, - "approved", maxAllowedInt, - ) - approved = maxAllowedInt - } - // Ensure minimum of 1 shard if approved < 1 { d.lggr.Warnw("Desired count is less than 1, setting to minimum", @@ -60,9 +47,8 @@ func (d *decisionEngine) ComputeApprovedCount(ctx context.Context, desiredCount d.lggr.Debugw("Computed approved replica count", "desired", desiredCount, - "maxAllowed", maxAllowedInt, "approved", approved, ) - return approved, nil + return int(approved), nil } From 8808f37ac896014ab45d183ad5dc0b394c76b307 Mon Sep 17 00:00:00 2001 From: george-dorin Date: Tue, 6 Jan 2026 16:26:25 +0200 Subject: [PATCH 4/8] Add shard-config contract reader --- core/services/arbiter/arbiter.go | 26 ++- core/services/arbiter/arbiter_test.go | 35 +++- core/services/arbiter/decision_test.go | 60 +++--- core/services/arbiter/shardconfig.go | 275 ++++++++++++++++++++++--- 4 files changed, 329 insertions(+), 67 deletions(-) diff --git a/core/services/arbiter/arbiter.go b/core/services/arbiter/arbiter.go index 1f30c406c58..0bab70e19be 100644 --- a/core/services/arbiter/arbiter.go +++ b/core/services/arbiter/arbiter.go @@ -2,13 +2,13 @@ package arbiter import ( "context" + "fmt" "net" "sync" "google.golang.org/grpc" "github.com/smartcontractkit/chainlink-common/pkg/services" - "github.com/smartcontractkit/chainlink-common/pkg/types" // TODO: Update this import path once proto is generated pb "github.com/smartcontractkit/chainlink/v2/core/services/arbiter/proto" @@ -45,9 +45,11 @@ type arbiter struct { var _ Arbiter = (*arbiter)(nil) // New creates a new Arbiter service. +// contractReaderFactory is used to create the contract reader for querying the ShardConfig contract. +// This follows the same pattern as the workflow registry syncer and capability registry syncer. func New( lggr logger.Logger, - contractReader types.ContractReader, + contractReaderFactory ContractReaderFactory, shardConfigAddr string, ) (Arbiter, error) { lggr = lggr.Named("Arbiter") @@ -55,11 +57,11 @@ func New( // Create state state := NewState() - // Create ShardConfig reader - shardConfig := NewShardConfigReader(contractReader, shardConfigAddr, lggr) + // Create ShardConfig syncer (implements services.Service) + shardConfig := NewShardConfigSyncer(contractReaderFactory, shardConfigAddr, lggr) - // Create decision engine - decision := NewDecisionEngine(shardConfig, lggr) + // Create decision engine with sugared logger + decision := NewDecisionEngine(shardConfig, logger.Sugared(lggr)) // Create gRPC handler grpcHandler := NewGRPCServer(state, decision, lggr) @@ -85,6 +87,11 @@ func (a *arbiter) Start(ctx context.Context) error { return a.StartOnce("Arbiter", func() error { a.lggr.Info("Starting Arbiter service") + // Start ShardConfig syncer first + if err := a.shardConfig.Start(ctx); err != nil { + return fmt.Errorf("failed to start shard config syncer: %w", err) + } + // Start gRPC server in a goroutine a.wg.Add(1) go func() { @@ -141,9 +148,14 @@ func (a *arbiter) Close() error { a.grpcServer.GracefulStop() a.lggr.Debug("gRPC server stopped gracefully") - // Wait for goroutines + // Wait for gRPC goroutine a.wg.Wait() + // Close ShardConfig syncer + if err := a.shardConfig.Close(); err != nil { + a.lggr.Errorw("Failed to close shard config syncer", "error", err) + } + a.lggr.Info("Arbiter service stopped") return nil diff --git a/core/services/arbiter/arbiter_test.go b/core/services/arbiter/arbiter_test.go index 5d51ca2e709..88add77ff43 100644 --- a/core/services/arbiter/arbiter_test.go +++ b/core/services/arbiter/arbiter_test.go @@ -81,11 +81,19 @@ func (m *mockContractReader) QueryKeys(ctx context.Context, filters []types.Cont return nil, nil } +// mockContractReaderFactory creates a ContractReaderFactory that returns the mock reader. +func mockContractReaderFactory(mockReader *mockContractReader) ContractReaderFactory { + return func(ctx context.Context, cfg []byte) (types.ContractReader, error) { + return mockReader, nil + } +} + func TestArbiter_New(t *testing.T) { lggr := logger.TestLogger(t) mockReader := &mockContractReader{desiredShardCount: 10} + factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, mockReader, "0x1234567890abcdef") + arb, err := New(lggr, factory, "0x1234567890abcdef") require.NoError(t, err) require.NotNil(t, arb) @@ -95,16 +103,17 @@ func TestArbiter_New(t *testing.T) { func TestArbiter_StartClose(t *testing.T) { lggr := logger.TestLogger(t) mockReader := &mockContractReader{desiredShardCount: 10} + factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, mockReader, "0x1234567890abcdef") + arb, err := New(lggr, factory, "0x1234567890abcdef") require.NoError(t, err) // Test start err = arb.Start(context.Background()) require.NoError(t, err) - // Give gRPC server a moment to start - time.Sleep(50 * time.Millisecond) + // Give gRPC server and syncer a moment to start + time.Sleep(100 * time.Millisecond) // Test health after start healthReport := arb.HealthReport() @@ -119,8 +128,9 @@ func TestArbiter_StartClose(t *testing.T) { func TestArbiter_ServiceTestRun(t *testing.T) { lggr := logger.TestLogger(t) mockReader := &mockContractReader{desiredShardCount: 10} + factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, mockReader, "0x1234567890abcdef") + arb, err := New(lggr, factory, "0x1234567890abcdef") require.NoError(t, err) // Use servicetest.Run to handle lifecycle @@ -135,8 +145,9 @@ func TestArbiter_ServiceTestRun(t *testing.T) { func TestArbiter_HealthReport(t *testing.T) { lggr := logger.TestLogger(t) mockReader := &mockContractReader{desiredShardCount: 10} + factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, mockReader, "0x1234567890abcdef") + arb, err := New(lggr, factory, "0x1234567890abcdef") require.NoError(t, err) t.Run("before start - not ready", func(t *testing.T) { @@ -162,8 +173,9 @@ func TestArbiter_HealthReport(t *testing.T) { func TestArbiter_DoubleStart(t *testing.T) { lggr := logger.TestLogger(t) mockReader := &mockContractReader{desiredShardCount: 10} + factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, mockReader, "0x1234567890abcdef") + arb, err := New(lggr, factory, "0x1234567890abcdef") require.NoError(t, err) // First start should succeed @@ -181,8 +193,9 @@ func TestArbiter_DoubleStart(t *testing.T) { func TestArbiter_DoubleClose(t *testing.T) { lggr := logger.TestLogger(t) mockReader := &mockContractReader{desiredShardCount: 10} + factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, mockReader, "0x1234567890abcdef") + arb, err := New(lggr, factory, "0x1234567890abcdef") require.NoError(t, err) err = arb.Start(context.Background()) @@ -200,8 +213,9 @@ func TestArbiter_DoubleClose(t *testing.T) { func TestArbiter_Name(t *testing.T) { lggr := logger.TestLogger(t) mockReader := &mockContractReader{desiredShardCount: 10} + factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, mockReader, "0x1234567890abcdef") + arb, err := New(lggr, factory, "0x1234567890abcdef") require.NoError(t, err) assert.Equal(t, "Arbiter", arb.Name()) @@ -210,8 +224,9 @@ func TestArbiter_Name(t *testing.T) { func TestArbiter_Ready(t *testing.T) { lggr := logger.TestLogger(t) mockReader := &mockContractReader{desiredShardCount: 10} + factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, mockReader, "0x1234567890abcdef") + arb, err := New(lggr, factory, "0x1234567890abcdef") require.NoError(t, err) // Before start, Ready should return error diff --git a/core/services/arbiter/decision_test.go b/core/services/arbiter/decision_test.go index a976b61a560..e712ea61a68 100644 --- a/core/services/arbiter/decision_test.go +++ b/core/services/arbiter/decision_test.go @@ -12,11 +12,32 @@ import ( ) // mockShardConfigReader is a mock implementation of ShardConfigReader for testing. +// It implements services.Service interface as required by the updated ShardConfigReader. type mockShardConfigReader struct { desiredCount uint64 err error } +func (m *mockShardConfigReader) Start(ctx context.Context) error { + return nil +} + +func (m *mockShardConfigReader) Close() error { + return nil +} + +func (m *mockShardConfigReader) Ready() error { + return nil +} + +func (m *mockShardConfigReader) HealthReport() map[string]error { + return nil +} + +func (m *mockShardConfigReader) Name() string { + return "mockShardConfigReader" +} + func (m *mockShardConfigReader) GetDesiredShardCount(ctx context.Context) (uint64, error) { return m.desiredCount, m.err } @@ -34,7 +55,7 @@ func TestDecisionEngine_ComputeApprovedCount(t *testing.T) { name: "desired under limit", desiredCount: 5, onChainMax: 10, - expectedResult: 5, + expectedResult: 10, // approved = onChainMax (since we just return on-chain value) expectError: false, }, { @@ -45,28 +66,21 @@ func TestDecisionEngine_ComputeApprovedCount(t *testing.T) { expectError: false, }, { - name: "desired exceeds limit - capped", + name: "desired exceeds limit", desiredCount: 15, onChainMax: 10, - expectedResult: 10, + expectedResult: 10, // approved = onChainMax expectError: false, }, { - name: "desired zero - minimum 1", - desiredCount: 0, - onChainMax: 10, - expectedResult: 1, - expectError: false, - }, - { - name: "negative desired - minimum 1", - desiredCount: -5, - onChainMax: 10, - expectedResult: 1, + name: "on-chain limit zero - minimum 1 applied", + desiredCount: 5, + onChainMax: 0, + expectedResult: 1, // minimum of 1 shard expectError: false, }, { - name: "small on-chain limit caps result", + name: "small on-chain limit", desiredCount: 5, onChainMax: 3, expectedResult: 3, @@ -98,7 +112,7 @@ func TestDecisionEngine_ComputeApprovedCount(t *testing.T) { err: tc.shardConfigErr, } - engine := NewDecisionEngine(mockReader, lggr) + engine := NewDecisionEngine(mockReader, logger.Sugared(lggr)) result, err := engine.ComputeApprovedCount(context.Background(), tc.desiredCount) @@ -116,23 +130,23 @@ func TestDecisionEngine_ComputeApprovedCount(t *testing.T) { func TestDecisionEngine_ComputeApprovedCount_EdgeCases(t *testing.T) { lggr := logger.TestLogger(t) - t.Run("large desired count capped to large on-chain limit", func(t *testing.T) { + t.Run("large on-chain limit", func(t *testing.T) { mockReader := &mockShardConfigReader{ desiredCount: 1000, } - engine := NewDecisionEngine(mockReader, lggr) + engine := NewDecisionEngine(mockReader, logger.Sugared(lggr)) result, err := engine.ComputeApprovedCount(context.Background(), 500) require.NoError(t, err) - assert.Equal(t, 500, result) + assert.Equal(t, 1000, result) // returns on-chain value }) t.Run("exactly at on-chain limit", func(t *testing.T) { mockReader := &mockShardConfigReader{ desiredCount: 7, } - engine := NewDecisionEngine(mockReader, lggr) + engine := NewDecisionEngine(mockReader, logger.Sugared(lggr)) result, err := engine.ComputeApprovedCount(context.Background(), 7) @@ -144,12 +158,12 @@ func TestDecisionEngine_ComputeApprovedCount_EdgeCases(t *testing.T) { mockReader := &mockShardConfigReader{ desiredCount: 0, } - engine := NewDecisionEngine(mockReader, lggr) + engine := NewDecisionEngine(mockReader, logger.Sugared(lggr)) result, err := engine.ComputeApprovedCount(context.Background(), 5) require.NoError(t, err) - // approved = min(5, 0) = 0, but minimum is 1 + // on-chain returns 0, but minimum is 1 assert.Equal(t, 1, result) }) } @@ -161,7 +175,7 @@ func TestDecisionEngine_ContextCancellation(t *testing.T) { mockReader := &mockShardConfigReader{ err: context.Canceled, } - engine := NewDecisionEngine(mockReader, lggr) + engine := NewDecisionEngine(mockReader, logger.Sugared(lggr)) ctx, cancel := context.WithCancel(context.Background()) cancel() diff --git a/core/services/arbiter/shardconfig.go b/core/services/arbiter/shardconfig.go index 0ed9fe92558..71bff6fb44c 100644 --- a/core/services/arbiter/shardconfig.go +++ b/core/services/arbiter/shardconfig.go @@ -2,10 +2,16 @@ package arbiter import ( "context" + "encoding/json" + "fmt" "math/big" + "sync" + "time" + "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + "github.com/smartcontractkit/chainlink-evm/pkg/config" "github.com/smartcontractkit/chainlink/v2/core/logger" ) @@ -16,10 +22,15 @@ const ( // GetDesiredShardCountMethod is the method name for reading the desired shard count. GetDesiredShardCountMethod = "getDesiredShardCount" + + // Polling intervals (matching workflow registry syncer pattern) + defaultPollInterval = 12 * time.Second + defaultRetryInterval = 12 * time.Second ) // ShardConfigABI is the ABI for the ShardConfig contract. -// Only includes the getDesiredShardCount function we need. +// TODO: Once the chainlink-evm/contracts/cre/gobindings module is published, +// import shard_config.ShardConfigABI instead of using this inline definition. const ShardConfigABI = `[ { "inputs": [], @@ -30,62 +41,272 @@ const ShardConfigABI = `[ } ]` +// ContractReaderFactory creates a ContractReader from a config. +// This matches the pattern used in workflow registry syncer. +type ContractReaderFactory func(context.Context, []byte) (types.ContractReader, error) + // ShardConfigReader reads the desired shard count from the ShardConfig contract. type ShardConfigReader interface { - // GetDesiredShardCount retrieves the current desired shard count from on-chain. + services.Service + // GetDesiredShardCount retrieves the current desired shard count from cache. + // Returns an error if the value hasn't been fetched yet. GetDesiredShardCount(ctx context.Context) (uint64, error) } -type shardConfigReader struct { - reader types.ContractReader - contract types.BoundContract - lggr logger.Logger +// shardConfigSyncer implements ShardConfigReader with periodic polling and caching. +type shardConfigSyncer struct { + services.StateMachine + stopCh services.StopChan + wg sync.WaitGroup + + lggr logger.Logger + shardConfigAddress string + contractReaderFactory ContractReaderFactory + contractReader types.ContractReader + + // Cached value with mutex protection + cachedShardCount uint64 + cachedMu sync.RWMutex + + // Polling configuration + pollInterval time.Duration } -// NewShardConfigReader creates a new ShardConfigReader. -func NewShardConfigReader( - reader types.ContractReader, - contractAddress string, +var _ ShardConfigReader = (*shardConfigSyncer)(nil) + +// NewShardConfigSyncer creates a new ShardConfigReader that polls the contract periodically. +func NewShardConfigSyncer( + contractReaderFactory ContractReaderFactory, + shardConfigAddress string, lggr logger.Logger, ) ShardConfigReader { - return &shardConfigReader{ - reader: reader, - contract: types.BoundContract{ - Address: contractAddress, - Name: ShardConfigContractName, + return &shardConfigSyncer{ + lggr: lggr.Named("ShardConfigSyncer"), + shardConfigAddress: shardConfigAddress, + contractReaderFactory: contractReaderFactory, + pollInterval: defaultPollInterval, + stopCh: make(services.StopChan), + } +} + +// Start starts the ShardConfig syncer service. +func (s *shardConfigSyncer) Start(ctx context.Context) error { + return s.StartOnce("ShardConfigSyncer", func() error { + s.lggr.Info("Starting ShardConfig syncer") + + // Start async initialization and polling + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.run(ctx) + }() + + return nil + }) +} + +// Close stops the ShardConfig syncer service. +func (s *shardConfigSyncer) Close() error { + return s.StopOnce("ShardConfigSyncer", func() error { + s.lggr.Info("Stopping ShardConfig syncer") + + // Signal stop + close(s.stopCh) + + // Wait for goroutines + s.wg.Wait() + + // Close contract reader if initialized + if s.contractReader != nil { + if err := s.contractReader.Close(); err != nil { + s.lggr.Errorw("Failed to close contract reader", "error", err) + } + } + + s.lggr.Info("ShardConfig syncer stopped") + return nil + }) +} + +// Name returns the service name. +func (s *shardConfigSyncer) Name() string { + return s.lggr.Name() +} + +// HealthReport returns the health status of the service. +func (s *shardConfigSyncer) HealthReport() map[string]error { + return map[string]error{ + s.Name(): s.Ready(), + } +} + +// GetDesiredShardCount retrieves the cached desired shard count. +func (s *shardConfigSyncer) GetDesiredShardCount(ctx context.Context) (uint64, error) { + s.cachedMu.RLock() + defer s.cachedMu.RUnlock() + + if s.cachedShardCount == 0 { + return 0, fmt.Errorf("shard count not yet available") + } + + return s.cachedShardCount, nil +} + +// run handles initialization and polling loop. +func (s *shardConfigSyncer) run(ctx context.Context) { + // Phase 1: Initialize contract reader with retries + if err := s.initContractReader(ctx); err != nil { + s.lggr.Errorw("Failed to initialize contract reader", "error", err) + return + } + + // Phase 2: Start polling loop + s.pollLoop(ctx) +} + +// initContractReader initializes the contract reader with retry logic. +// This follows the lazy initialization pattern used by workflow registry syncer. +func (s *shardConfigSyncer) initContractReader(ctx context.Context) error { + ticker := time.NewTicker(defaultRetryInterval) + defer ticker.Stop() + + // Try immediately first + reader, err := s.newContractReader(ctx) + if err == nil { + s.contractReader = reader + s.lggr.Info("Contract reader initialized successfully") + return nil + } + s.lggr.Infow("Contract reader unavailable, will retry", "error", err) + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-s.stopCh: + return nil + case <-ticker.C: + reader, err := s.newContractReader(ctx) + if err != nil { + s.lggr.Infow("Contract reader unavailable, retrying", "error", err) + continue + } + s.contractReader = reader + s.lggr.Info("Contract reader initialized successfully") + return nil + } + } +} + +// newContractReader creates and configures a new contract reader. +func (s *shardConfigSyncer) newContractReader(ctx context.Context) (types.ContractReader, error) { + cfg := buildShardConfigReaderConfig() + + cfgBytes, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("failed to marshal config: %w", err) + } + + reader, err := s.contractReaderFactory(ctx, cfgBytes) + if err != nil { + return nil, fmt.Errorf("failed to create contract reader: %w", err) + } + if reader == nil { + return nil, fmt.Errorf("contract reader factory returned nil") + } + + // Bind the contract address + bc := types.BoundContract{ + Address: s.shardConfigAddress, + Name: ShardConfigContractName, + } + + if err := reader.Bind(ctx, []types.BoundContract{bc}); err != nil { + return nil, fmt.Errorf("failed to bind contract: %w", err) + } + + if err := reader.Start(ctx); err != nil { + return nil, fmt.Errorf("failed to start contract reader: %w", err) + } + + return reader, nil +} + +// buildShardConfigReaderConfig creates the ChainReaderConfig for the ShardConfig contract. +func buildShardConfigReaderConfig() config.ChainReaderConfig { + return config.ChainReaderConfig{ + Contracts: map[string]config.ChainContractReader{ + ShardConfigContractName: { + ContractABI: ShardConfigABI, + Configs: map[string]*config.ChainReaderDefinition{ + GetDesiredShardCountMethod: { + ChainSpecificName: GetDesiredShardCountMethod, + ReadType: config.Method, + }, + }, + }, }, - lggr: lggr.Named("ShardConfigReader"), } } -// GetDesiredShardCount retrieves the current desired shard count from on-chain. -func (s *shardConfigReader) GetDesiredShardCount(ctx context.Context) (uint64, error) { +// pollLoop periodically fetches the shard count from the contract. +func (s *shardConfigSyncer) pollLoop(ctx context.Context) { + ticker := time.NewTicker(s.pollInterval) + defer ticker.Stop() + + // Initial fetch + s.fetchAndCache(ctx) + + for { + select { + case <-ctx.Done(): + return + case <-s.stopCh: + return + case <-ticker.C: + s.fetchAndCache(ctx) + } + } +} + +// fetchAndCache fetches the shard count from the contract and updates the cache. +func (s *shardConfigSyncer) fetchAndCache(ctx context.Context) { + if s.contractReader == nil { + return + } + var result *big.Int + bc := types.BoundContract{ + Address: s.shardConfigAddress, + Name: ShardConfigContractName, + } - err := s.reader.GetLatestValue( + err := s.contractReader.GetLatestValue( ctx, - s.contract.ReadIdentifier(GetDesiredShardCountMethod), + bc.ReadIdentifier(GetDesiredShardCountMethod), primitives.Finalized, // Use finalized for governance data nil, // No input params &result, ) if err != nil { - s.lggr.Errorw("failed to get desired shard count from on-chain", + s.lggr.Errorw("Failed to fetch shard count from on-chain", "error", err, - "contract", s.contract.Address, + "contract", s.shardConfigAddress, ) - return 0, err + return } count := result.Uint64() + s.cachedMu.Lock() + s.cachedShardCount = count + s.cachedMu.Unlock() + // Update metrics SetOnChainMaxReplicas(count) - s.lggr.Debugw("read desired shard count from on-chain", + s.lggr.Debugw("Fetched shard count from on-chain", "count", count, - "contract", s.contract.Address, + "contract", s.shardConfigAddress, ) - - return count, nil } From 70f4ea5471d9835f5cdc1238fd0511c946a10cbb Mon Sep 17 00:00:00 2001 From: george-dorin Date: Wed, 7 Jan 2026 14:12:31 +0200 Subject: [PATCH 5/8] Add sharding config --- core/config/sharding_config.go | 9 +++++ core/config/toml/types.go | 23 +++++++++++++ core/services/chainlink/config_general.go | 4 +++ core/services/chainlink/config_sharding.go | 39 ++++++++++++++++++++++ 4 files changed, 75 insertions(+) create mode 100644 core/config/sharding_config.go create mode 100644 core/services/chainlink/config_sharding.go diff --git a/core/config/sharding_config.go b/core/config/sharding_config.go new file mode 100644 index 00000000000..693166bf69f --- /dev/null +++ b/core/config/sharding_config.go @@ -0,0 +1,9 @@ +package config + +import "time" + +type Sharding interface { + ArbiterPort() uint16 + ArbiterPollInterval() time.Duration + ArbiterRetryInterval() time.Duration +} diff --git a/core/config/toml/types.go b/core/config/toml/types.go index 74800fcb5a5..afb769a6f3e 100644 --- a/core/config/toml/types.go +++ b/core/config/toml/types.go @@ -64,6 +64,7 @@ type Core struct { CRE CreConfig `toml:",omitempty"` Billing Billing `toml:",omitempty"` BridgeStatusReporter BridgeStatusReporter `toml:",omitempty"` + Sharding Sharding `toml:",omitempty"` } // SetFrom updates c with any non-nil values from f. (currently TOML field only!) @@ -109,6 +110,8 @@ func (c *Core) SetFrom(f *Core) { c.CRE.setFrom(&f.CRE) c.Billing.setFrom(&f.Billing) c.BridgeStatusReporter.setFrom(&f.BridgeStatusReporter) + + c.Sharding.setFrom(&f.Sharding) } func (c *Core) ValidateConfig() (err error) { @@ -2752,3 +2755,23 @@ func (jd *JobDistributor) setFrom(f *JobDistributor) { jd.DisplayName = f.DisplayName } } + +type Sharding struct { + ArbiterPort *uint16 + ArbiterPollInterval *commonconfig.Duration + ArbiterRetryInterval *commonconfig.Duration +} + +func (s *Sharding) setFrom(f *Sharding) { + if f.ArbiterPort != nil { + s.ArbiterPort = f.ArbiterPort + } + + if f.ArbiterPollInterval != nil { + s.ArbiterPollInterval = f.ArbiterPollInterval + } + + if f.ArbiterRetryInterval != nil { + s.ArbiterRetryInterval = f.ArbiterRetryInterval + } +} diff --git a/core/services/chainlink/config_general.go b/core/services/chainlink/config_general.go index 3e6a0a8b8f7..741b5547423 100644 --- a/core/services/chainlink/config_general.go +++ b/core/services/chainlink/config_general.go @@ -596,4 +596,8 @@ func (g *generalConfig) BridgeStatusReporter() coreconfig.BridgeStatusReporter { return &bridgeStatusReporterConfig{c: g.c.BridgeStatusReporter} } +func (g *generalConfig) Sharding() coreconfig.Sharding { + return &shardingConfig{s: g.c.Sharding} +} + var zeroSha256Hash = models.Sha256Hash{} diff --git a/core/services/chainlink/config_sharding.go b/core/services/chainlink/config_sharding.go new file mode 100644 index 00000000000..561c190146c --- /dev/null +++ b/core/services/chainlink/config_sharding.go @@ -0,0 +1,39 @@ +package chainlink + +import ( + "time" + + "github.com/smartcontractkit/chainlink/v2/core/config" + "github.com/smartcontractkit/chainlink/v2/core/config/toml" +) + +const defaultArbiterPort = 9876 +const defaultArbiterPollInterval = time.Second * 12 +const defaultArbiterRetryInterval = time.Second * 12 + +var _ config.Sharding = (*shardingConfig)(nil) + +type shardingConfig struct { + s toml.Sharding +} + +func (s *shardingConfig) ArbiterPort() uint16 { + if s.s.ArbiterPort == nil { + return defaultArbiterPort + } + return *s.s.ArbiterPort +} + +func (s *shardingConfig) ArbiterPollInterval() time.Duration { + if s.s.ArbiterPollInterval == nil || s.s.ArbiterPollInterval.Duration() <= 0 { + return defaultArbiterPollInterval + } + return s.s.ArbiterPollInterval.Duration() +} + +func (s *shardingConfig) ArbiterRetryInterval() time.Duration { + if s.s.ArbiterRetryInterval == nil || s.s.ArbiterRetryInterval.Duration() <= 0 { + return defaultArbiterRetryInterval + } + return s.s.ArbiterRetryInterval.Duration() +} From f3ef44be893cd93c2d5f4182a70e284c15323468 Mon Sep 17 00:00:00 2001 From: george-dorin Date: Wed, 7 Jan 2026 15:40:25 +0200 Subject: [PATCH 6/8] Update arbiter to use config values --- core/config/app_config.go | 1 + core/services/arbiter/arbiter.go | 14 +++++++------- core/services/arbiter/arbiter_test.go | 24 ++++++++++++++++-------- core/services/arbiter/shardconfig.go | 12 ++++++------ 4 files changed, 30 insertions(+), 21 deletions(-) diff --git a/core/config/app_config.go b/core/config/app_config.go index a5a0e0a1ee9..fbf97dba8bd 100644 --- a/core/config/app_config.go +++ b/core/config/app_config.go @@ -65,6 +65,7 @@ type AppConfig interface { CCV() CCV Billing() Billing BridgeStatusReporter() BridgeStatusReporter + Sharding() Sharding } type DatabaseBackupMode string diff --git a/core/services/arbiter/arbiter.go b/core/services/arbiter/arbiter.go index 0bab70e19be..b50b8f154e4 100644 --- a/core/services/arbiter/arbiter.go +++ b/core/services/arbiter/arbiter.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "net" + "strconv" "sync" + "time" "google.golang.org/grpc" @@ -16,11 +18,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" ) -const ( - // DefaultGRPCPort is the default port for the gRPC server. - DefaultGRPCPort = ":9090" -) - // Arbiter is the main service interface. type Arbiter interface { services.Service @@ -51,6 +48,9 @@ func New( lggr logger.Logger, contractReaderFactory ContractReaderFactory, shardConfigAddr string, + port uint16, + pollInterval time.Duration, + retryInterval time.Duration, ) (Arbiter, error) { lggr = lggr.Named("Arbiter") @@ -58,7 +58,7 @@ func New( state := NewState() // Create ShardConfig syncer (implements services.Service) - shardConfig := NewShardConfigSyncer(contractReaderFactory, shardConfigAddr, lggr) + shardConfig := NewShardConfigSyncer(contractReaderFactory, shardConfigAddr, pollInterval, retryInterval, lggr) // Create decision engine with sugared logger decision := NewDecisionEngine(shardConfig, logger.Sugared(lggr)) @@ -77,7 +77,7 @@ func New( decision: decision, shardConfig: shardConfig, lggr: lggr, - grpcAddr: DefaultGRPCPort, + grpcAddr: strconv.Itoa(int(port)), stopCh: make(services.StopChan), }, nil } diff --git a/core/services/arbiter/arbiter_test.go b/core/services/arbiter/arbiter_test.go index 88add77ff43..22402e97556 100644 --- a/core/services/arbiter/arbiter_test.go +++ b/core/services/arbiter/arbiter_test.go @@ -14,6 +14,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/query" "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + "github.com/smartcontractkit/freeport" "github.com/smartcontractkit/chainlink/v2/core/logger" ) @@ -88,12 +89,19 @@ func mockContractReaderFactory(mockReader *mockContractReader) ContractReaderFac } } +// Test configuration defaults +const ( + testPollInterval time.Duration = 12 * time.Second + testRetryInterval time.Duration = 12 * time.Second + testShardConfigAddr = "0x1234567890abcdef" +) + func TestArbiter_New(t *testing.T) { lggr := logger.TestLogger(t) mockReader := &mockContractReader{desiredShardCount: 10} factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, factory, "0x1234567890abcdef") + arb, err := New(lggr, factory, testShardConfigAddr, uint16(freeport.GetOne(t)), testPollInterval, testRetryInterval) require.NoError(t, err) require.NotNil(t, arb) @@ -105,7 +113,7 @@ func TestArbiter_StartClose(t *testing.T) { mockReader := &mockContractReader{desiredShardCount: 10} factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, factory, "0x1234567890abcdef") + arb, err := New(lggr, factory, testShardConfigAddr, uint16(freeport.GetOne(t)), testPollInterval, testRetryInterval) require.NoError(t, err) // Test start @@ -130,7 +138,7 @@ func TestArbiter_ServiceTestRun(t *testing.T) { mockReader := &mockContractReader{desiredShardCount: 10} factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, factory, "0x1234567890abcdef") + arb, err := New(lggr, factory, testShardConfigAddr, uint16(freeport.GetOne(t)), testPollInterval, testRetryInterval) require.NoError(t, err) // Use servicetest.Run to handle lifecycle @@ -147,7 +155,7 @@ func TestArbiter_HealthReport(t *testing.T) { mockReader := &mockContractReader{desiredShardCount: 10} factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, factory, "0x1234567890abcdef") + arb, err := New(lggr, factory, testShardConfigAddr, uint16(freeport.GetOne(t)), testPollInterval, testRetryInterval) require.NoError(t, err) t.Run("before start - not ready", func(t *testing.T) { @@ -175,7 +183,7 @@ func TestArbiter_DoubleStart(t *testing.T) { mockReader := &mockContractReader{desiredShardCount: 10} factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, factory, "0x1234567890abcdef") + arb, err := New(lggr, factory, testShardConfigAddr, uint16(freeport.GetOne(t)), testPollInterval, testRetryInterval) require.NoError(t, err) // First start should succeed @@ -195,7 +203,7 @@ func TestArbiter_DoubleClose(t *testing.T) { mockReader := &mockContractReader{desiredShardCount: 10} factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, factory, "0x1234567890abcdef") + arb, err := New(lggr, factory, testShardConfigAddr, uint16(freeport.GetOne(t)), testPollInterval, testRetryInterval) require.NoError(t, err) err = arb.Start(context.Background()) @@ -215,7 +223,7 @@ func TestArbiter_Name(t *testing.T) { mockReader := &mockContractReader{desiredShardCount: 10} factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, factory, "0x1234567890abcdef") + arb, err := New(lggr, factory, testShardConfigAddr, uint16(freeport.GetOne(t)), testPollInterval, testRetryInterval) require.NoError(t, err) assert.Equal(t, "Arbiter", arb.Name()) @@ -226,7 +234,7 @@ func TestArbiter_Ready(t *testing.T) { mockReader := &mockContractReader{desiredShardCount: 10} factory := mockContractReaderFactory(mockReader) - arb, err := New(lggr, factory, "0x1234567890abcdef") + arb, err := New(lggr, factory, testShardConfigAddr, uint16(freeport.GetOne(t)), testPollInterval, testRetryInterval) require.NoError(t, err) // Before start, Ready should return error diff --git a/core/services/arbiter/shardconfig.go b/core/services/arbiter/shardconfig.go index 71bff6fb44c..09d793078be 100644 --- a/core/services/arbiter/shardconfig.go +++ b/core/services/arbiter/shardconfig.go @@ -22,10 +22,6 @@ const ( // GetDesiredShardCountMethod is the method name for reading the desired shard count. GetDesiredShardCountMethod = "getDesiredShardCount" - - // Polling intervals (matching workflow registry syncer pattern) - defaultPollInterval = 12 * time.Second - defaultRetryInterval = 12 * time.Second ) // ShardConfigABI is the ABI for the ShardConfig contract. @@ -70,6 +66,7 @@ type shardConfigSyncer struct { // Polling configuration pollInterval time.Duration + retryTimeout time.Duration } var _ ShardConfigReader = (*shardConfigSyncer)(nil) @@ -78,13 +75,16 @@ var _ ShardConfigReader = (*shardConfigSyncer)(nil) func NewShardConfigSyncer( contractReaderFactory ContractReaderFactory, shardConfigAddress string, + pollInterval time.Duration, + retryTimeout time.Duration, lggr logger.Logger, ) ShardConfigReader { return &shardConfigSyncer{ lggr: lggr.Named("ShardConfigSyncer"), shardConfigAddress: shardConfigAddress, contractReaderFactory: contractReaderFactory, - pollInterval: defaultPollInterval, + pollInterval: pollInterval, + retryTimeout: retryTimeout, stopCh: make(services.StopChan), } } @@ -167,7 +167,7 @@ func (s *shardConfigSyncer) run(ctx context.Context) { // initContractReader initializes the contract reader with retry logic. // This follows the lazy initialization pattern used by workflow registry syncer. func (s *shardConfigSyncer) initContractReader(ctx context.Context) error { - ticker := time.NewTicker(defaultRetryInterval) + ticker := time.NewTicker(s.retryTimeout) defer ticker.Stop() // Try immediately first From c2d0c85f91a297bda839e5cc486c6613c0a41078 Mon Sep 17 00:00:00 2001 From: george-dorin Date: Wed, 7 Jan 2026 15:46:09 +0200 Subject: [PATCH 7/8] Add shardconfig tests and fix port format error --- core/services/arbiter/arbiter.go | 3 +- core/services/arbiter/arbiter_test.go | 29 +++ core/services/arbiter/shardconfig_test.go | 224 ++++++++++++++++++++++ 3 files changed, 254 insertions(+), 2 deletions(-) create mode 100644 core/services/arbiter/shardconfig_test.go diff --git a/core/services/arbiter/arbiter.go b/core/services/arbiter/arbiter.go index b50b8f154e4..6fb7f6c47fd 100644 --- a/core/services/arbiter/arbiter.go +++ b/core/services/arbiter/arbiter.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net" - "strconv" "sync" "time" @@ -77,7 +76,7 @@ func New( decision: decision, shardConfig: shardConfig, lggr: lggr, - grpcAddr: strconv.Itoa(int(port)), + grpcAddr: fmt.Sprintf(":%d", port), stopCh: make(services.StopChan), }, nil } diff --git a/core/services/arbiter/arbiter_test.go b/core/services/arbiter/arbiter_test.go index 22402e97556..f26503cce12 100644 --- a/core/services/arbiter/arbiter_test.go +++ b/core/services/arbiter/arbiter_test.go @@ -2,8 +2,10 @@ package arbiter import ( "context" + "fmt" "iter" "math/big" + "net" "testing" "time" @@ -255,3 +257,30 @@ func TestArbiter_Ready(t *testing.T) { err = arb.Ready() assert.Error(t, err) } + +func TestArbiter_GRPCServerListening(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockContractReader{desiredShardCount: 10} + factory := mockContractReaderFactory(mockReader) + + port := freeport.GetOne(t) + arb, err := New(lggr, factory, testShardConfigAddr, uint16(port), testPollInterval, testRetryInterval) + require.NoError(t, err) + + // Start the arbiter + err = arb.Start(context.Background()) + require.NoError(t, err) + + // Give gRPC server a moment to start + time.Sleep(100 * time.Millisecond) + + // Verify gRPC server is actually listening by attempting to connect + addr := fmt.Sprintf("localhost:%d", port) + conn, err := net.DialTimeout("tcp", addr, 1*time.Second) + require.NoError(t, err, "gRPC server should be listening on port %d", port) + conn.Close() + + // Cleanup + err = arb.Close() + require.NoError(t, err) +} diff --git a/core/services/arbiter/shardconfig_test.go b/core/services/arbiter/shardconfig_test.go new file mode 100644 index 00000000000..d8c3a27b6f3 --- /dev/null +++ b/core/services/arbiter/shardconfig_test.go @@ -0,0 +1,224 @@ +package arbiter + +import ( + "context" + "iter" + "math/big" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/types/query" + "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +// mockShardConfigContractReader is a mock implementation for testing ShardConfigSyncer. +type mockShardConfigContractReader struct { + types.UnimplementedContractReader + shardCount uint64 + err error + started bool + bound bool +} + +func (m *mockShardConfigContractReader) Name() string { + return "mockShardConfigContractReader" +} + +func (m *mockShardConfigContractReader) Start(ctx context.Context) error { + m.started = true + return nil +} + +func (m *mockShardConfigContractReader) Close() error { + m.started = false + return nil +} + +func (m *mockShardConfigContractReader) Ready() error { + return nil +} + +func (m *mockShardConfigContractReader) HealthReport() map[string]error { + return nil +} + +func (m *mockShardConfigContractReader) Bind(ctx context.Context, bindings []types.BoundContract) error { + m.bound = true + return nil +} + +func (m *mockShardConfigContractReader) Unbind(ctx context.Context, bindings []types.BoundContract) error { + return nil +} + +func (m *mockShardConfigContractReader) GetLatestValue(ctx context.Context, readIdentifier string, confidenceLevel primitives.ConfidenceLevel, params any, returnVal any) error { + if m.err != nil { + return m.err + } + if ptr, ok := returnVal.(**big.Int); ok { + *ptr = big.NewInt(int64(m.shardCount)) + } + return nil +} + +func (m *mockShardConfigContractReader) GetLatestValueWithHeadData(ctx context.Context, readIdentifier string, confidenceLevel primitives.ConfidenceLevel, params any, returnVal any) (head *types.Head, err error) { + err = m.GetLatestValue(ctx, readIdentifier, confidenceLevel, params, returnVal) + return nil, err +} + +func (m *mockShardConfigContractReader) BatchGetLatestValues(ctx context.Context, request types.BatchGetLatestValuesRequest) (types.BatchGetLatestValuesResult, error) { + return nil, nil +} + +func (m *mockShardConfigContractReader) QueryKey(ctx context.Context, contract types.BoundContract, filter query.KeyFilter, limitAndSort query.LimitAndSort, sequenceDataType any) ([]types.Sequence, error) { + return nil, nil +} + +func (m *mockShardConfigContractReader) QueryKeys(ctx context.Context, filters []types.ContractKeyFilter, limitAndSort query.LimitAndSort) (iter.Seq2[string, types.Sequence], error) { + return nil, nil +} + +func mockShardConfigReaderFactory(reader *mockShardConfigContractReader) ContractReaderFactory { + return func(ctx context.Context, cfg []byte) (types.ContractReader, error) { + return reader, nil + } +} + +func TestShardConfigSyncer_New(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockShardConfigContractReader{shardCount: 10} + factory := mockShardConfigReaderFactory(mockReader) + + syncer := NewShardConfigSyncer(factory, "0x1234", 12*time.Second, 12*time.Second, lggr) + + require.NotNil(t, syncer) + assert.Contains(t, syncer.Name(), "ShardConfigSyncer") +} + +func TestShardConfigSyncer_GetDesiredShardCount_BeforeFetch(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockShardConfigContractReader{shardCount: 10} + factory := mockShardConfigReaderFactory(mockReader) + + syncer := NewShardConfigSyncer(factory, "0x1234", 12*time.Second, 12*time.Second, lggr) + + // Before start, GetDesiredShardCount should return error + count, err := syncer.GetDesiredShardCount(context.Background()) + assert.Error(t, err) + assert.Equal(t, uint64(0), count) + assert.Contains(t, err.Error(), "not yet available") +} + +func TestShardConfigSyncer_GetDesiredShardCount_AfterFetch(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockShardConfigContractReader{shardCount: 42} + factory := mockShardConfigReaderFactory(mockReader) + + syncer := NewShardConfigSyncer(factory, "0x1234", 100*time.Millisecond, 100*time.Millisecond, lggr) + + // Start the syncer + err := syncer.Start(context.Background()) + require.NoError(t, err) + + // Wait for initial fetch (contract reader init + first poll) + time.Sleep(200 * time.Millisecond) + + // After fetch, GetDesiredShardCount should return the cached value + count, err := syncer.GetDesiredShardCount(context.Background()) + assert.NoError(t, err) + assert.Equal(t, uint64(42), count) + + // Cleanup + err = syncer.Close() + require.NoError(t, err) +} + +func TestShardConfigSyncer_StartClose(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockShardConfigContractReader{shardCount: 10} + factory := mockShardConfigReaderFactory(mockReader) + + syncer := NewShardConfigSyncer(factory, "0x1234", 12*time.Second, 12*time.Second, lggr) + + // Start + err := syncer.Start(context.Background()) + require.NoError(t, err) + + // Give it time to initialize + time.Sleep(100 * time.Millisecond) + + // Close + err = syncer.Close() + require.NoError(t, err) +} + +func TestShardConfigSyncer_HealthReport(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockShardConfigContractReader{shardCount: 10} + factory := mockShardConfigReaderFactory(mockReader) + + syncer := NewShardConfigSyncer(factory, "0x1234", 12*time.Second, 12*time.Second, lggr) + + // Before start + healthReport := syncer.HealthReport() + require.Contains(t, healthReport, syncer.Name()) + // Before start, Ready() should return an error + assert.Error(t, healthReport[syncer.Name()]) + + // Start + err := syncer.Start(context.Background()) + require.NoError(t, err) + + // After start + healthReport = syncer.HealthReport() + require.Contains(t, healthReport, syncer.Name()) + assert.NoError(t, healthReport[syncer.Name()]) + + // Cleanup + err = syncer.Close() + require.NoError(t, err) +} + +func TestShardConfigSyncer_DoubleStart(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockShardConfigContractReader{shardCount: 10} + factory := mockShardConfigReaderFactory(mockReader) + + syncer := NewShardConfigSyncer(factory, "0x1234", 12*time.Second, 12*time.Second, lggr) + + // First start should succeed + err := syncer.Start(context.Background()) + require.NoError(t, err) + + // Second start should return error (StartOnce) + err = syncer.Start(context.Background()) + assert.Error(t, err) + + err = syncer.Close() + require.NoError(t, err) +} + +func TestShardConfigSyncer_DoubleClose(t *testing.T) { + lggr := logger.TestLogger(t) + mockReader := &mockShardConfigContractReader{shardCount: 10} + factory := mockShardConfigReaderFactory(mockReader) + + syncer := NewShardConfigSyncer(factory, "0x1234", 12*time.Second, 12*time.Second, lggr) + + err := syncer.Start(context.Background()) + require.NoError(t, err) + + // First close should succeed + err = syncer.Close() + require.NoError(t, err) + + // Second close should return error (StopOnce) + err = syncer.Close() + assert.Error(t, err) +} From 156ee5096b4c93393e4adbf1626bf21656889a59 Mon Sep 17 00:00:00 2001 From: george-dorin Date: Wed, 7 Jan 2026 17:30:03 +0200 Subject: [PATCH 8/8] Remove redundant type definition --- core/services/arbiter/arbiter_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/services/arbiter/arbiter_test.go b/core/services/arbiter/arbiter_test.go index f26503cce12..4529e2b8601 100644 --- a/core/services/arbiter/arbiter_test.go +++ b/core/services/arbiter/arbiter_test.go @@ -93,9 +93,9 @@ func mockContractReaderFactory(mockReader *mockContractReader) ContractReaderFac // Test configuration defaults const ( - testPollInterval time.Duration = 12 * time.Second - testRetryInterval time.Duration = 12 * time.Second - testShardConfigAddr = "0x1234567890abcdef" + testPollInterval = 12 * time.Second + testRetryInterval = 12 * time.Second + testShardConfigAddr = "0x1234567890abcdef" ) func TestArbiter_New(t *testing.T) {