diff --git a/config/config.go b/config/config.go index 938a8e4..158b2fa 100644 --- a/config/config.go +++ b/config/config.go @@ -3,6 +3,7 @@ package config import ( "encoding/json" "fmt" + "github.com/sei-protocol/sei-load/utils" "math/big" "time" ) @@ -11,12 +12,19 @@ import ( type LoadConfig struct { ChainID int64 `json:"chainId,omitempty"` // SeiChainID is the textual chain ID used for tagging metric collection. - SeiChainID string `json:"seiChainID,omitempty"` - Endpoints []string `json:"endpoints"` - Accounts *AccountConfig `json:"accounts,omitempty"` - Scenarios []Scenario `json:"scenarios,omitempty"` - MockDeploy bool `json:"mockDeploy,omitempty"` - Settings *Settings `json:"settings,omitempty"` + SeiChainID string `json:"seiChainID,omitempty"` + Endpoints []string `json:"endpoints"` + // Number of shards to divide the senders into. + // Txs within each shard are sent sequentially. + // Defaults to Endpoints * Settings.TasksPerEndpoint. + // WARNING: this is unrelated to the server-side autobahn sharding + // (which assigns tx sender addrs to lanes). It is solely used to maximize + // txs/s throughput of the load generator. + NumShards utils.Option[int] `json:"numShards,omitzero"` + Accounts *AccountConfig `json:"accounts,omitempty"` + Scenarios []Scenario `json:"scenarios,omitempty"` + MockDeploy bool `json:"mockDeploy,omitempty"` + Settings *Settings `json:"settings,omitempty"` // Funding, when set, funds the generated account pool from a root key at // startup so the run works against a real chain. See funding.go. Funding *FundingConfig `json:"funding,omitempty"` @@ -33,6 +41,15 @@ type LoadConfig struct { Seed *uint64 `json:"seed,omitempty"` } +func (c *LoadConfig) GetNumShards() int { + return c.NumShards.Or(len(c.Endpoints) * c.Settings.TasksPerEndpoint) +} + +func (c *LoadConfig) TotalQueueSize() int { + // Backward compatible formula, consider making it a config value. + return len(c.Endpoints) * c.Settings.BufferSize +} + // Duration wraps time.Duration to provide JSON unmarshaling support type Duration time.Duration diff --git a/go.mod b/go.mod index 1678f0c..8aad25a 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25.1 require ( github.com/ethereum/go-ethereum v1.16.1 + github.com/gogo/protobuf v1.3.2 github.com/google/go-cmp v0.7.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.22.0 diff --git a/go.sum b/go.sum index 528de49..6457faa 100644 --- a/go.sum +++ b/go.sum @@ -119,6 +119,8 @@ github.com/influxdata/line-protocol v0.0.0-20200327222509-2487e7298839 h1:W9WBk7 github.com/influxdata/line-protocol v0.0.0-20200327222509-2487e7298839/go.mod h1:xaLFMmpvUxqXtVkUJfg9QmT88cDaCJ3ZKgdZ78oO8Qo= github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= @@ -220,6 +222,8 @@ github.com/urfave/cli/v2 v2.27.5 h1:WoHEJLdsXr6dDWoJgMq/CboDmyY/8HMMH1fTECbih+w= github.com/urfave/cli/v2 v2.27.5/go.mod h1:3Sevf16NykTbInEnD0yKkjDAeZDS0A6bzhBH5hrMvTQ= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY= @@ -250,24 +254,49 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df h1:UA2aFVmmsIlefxMk29Dp2juaUSth8Pyn3Tq5Y5mJGME= golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= diff --git a/main.go b/main.go index 4b0101c..db3eee4 100644 --- a/main.go +++ b/main.go @@ -28,7 +28,7 @@ import ( "github.com/sei-protocol/sei-load/sender" "github.com/sei-protocol/sei-load/stats" "github.com/sei-protocol/sei-load/utils" - "github.com/sei-protocol/sei-load/utils/service" + "github.com/sei-protocol/sei-load/utils/scope" ) var ( @@ -212,21 +212,18 @@ func runLoadTest(ctx context.Context, cmd *cobra.Command) error { var dispatcher *sender.Dispatcher var inclusionTracker *stats.InclusionTracker - err = service.Run(ctx, func(ctx context.Context, s service.Scope) error { + err = scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { // Create the generator from the config struct gen, err := generator.NewConfigBasedGenerator(cfg) if err != nil { return fmt.Errorf("failed to create generator: %w", err) } - // Create shared rate limiter for all workers if TPS is specified - var sharedLimiter *rate.Limiter + // Create the shared rate authority for the whole run. + sharedLimiter := rate.NewLimiter(rate.Inf, 1) if cfg.Settings.TPS > 0 { sharedLimiter = rate.NewLimiter(rate.Limit(cfg.Settings.TPS), 1) log.Printf("πŸ“ˆ Rate limiting enabled: %.2f TPS shared across all workers", cfg.Settings.TPS) - } else { - // No rate limiting - sharedLimiter = rate.NewLimiter(rate.Inf, 1) } // Create and start block collector if endpoints are available @@ -280,8 +277,16 @@ func runLoadTest(ctx context.Context, cmd *cobra.Command) error { }) } + // Open-loop owns the arrival clock in the scheduler, so the sender must + // not add a second finite gate. Prewarm and the scheduler still use the + // real shared limiter. + senderLimiter := sharedLimiter + if cfg.Settings.ArrivalModel == config.ArrivalModelOpenLoop && cfg.Settings.TxsDir == "" { + senderLimiter = rate.NewLimiter(rate.Inf, 1) + } + // Create the sender from the config struct - snd, err := sender.NewShardedSender(cfg, sharedLimiter, collector, inclusion) + snd, err := sender.NewShardedSender(cfg, senderLimiter, collector, inclusion) if err != nil { return fmt.Errorf("failed to create sender: %w", err) } @@ -344,7 +349,7 @@ func runLoadTest(ctx context.Context, cmd *cobra.Command) error { if cfg.Settings.TxsDir == "" { // Start the sender (starts all workers) s.SpawnBgNamed("sender", func() error { return snd.Run(ctx) }) - log.Printf("βœ… Connected to %d endpoints", snd.NumShards()) + log.Printf("βœ… Connected to %d endpoints", len(cfg.Endpoints)) } // Perform prewarming if enabled (before starting logger to avoid logging prewarm transactions) if cfg.Settings.Prewarm { diff --git a/sender/dispatcher.go b/sender/dispatcher.go index 69a6d4d..f21ae14 100644 --- a/sender/dispatcher.go +++ b/sender/dispatcher.go @@ -35,8 +35,8 @@ type Dispatcher struct { prewarmGen utils.Option[generator.Generator] // Optional prewarm generator sender TxSender - // Open-loop arrival configuration. arrivalModel defaults to closed-loop; - // limiter and maxInFlight are only consulted in open-loop mode. + // Open-loop arrival configuration. arrivalModel defaults to closed-loop. + // limiter is always present; open-loop additionally consults maxInFlight. arrivalModel ArrivalModel limiter *rate.Limiter maxInFlight int @@ -56,6 +56,7 @@ func NewDispatcher(gen generator.Generator, sender TxSender) *Dispatcher { generator: gen, sender: sender, arrivalModel: ArrivalClosedLoop, + limiter: rate.NewLimiter(rate.Inf, 1), } } @@ -99,9 +100,8 @@ func (d *Dispatcher) SetPrewarmGenerator(prewarmGen generator.Generator) { func (d *Dispatcher) Prewarm(ctx context.Context) error { d.mu.RLock() prewarmGen := d.prewarmGen - // Prewarm runs over the workers before the scheduler paces anything, so in - // open-loop (ungated workers) it must self-pace off the shared limiter or it - // floods the SUT. Nil in closed-loop, where the worker gates instead. + // Prewarm runs before the scheduler paces anything, so it must self-pace off + // the shared limiter or it floods the SUT. limiter := d.limiter d.mu.RUnlock() @@ -116,10 +116,8 @@ func (d *Dispatcher) Prewarm(ctx context.Context) error { // Run prewarm generator until completion for ctx.Err() == nil { - if limiter != nil { - if err := limiter.Wait(ctx); err != nil { - return err - } + if err := limiter.Wait(ctx); err != nil { + return err } tx, ok := gen.Generate() diff --git a/sender/doc.go b/sender/doc.go index 2d3f259..7bcc187 100644 --- a/sender/doc.go +++ b/sender/doc.go @@ -2,8 +2,8 @@ // // The send path is a pipeline: a [generator.Generator] produces transactions; // the [Dispatcher] times their arrival and hands each off to a [TxSender]; the -// [ShardedSender] routes each tx to one of N per-endpoint [Worker]s by shard; -// the worker's send loop stamps the attempt and calls the go-ethereum client +// [ShardedSender] routes each tx to one of N per-endpoint [ethClient]s by shard; +// the sender loop stamps the attempt and calls the go-ethereum client // (eth_sendRawTransaction). Inclusion, when tracked, is observed by the // block-indexed [stats.InclusionTracker] (see Inclusion stage below), not by // per-tx receipt polling. A shared [golang.org/x/time/rate.Limiter] is @@ -39,9 +39,9 @@ // scheduled instant the scheduler tries to acquire a permit without blocking: if // the senders are saturated the tick is dropped and counted, and the clock moves // on. The permit is not released at enqueue. The scheduler installs a release -// callback on tx.OnComplete, and the worker invokes it only after the -// synchronous send returns β€” note the two phases of the worker path: the enqueue -// into the worker's channel ([TxSender.Send]) is asynchronous and returns at +// callback on tx.OnComplete, and the sender invokes it only after the +// synchronous send returns β€” note the two phases of the send path: the enqueue +// into the sender's request channel ([TxSender.Send]) is asynchronous and returns at // once, but the RPC send itself is synchronous. So the permit is held for the // full unacked-in-flight window (enqueue plus RPC round-trip), and maxInFlight // bounds real in-flight work while the drop count measures genuine load shed, @@ -71,7 +71,7 @@ // // Shutdown boundary (accepted, not drift). admitted == succeeded + failed holds // on a clean drain (generator exhaustion). On ctx cancel (SIGTERM/duration), -// admitted txs still buffered for a worker exit uncounted; the undercount is +// admitted txs still buffered for a sender exit uncounted; the undercount is // bounded by the channel backlog and never affects a cleanly completed run. // // LoadTx lifecycle and scheduling. The scheduling-relevant fields of [types.LoadTx] @@ -79,7 +79,7 @@ // goroutine that solely owns the tx at that stage, then is immutable as ownership // transfers with the pointer across channels. The scheduler stamps IntendedSendTime // (the true scheduled instant tβ‚€ + i/Ξ») and SequenceIndex (the arrival index i) -// before hand-off; the worker stamps AttemptedSendTime at the real send. A tx +// before hand-off; the sender stamps AttemptedSendTime at the real send. A tx // cannot self-describe which model produced it β€” an open-loop and a closed-loop // tx are byte-identical β€” so coordinated-omission safety is a property of the // run's arrival model, not of any per-tx field. Latency and schedule-lag consumers @@ -87,7 +87,7 @@ // // # Inclusion stage // -// When enabled (--track-receipts), the worker hands each successful send to the +// When enabled (--track-receipts), the sender hands each successful send to the // [stats.InclusionTracker] at send-completion (after OnComplete, only on a nil // send error). The tracker subscribes to new heads, fetches each arriving // block's body once (O(blocks), not O(txs)), and stamps InclusionTime on every @@ -100,7 +100,7 @@ // registered βŠ† succeeded (only successful sends are registered). The inclusion // denominator is succeeded (txs_accepted), never a minted "registered" series; // dropped_at_cap txs are excluded from it. inflight_at_shutdown is read only -// after both the workers and the tracker have joined. +// after both the senders and the tracker have joined. // // Accepted boundaries. (1) WS gaps degrade conservatively: a missed head is // counted (block_gaps) but never backfilled, so its txs reap as expired β€” diff --git a/sender/eth_client.go b/sender/eth_client.go new file mode 100644 index 0000000..796b071 --- /dev/null +++ b/sender/eth_client.go @@ -0,0 +1,182 @@ +package sender + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "time" + + "github.com/ethereum/go-ethereum/ethclient" + "github.com/ethereum/go-ethereum/rpc" + "github.com/sei-protocol/sei-load/stats" + "github.com/sei-protocol/sei-load/types" + "github.com/sei-protocol/sei-load/utils" + "github.com/sei-protocol/sei-load/utils/scope" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" +) + +var tracer = otel.Tracer("github.com/sei-protocol/sei-load/sender") + +type sendReq struct { + tx *types.LoadTx + done chan error +} + +type ethClientConfig struct { + ChainID string + ID int + Endpoint string + Tasks int + Debug bool + DryRun bool + Collector *stats.Collector + Inclusion utils.Option[*stats.InclusionTracker] +} + +type ethClient struct { + cfg *ethClientConfig + reqs chan sendReq +} + +func (c *ethClient) Run(ctx context.Context) error { + u, err := url.Parse(c.cfg.Endpoint) + if err != nil { + return fmt.Errorf("parse endpoint %q: %w", c.cfg.Endpoint, err) + } + var opts []rpc.ClientOption + switch u.Scheme { + case "http", "https": + opts = append(opts, rpc.WithHTTPClient(newHttpClient())) + } + rpcClient, err := rpc.DialOptions(ctx, c.cfg.Endpoint, opts...) + if err != nil { + return fmt.Errorf("rpc.Dial(%q): %w", c.cfg.Endpoint, err) + } + client := ethclient.NewClient(rpcClient) + defer client.Close() + return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + for range c.cfg.Tasks { + s.Spawn(func() error { return c.runSender(ctx, client) }) + } + return nil + }) +} + +// newHttpClient returns an otelhttp-wrapped client: injects traceparent on +// outbound, emits http.client.* metrics. Requires observability.Setup to have +// installed the global TextMapPropagator. +func newHttpClient() *http.Client { + t := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: 500, + MaxIdleConnsPerHost: 50, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DisableKeepAlives: false, + } + return &http.Client{ + Timeout: 30 * time.Second, + Transport: otelhttp.NewTransport(t), + } +} + +func newEthClient(cfg *ethClientConfig) *ethClient { + return ðClient{ + cfg: cfg, + reqs: make(chan sendReq), + } +} + +// Send queues a transaction for this endpoint client to process. +func (c *ethClient) Send(ctx context.Context, tx *types.LoadTx) error { + done := make(chan error, 1) + if err := utils.Send(ctx, c.reqs, sendReq{tx, done}); err != nil { + return err + } + err, recvErr := utils.Recv(ctx, done) + if recvErr != nil { + return recvErr + } + return err +} + +// runSender handles the tx send requests. +func (c *ethClient) runSender(ctx context.Context, client *ethclient.Client) error { + for ctx.Err() == nil { + req, err := utils.Recv(ctx, c.reqs) + if err != nil { + return err + } + + startTime := time.Now() + // This goroutine solely owns tx between dequeue and the sentTxs hand-off, + // so stamping the actual send-attempt time here is race-free (see LoadTx). + req.tx.AttemptedSendTime = startTime + err = c.sendTx(ctx, client, req.tx) + if req.tx.OnComplete != nil { + req.tx.OnComplete(err) + } + req.done <- err + c.cfg.Collector.RecordTransaction(req.tx.Scenario.Name, c.cfg.Endpoint, time.Since(startTime), err == nil) + if err == nil { + if t, ok := c.cfg.Inclusion.Get(); ok { + t.Register(req.tx) + } + } + } + return ctx.Err() +} + +func (c *ethClient) sendTx(ctx context.Context, eth *ethclient.Client, tx *types.LoadTx) (_err error) { + ctx, span := tracer.Start(ctx, "sender.send_tx", trace.WithAttributes( + attribute.String("seiload.scenario", tx.Scenario.Name), + attribute.String("seiload.endpoint", c.cfg.Endpoint), + attribute.Int("seiload.worker_id", c.cfg.ID), + attribute.String("seiload.chain_id", c.cfg.ChainID), + )) + defer func(start time.Time) { + if _err != nil { + span.RecordError(_err) + } + span.End() + // Record inside the span ctx so exemplars link to the trace. + sendLatency.Record(ctx, time.Since(start).Seconds(), + metric.WithAttributes( + attribute.String("scenario", tx.Scenario.Name), + attribute.String("endpoint", c.cfg.Endpoint), + attribute.String("chain_id", c.cfg.ChainID), + statusAttrFromError(_err)), + ) + }(time.Now()) + if c.cfg.DryRun { + // In dry-run mode, simulate processing time and mark as successful + // Use very minimal delay to avoid channel overflow + return utils.Sleep(ctx, 10*time.Microsecond) // Much faster simulation + } + + // Send through go-ethereum so the same code path supports both HTTP(S) and WS(S) RPC. + if err := eth.SendTransaction(ctx, tx.EthTx); err != nil { + txsRejected.Add(ctx, 1, metric.WithAttributes( + attribute.String("endpoint", c.cfg.Endpoint), + attribute.String("scenario", tx.Scenario.Name), + attribute.String("reason", "rpc"), + )) + return fmt.Errorf("eth.SendTransaction(): %w", err) + } + + txsAccepted.Add(ctx, 1, metric.WithAttributes( + attribute.String("endpoint", c.cfg.Endpoint), + attribute.String("scenario", tx.Scenario.Name), + )) + return nil +} diff --git a/sender/eth_client_test.go b/sender/eth_client_test.go new file mode 100644 index 0000000..cad33ec --- /dev/null +++ b/sender/eth_client_test.go @@ -0,0 +1,169 @@ +package sender + +import ( + "context" + "crypto/sha256" + "math/big" + "net/http" + "net/http/httptest" + "slices" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + ethtypes "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/rpc" + "github.com/sei-protocol/sei-load/stats" + "github.com/sei-protocol/sei-load/types" + "github.com/sei-protocol/sei-load/utils" + "github.com/sei-protocol/sei-load/utils/scope" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + sdktrace "go.opentelemetry.io/otel/sdk/trace" +) + +func TestEthClientSendTx_HTTP(t *testing.T) { + api := newMockEthAPI() + srv := rpc.NewServer() + require.NoError(t, srv.RegisterName("eth", api)) + + // We check the TraceID as a proof that otel Transport was used. + var traceparent string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + traceparent = r.Header.Get("traceparent") + srv.ServeHTTP(w, r) + })) + defer ts.Close() + otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + )) + otel.SetTracerProvider(sdktrace.NewTracerProvider()) + ctx, span := otel.Tracer("sender-test").Start(t.Context(), "parent") + defer span.End() + + tx := testLoadTx(t) + client := newEthClient(ðClientConfig{ + ChainID: "test-chain", + ID: 7, + Endpoint: ts.URL, + Tasks: 1, + Collector: stats.NewCollector(), + }) + + err := scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.SpawnBg(func() error { return utils.IgnoreCancel(client.Run(ctx)) }) + return client.Send(ctx, tx) + }) + require.NoError(t, err) + require.Equal(t, [][]byte{tx.Payload}, api.RawTransactions()) + require.Contains(t, traceparent, span.SpanContext().TraceID().String()) +} + +func TestEthClientSendTx_WS(t *testing.T) { + api := newMockEthAPI() + srv := rpc.NewServer() + require.NoError(t, srv.RegisterName("eth", api)) + + ts := httptest.NewServer(srv.WebsocketHandler([]string{"*"})) + defer ts.Close() + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + tx := testLoadTx(t) + client := newEthClient(ðClientConfig{ + ChainID: "test-chain", + ID: 8, + Endpoint: wsURL, + Tasks: 1, + Collector: stats.NewCollector(), + }) + + err := scope.Run(t.Context(), func(ctx context.Context, s scope.Scope) error { + s.SpawnBg(func() error { return utils.IgnoreCancel(client.Run(ctx)) }) + return client.Send(ctx, tx) + }) + require.NoError(t, err) + require.Equal(t, [][]byte{tx.Payload}, api.RawTransactions()) +} + +func TestEthClientRunSender_RegistersSuccessfulSendAfterOnComplete(t *testing.T) { + tracker := stats.NewInclusionTracker("test-chain", time.Hour, 100, true) + client := newEthClient(ðClientConfig{ + ChainID: "test-chain", + ID: 9, + Endpoint: "dryrun", + Tasks: 1, + DryRun: true, + Collector: stats.NewCollector(), + Inclusion: utils.Some(tracker), + }) + + tx := testLoadTx(t) + var inflightAtComplete atomic.Uint64 + tx.OnComplete = func(error) { + inflightAtComplete.Store(tracker.Summary().InflightAtShutdown) + } + + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + defer cancel() + + errCh := make(chan error, 1) + go func() { errCh <- client.runSender(ctx, nil) }() + + require.NoError(t, client.Send(ctx, tx)) + cancel() + require.ErrorIs(t, <-errCh, context.Canceled) + require.Zero(t, inflightAtComplete.Load(), "inclusion must register after OnComplete") + require.Equal(t, uint64(1), tracker.Summary().InflightAtShutdown, "successful send must register exactly once") +} + +type mockEthAPI struct { + rawTxs utils.Mutex[*[][]byte] +} + +func newMockEthAPI() *mockEthAPI { + rawTxs := [][]byte{} + return &mockEthAPI{rawTxs: utils.NewMutex(&rawTxs)} +} + +func (m *mockEthAPI) SendRawTransaction(_ context.Context, rawTx hexutil.Bytes) (common.Hash, error) { + for rawTxs := range m.rawTxs.Lock() { + *rawTxs = append(*rawTxs, slices.Clone(rawTx)) + } + sum := sha256.Sum256(rawTx) + return common.BytesToHash(sum[:]), nil +} + +func (m *mockEthAPI) RawTransactions() [][]byte { + for rawTxs := range m.rawTxs.Lock() { + return slices.Clone(*rawTxs) + } + panic("unreachable") +} + +func testLoadTx(t *testing.T) *types.LoadTx { + t.Helper() + + account, err := types.NewAccount() + require.NoError(t, err) + + to := common.HexToAddress("0x0000000000000000000000000000000000000001") + tx := ethtypes.NewTx(ðtypes.LegacyTx{ + Nonce: 1, + To: &to, + Value: big.NewInt(1), + Gas: 21_000, + GasPrice: big.NewInt(1), + }) + signedTx, err := ethtypes.SignTx(tx, ethtypes.LatestSignerForChainID(big.NewInt(1)), account.PrivKey) + require.NoError(t, err) + + return types.CreateTxFromEthTx(signedTx, &types.TxScenario{ + Name: "test-scenario", + Sender: account, + }) +} diff --git a/sender/metrics.go b/sender/metrics.go index 0fe8fd5..e433290 100644 --- a/sender/metrics.go +++ b/sender/metrics.go @@ -2,8 +2,8 @@ package sender import ( "context" - "sync" + "github.com/sei-protocol/sei-load/utils" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -17,18 +17,18 @@ var meter = otel.Meter("github.com/sei-protocol/sei-load/sender") // Synchronous instruments β€” read by Record/Add call sites. var ( - sendLatency = must(meter.Float64Histogram( + sendLatency = utils.OrPanic1(meter.Float64Histogram( "send_latency", metric.WithDescription("Latency of sending transactions in seconds"), metric.WithUnit("s"), metric.WithExplicitBucketBoundaries(0.1, 0.2, 0.3, 0.5, 1.0, 2.0, 3.0, 5.0, 10.0, 20.0))) - txsAccepted = must(meter.Int64Counter( + txsAccepted = utils.OrPanic1(meter.Int64Counter( "txs_accepted", metric.WithDescription("Transactions successfully submitted to an endpoint"), metric.WithUnit("{transactions}"))) - txsRejected = must(meter.Int64Counter( + txsRejected = utils.OrPanic1(meter.Int64Counter( "txs_rejected", metric.WithDescription("Transactions rejected by the target or local client, by reason"), metric.WithUnit("{transactions}"))) @@ -38,62 +38,67 @@ var ( // Return values are discarded because OTel invokes the callbacks on each // collection; we never read the instrument handles. func init() { - must(meter.Int64ObservableGauge( + utils.OrPanic1(meter.Int64ObservableGauge( "worker_queue_length", metric.WithDescription("Length of the worker's queue"), metric.WithUnit("{count}"), metric.WithInt64Callback(func(ctx context.Context, observer metric.Int64Observer) error { - meteredChainWorkers.lock.RLock() - defer meteredChainWorkers.lock.RUnlock() - for _, worker := range meteredChainWorkers.workers { - observer.Observe(int64(worker.ChannelLength()), metric.WithAttributes( - attribute.String("endpoint", worker.Endpoint()), - attribute.Int("worker_id", worker.cfg.ID), - attribute.String("chain_id", worker.cfg.SeiChainID), - )) + for _, ss := range meteredSenders.Get() { + for _, stats := range ss.ShardStats() { + observer.Observe(int64(stats.TxsQueued), metric.WithAttributes( + attribute.String("endpoint", stats.Endpoint), + attribute.Int("worker_id", stats.ID), + attribute.String("chain_id", stats.ChainID), + )) + } } return nil }))) - must(meter.Float64ObservableGauge( + utils.OrPanic1(meter.Float64ObservableGauge( "tps_achieved", metric.WithDescription("Most recent TPS sample observed by the sender, per endpoint/scenario"), metric.WithUnit("{transactions}/s"), metric.WithFloat64Callback(observeTPS))) } -// meteredChainWorkers is the registry the worker_queue_length callback reads. -var meteredChainWorkers = &chainWorkerObserver{ - workers: make(map[chainWorkerID]*Worker), +type Registry[T comparable] struct { + r utils.RWMutex[map[T]struct{}] } -type chainWorkerObserver struct { - lock sync.RWMutex - workers map[chainWorkerID]*Worker +func (r *Registry[T]) Get() []T { + for r := range r.r.RLock() { + var vs []T + for v := range r { + vs = append(vs, v) + } + return vs + } + panic("unreachable") } -type chainWorkerID struct { - workerID int - chainID string + +func NewRegistry[T comparable]() *Registry[T] { + return &Registry[T]{r: utils.NewRWMutex(map[T]struct{}{})} } -func meterWorkerQueueLength(worker *Worker) { - meteredChainWorkers.lock.Lock() - defer meteredChainWorkers.lock.Unlock() - id := chainWorkerID{ - workerID: worker.cfg.ID, - chainID: worker.cfg.SeiChainID, +func (r *Registry[T]) MustRegister(val T) (cancel func()) { + for r := range r.r.Lock() { + if _, ok := r[val]; ok { + panic("already registered") + } + r[val] = struct{}{} } - if _, exists := meteredChainWorkers.workers[id]; !exists { - meteredChainWorkers.workers[id] = worker + return func() { + for r := range r.r.Lock() { + delete(r, val) + } } } -var tpsObserverRegistry = struct { - lock sync.RWMutex - samples map[tpsSampleKey]float64 -}{ - samples: make(map[tpsSampleKey]float64), -} +// meteredChainWorkers is the registry the worker_queue_length callback reads. +var meteredSenders = NewRegistry[*ShardedSender]() + +var tpsObserverRegistry = utils.NewRWMutex(map[tpsSampleKey]float64{}) type tpsSampleKey struct { endpoint string @@ -103,20 +108,20 @@ type tpsSampleKey struct { // RecordTPSSample publishes the latest TPS sample read by the tps_achieved gauge. func RecordTPSSample(endpoint, chainID, scenario string, tps float64) { - tpsObserverRegistry.lock.Lock() - defer tpsObserverRegistry.lock.Unlock() - tpsObserverRegistry.samples[tpsSampleKey{endpoint, chainID, scenario}] = tps + for r := range tpsObserverRegistry.Lock() { + r[tpsSampleKey{endpoint, chainID, scenario}] = tps + } } func observeTPS(_ context.Context, observer metric.Float64Observer) error { - tpsObserverRegistry.lock.RLock() - defer tpsObserverRegistry.lock.RUnlock() - for k, v := range tpsObserverRegistry.samples { - observer.Observe(v, metric.WithAttributes( - attribute.String("endpoint", k.endpoint), - attribute.String("chain_id", k.chainID), - attribute.String("scenario", k.scenario), - )) + for r := range tpsObserverRegistry.RLock() { + for k, v := range r { + observer.Observe(v, metric.WithAttributes( + attribute.String("endpoint", k.endpoint), + attribute.String("chain_id", k.chainID), + attribute.String("scenario", k.scenario), + )) + } } return nil } @@ -128,10 +133,3 @@ func statusAttrFromError(err error) attribute.KeyValue { } return attribute.String(key, "failure") } - -func must[V any](v V, err error) V { - if err != nil { - panic(err) - } - return v -} diff --git a/sender/queue.go b/sender/queue.go new file mode 100644 index 0000000..33c1f2f --- /dev/null +++ b/sender/queue.go @@ -0,0 +1,88 @@ +package sender + +import ( + "context" + + "github.com/sei-protocol/sei-load/utils" +) + +type queueSlot struct { + id queueID + slot int +} + +type queueID int + +type queueState struct { + first int + next int +} + +type queuePoolState[T any] struct { + mem map[queueSlot]T + queues []queueState +} + +type QueuePool[T any] struct { + state utils.Mutex[*queuePoolState[T]] + size chan struct{} +} + +type Queue[T any] struct { + id queueID + pool *QueuePool[T] + size chan struct{} +} + +func (q *Queue[T]) Len() int { return len(q.size) } + +func NewQueuePool[T any](capacity int) *QueuePool[T] { + return &QueuePool[T]{ + state: utils.NewMutex(&queuePoolState[T]{ + mem: make(map[queueSlot]T, capacity), + }), + size: make(chan struct{}, capacity), + } +} + +func (p *QueuePool[T]) NewQueue() *Queue[T] { + for state := range p.state.Lock() { + id := queueID(len(state.queues)) + state.queues = append(state.queues, queueState{}) + return &Queue[T]{ + id: id, + pool: p, + size: make(chan struct{}, cap(p.size)), + } + } + panic("unreachable") +} + +func (q *Queue[T]) Send(ctx context.Context, v T) error { + if err := utils.Send(ctx, q.pool.size, struct{}{}); err != nil { + return err + } + for state := range q.pool.state.Lock() { + s := &state.queues[q.id] + state.mem[queueSlot{q.id, s.next}] = v + s.next += 1 + } + q.size <- struct{}{} + return nil +} + +func (q *Queue[T]) Recv(ctx context.Context) (T, error) { + if _, err := utils.Recv(ctx, q.size); err != nil { + return utils.Zero[T](), err + } + var res T + for state := range q.pool.state.Lock() { + s := &state.queues[q.id] + slot := queueSlot{q.id, s.first} + s.first += 1 + res = state.mem[slot] + delete(state.mem, slot) + } + <-q.pool.size + return res, nil +} diff --git a/sender/ramper.go b/sender/ramper.go index 209384c..9bf10fc 100644 --- a/sender/ramper.go +++ b/sender/ramper.go @@ -8,7 +8,7 @@ import ( "time" "github.com/sei-protocol/sei-load/stats" - "github.com/sei-protocol/sei-load/utils/service" + "github.com/sei-protocol/sei-load/utils/scope" "golang.org/x/time/rate" ) @@ -114,7 +114,7 @@ func (r *Ramper) WatchSLO(ctx context.Context) <-chan struct{} { // Start initializes and starts all workers func (r *Ramper) Run(ctx context.Context) error { - return service.Run(ctx, func(ctx context.Context, s service.Scope) error { + return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { // TODO: Implement ramping logic r.startTime = time.Now() sloChan := r.WatchSLO(ctx) diff --git a/sender/scheduler.go b/sender/scheduler.go index f4c1ade..0527c26 100644 --- a/sender/scheduler.go +++ b/sender/scheduler.go @@ -7,6 +7,7 @@ import ( "sync/atomic" "time" + "golang.org/x/sync/semaphore" "golang.org/x/time/rate" "github.com/sei-protocol/sei-load/generator" @@ -24,7 +25,7 @@ type openLoopScheduler struct { generator generator.Generator sender TxSender limiter *rate.Limiter - inflight *utils.Semaphore + inflight *semaphore.Weighted onSent func(tx *types.LoadTx, err error) maxInFlight int @@ -54,11 +55,14 @@ func newOpenLoopScheduler( maxInFlight int, onSent func(tx *types.LoadTx, err error), ) *openLoopScheduler { + if maxInFlight < 1 { + maxInFlight = 1 + } return &openLoopScheduler{ generator: gen, sender: snd, limiter: limiter, - inflight: utils.NewSemaphore(maxInFlight), + inflight: semaphore.NewWeighted(int64(maxInFlight)), onSent: onSent, maxInFlight: maxInFlight, } @@ -102,7 +106,7 @@ func (s *openLoopScheduler) Run(ctx context.Context, scope service.Scope) error // Admit before generating: a dropped tick must not consume a seeded // generator draw (determinism). TryAcquire is non-blocking. - release, ok := s.inflight.TryAcquire() + ok := s.inflight.TryAcquire(1) if !ok { s.dropped.Add(1) nextSend = nextSend.Add(gap) @@ -113,7 +117,7 @@ func (s *openLoopScheduler) Run(ctx context.Context, scope service.Scope) error tx, ok := s.generator.Generate() if !ok { // Generator drained: not an arrival β€” release the permit and stop. - release() + s.inflight.Release(1) log.Print("Scheduler: generator returned no more transactions") return nil } @@ -131,7 +135,7 @@ func (s *openLoopScheduler) Run(ctx context.Context, scope service.Scope) error var once sync.Once complete := func(err error) { once.Do(func() { - release() + s.inflight.Release(1) if s.onSent != nil { s.onSent(tx, err) } diff --git a/sender/scheduler_realworker_test.go b/sender/scheduler_realworker_test.go index 63ac9a2..7df3519 100644 --- a/sender/scheduler_realworker_test.go +++ b/sender/scheduler_realworker_test.go @@ -24,13 +24,13 @@ import ( // This file is the production-path safety net for the open-loop in-flight bound. // // Every other scheduler test drives a FAKE TxSender that invokes tx.OnComplete -// itself, so the suite would stay green even if the real Worker forgot the -// `if tx.OnComplete != nil { tx.OnComplete(err) }` line in runTxSender β€” the one +// itself, so the suite would stay green even if the real ethClient forgot the +// `if tx.OnComplete != nil { tx.OnComplete(err) }` line in runSender β€” the one // load-bearing line that makes the maxInFlight semaphore bound true unacked // sends rather than nothing (permits would never be released β†’ leak/meaningless -// bound). The tests here wire the REAL Worker (runTxSender β†’ sendTransaction β†’ +// bound). The tests here wire the REAL ethClient (runSender β†’ sendTx β†’ // the real ethclient β†’ OnComplete) behind the scheduler and assert the permit -// is genuinely released by the worker on send completion. +// is genuinely released by the sender on send completion. // // Harness: an httptest.Server speaking the minimal JSON-RPC the ethclient send // path touches. SendTransaction issues exactly one eth_sendRawTransaction call @@ -215,34 +215,31 @@ func (g *signedTxGenerator) issuedCount() int { return g.issued } -// newRealWorker builds the production Worker against the given endpoint, in the -// open-loop configuration (SkipRateLimit=true so the scheduler owns the clock, -// no inclusion tracker so we exercise only the send path). It is the real +// newRealSender builds the production ethClient against the given endpoint. It +// is the real // TxSender the scheduler drives. -func newRealWorker(endpoint string, tasks, buffer int) *Worker { - return NewWorker(&WorkerConfig{ - ID: 0, - SeiChainID: "test", - Endpoint: endpoint, - BufferSize: buffer, - Tasks: tasks, - DryRun: false, - Debug: false, - Collector: stats.NewCollector(), - SkipRateLimit: true, +func newRealSender(endpoint string, tasks int) *ethClient { + return newEthClient(ðClientConfig{ + ChainID: "test", + ID: 0, + Endpoint: endpoint, + Tasks: tasks, + DryRun: false, + Debug: false, + Collector: stats.NewCollector(), }) } -// TestRealWorker_Conservation_OnRealSendPath asserts conservation +// TestRealSender_Conservation_OnRealSendPath asserts conservation // (issued == completed + dropped) where `completed` is driven exclusively by the -// REAL worker invoking tx.OnComplete after sendTransaction returns β€” not by a -// fake. If runTxSender stopped calling OnComplete, completed would stall and +// REAL sender invoking tx.OnComplete after sendTx returns β€” not by a fake. If +// runSender stopped calling OnComplete, completed would stall and // this would fail. -func TestRealWorker_Conservation_OnRealSendPath(t *testing.T) { +func TestRealSender_Conservation_OnRealSendPath(t *testing.T) { const txCount = 200 srv := newRPCServer(t) gen := newSignedTxGenerator(t, txCount) - worker := newRealWorker(srv.url(), 8, 256) + client := newRealSender(srv.url(), 8) var completed, succeeded atomic.Uint64 onSent := func(_ *types.LoadTx, err error) { @@ -253,26 +250,26 @@ func TestRealWorker_Conservation_OnRealSendPath(t *testing.T) { } limiter := rate.NewLimiter(rate.Limit(2000), 1) - sched := newOpenLoopScheduler(gen, worker, limiter, 256, onSent) + sched := newOpenLoopScheduler(gen, client, limiter, 256, onSent) ctx, cancel := context.WithCancel(t.Context()) defer cancel() - // Run the worker and scheduler in a scope whose teardown WE control via + // Run the sender and scheduler in a scope whose teardown WE control via // runCancel β€” not the scheduler's return. // // service.Run cancels the scope's context as soon as every MAIN task returns. // If the scheduler were a main task, the instant it exhausts the generator and - // returns, service.Run would cancel the worker's context β€” aborting any send + // returns, service.Run would cancel the sender's context β€” aborting any send // still in flight. A send whose 200 OK the server already counted (handled++) // but whose client.SendTransaction had not yet returned would then fail with // context-canceled: OnComplete fires with err != nil, so completed++ but NOT // succeeded++. That is exactly the observed flake (handled=200, succeeded=199): // not a sampling artifact but a teardown that races the last in-flight send. // - // So the scheduler and worker are BACKGROUND tasks, and the lone MAIN task is a + // So the scheduler and sender are BACKGROUND tasks, and the lone MAIN task is a // gate that blocks until the test calls runCancel(). The scope therefore stays - // alive β€” the worker keeps draining txChan and firing OnComplete β€” until the + // alive β€” the sender keeps draining reqs and firing OnComplete β€” until the // test has observed quiescence and torn down deliberately. runCtx, runCancel := context.WithCancel(ctx) var wg sync.WaitGroup @@ -280,7 +277,7 @@ func TestRealWorker_Conservation_OnRealSendPath(t *testing.T) { go func() { defer wg.Done() _ = service.Run(runCtx, func(ctx context.Context, scope service.Scope) error { - scope.SpawnBg(func() error { return worker.Run(ctx) }) + scope.SpawnBg(func() error { return client.Run(ctx) }) scope.SpawnBg(func() error { return sched.Run(ctx, scope) }) // Main task: hold the scope open until the test signals teardown. <-ctx.Done() @@ -297,12 +294,12 @@ func TestRealWorker_Conservation_OnRealSendPath(t *testing.T) { // generator is drained and dropped ticks consumed no draw; // the precondition that makes the fixpoint below stable) // conservation: completed == Admitted() (every admitted tx reached a - // terminal state via the real worker's OnComplete) + // terminal state via the real sender's OnComplete) // equality: succeeded == handled (every server-handled send // produced exactly one successful worker-driven completion) // // conservation and equality are transiently off WHILE a send is in flight: the - // server bumps `handled` when it RECEIVES eth_sendRawTransaction, but the worker + // server bumps `handled` when it RECEIVES eth_sendRawTransaction, but the sender // bumps `succeeded` only AFTER SendTransaction returns and OnComplete fires β€” the // instants differ by the serverβ†’worker-return window. Sampling any of them alone, // or at different instants, can catch that window. Requiring all three together, @@ -312,7 +309,7 @@ func TestRealWorker_Conservation_OnRealSendPath(t *testing.T) { // point. (The deeper hazard the gate above fixes is teardown racing that same // window; here we additionally refuse to read until the window is empty.) // - // Driven by the real worker's OnComplete β€” a missing invoke leaves completed + // Driven by the real sender's OnComplete β€” a missing invoke leaves completed // (and succeeded) short forever, so convergence never happens and the test // fails on the Eventually deadline. CI is slow, so the window is generous; // correctness depends on convergence, not on the deadline firing. @@ -336,39 +333,39 @@ func TestRealWorker_Conservation_OnRealSendPath(t *testing.T) { require.Equal(t, total, uint64(gen.issuedCount()), "the generator must be fully drained") require.Positive(t, srv.handled.Load(), "the real RPC server must have handled sends") require.Equal(t, sched.Admitted(), completed.Load(), - "every admitted tx must reach a terminal state via the worker's OnComplete") + "every admitted tx must reach a terminal state via the sender's OnComplete") require.Equal(t, succeeded.Load(), srv.handled.Load(), "each successful completion must correspond to one eth_sendRawTransaction") } -// TestRealWorker_PermitReleasedByWorker is the teeth: with maxInFlight=1 and a -// single worker task, the RPC server blocks the first send. The real worker is -// parked inside sendTransaction, so it has NOT yet called tx.OnComplete and the +// TestRealSender_PermitReleasedBySender is the teeth: with maxInFlight=1 and a +// single sender task, the RPC server blocks the first send. The real sender is +// parked inside sendTx, so it has NOT yet called tx.OnComplete and the // single permit stays held β€” every subsequent arrival must drop. Releasing the -// blocked send lets the worker return from sendTransaction, fire OnComplete, and +// blocked send lets the sender return from sendTx, fire OnComplete, and // free the permit, so flow resumes. // // If someone deletes the `if tx.OnComplete != nil { tx.OnComplete(err) }` invoke -// in runTxSender, the permit is never released even after the send completes: -// the worker would never accept a second tx, so handled stays at 1 and the +// in runSender, the permit is never released even after the send completes: the +// sender would never accept a second tx, so handled stays at 1 and the // resume assertion fails. That is the falsification this test exists for. -func TestRealWorker_PermitReleasedByWorker(t *testing.T) { +func TestRealSender_PermitReleasedBySender(t *testing.T) { srv := newRPCServer(t) srv.setBlocking(true) // Plenty of arrivals so the scheduler keeps offering txs while the first is // parked; the surplus must drop because the lone permit is held. gen := newSignedTxGenerator(t, 1000) - // One task: a single runTxSender owns the only permit's lifecycle, so the - // permit can only be freed by that worker calling OnComplete. - worker := newRealWorker(srv.url(), 1, 1) + // One task: a single runSender owns the only permit's lifecycle, so the + // permit can only be freed by that sender calling OnComplete. + client := newRealSender(srv.url(), 1) var completed atomic.Uint64 onSent := func(_ *types.LoadTx, _ error) { completed.Add(1) } // Fast arrival clock so many txs are offered during the blocked window. limiter := rate.NewLimiter(rate.Limit(5000), 1) // 0.2ms gap - sched := newOpenLoopScheduler(gen, worker, limiter, 1, onSent) + sched := newOpenLoopScheduler(gen, client, limiter, 1, onSent) ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() @@ -378,7 +375,7 @@ func TestRealWorker_PermitReleasedByWorker(t *testing.T) { go func() { defer wg.Done() _ = service.Run(ctx, func(ctx context.Context, scope service.Scope) error { - scope.SpawnBg(func() error { return worker.Run(ctx) }) + scope.SpawnBg(func() error { return client.Run(ctx) }) return sched.Run(ctx, scope) }) }() @@ -386,7 +383,7 @@ func TestRealWorker_PermitReleasedByWorker(t *testing.T) { // Wait until exactly one send is genuinely in flight (parked in the handler). <-srv.arrived - // While that send is parked, the worker has not fired OnComplete, so the lone + // While that send is parked, the sender has not fired OnComplete, so the lone // permit is held. Give the fast scheduler time to offer (and drop) a slew of // arrivals, then assert the bound held: exactly one send in flight, none // completed yet, and the rest dropped. @@ -402,7 +399,7 @@ func TestRealWorker_PermitReleasedByWorker(t *testing.T) { require.Equal(t, uint64(0), srv.handled.Load(), "the parked send has not returned a result yet, so the permit is still held") - // Release the blocked send. The real worker now returns from sendTransaction + // Release the blocked send. The real sender now returns from sendTx // and MUST invoke tx.OnComplete to free the permit. If it does not (the bug), // no further send is ever admitted and handled stays at 1 forever. require.True(t, srv.releaseOne(), "one send must be parked and releasable") @@ -427,7 +424,7 @@ func TestRealWorker_PermitReleasedByWorker(t *testing.T) { require.Eventually(t, func() bool { return srv.handled.Load() > 1 && completed.Load() > 1 }, 3*time.Second, 2*time.Millisecond, - "flow must resume after the worker releases the permit via OnComplete "+ + "flow must resume after the sender releases the permit via OnComplete "+ "(handled=%d completed=%d)", srv.handled.Load(), completed.Load()) // Flow resumed: at least one further send completed after the release, so the @@ -447,7 +444,7 @@ func TestRealWorker_PermitReleasedByWorker(t *testing.T) { require.Equal(t, admitted, uint64(gen.issuedCount()), "every generator draw must be an admitted tx; dropped ticks consume no draw") - // Each admitted tx completes exactly once via the worker's OnComplete, so + // Each admitted tx completes exactly once via the sender's OnComplete, so // completed must equal Admitted() β€” minus at most the single tx left mid-flight // when cancel raced its in-flight send. require.LessOrEqual(t, completed.Load(), admitted, "no admitted tx may complete more than once") @@ -457,9 +454,9 @@ func TestRealWorker_PermitReleasedByWorker(t *testing.T) { } // TestDispatcher_PrewarmRateLimitedInOpenLoop guards the prewarm-flood -// regression: in open-loop the workers are constructed SkipRateLimit=true, but +// regression: in open-loop the sender loop is ungated, but // the scheduler paces only the MAIN load. Prewarm runs first over those same -// ungated workers, so it must pace itself off the shared limiter or it floods +// ungated senders, so it must pace itself off the shared limiter or it floods // the SUT. With workers wired exactly as in open-loop, a low limit, and many // more prewarm txs than the worker pool could absorb instantly, an unpaced // prewarm would drain in well under the limiter's minimum span. We assert the @@ -470,7 +467,7 @@ func TestDispatcher_PrewarmRateLimitedInOpenLoop(t *testing.T) { const prewarmTxs = 40 const rps = 200.0 // limiter: 200 tx/s β†’ unpaced 40 txs is near-instant - worker := newRealWorker(srv.url(), 8, 256) // SkipRateLimit=true, open-loop shape + client := newRealSender(srv.url(), 8) ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() @@ -479,14 +476,14 @@ func TestDispatcher_PrewarmRateLimitedInOpenLoop(t *testing.T) { go func() { defer wg.Done() _ = service.Run(ctx, func(ctx context.Context, scope service.Scope) error { - scope.SpawnBg(func() error { return worker.Run(ctx) }) + scope.SpawnBg(func() error { return client.Run(ctx) }) <-ctx.Done() return nil }) }() limiter := rate.NewLimiter(rate.Limit(rps), 1) - d := NewDispatcher(newSignedTxGenerator(t, 0), worker) + d := NewDispatcher(newSignedTxGenerator(t, 0), client) d.SetOpenLoop(limiter, 256) // sets d.limiter so Prewarm self-paces d.SetPrewarmGenerator(newSignedTxGenerator(t, prewarmTxs)) diff --git a/sender/sharded_sender.go b/sender/sharded_sender.go index dafdc97..60f903c 100644 --- a/sender/sharded_sender.go +++ b/sender/sharded_sender.go @@ -3,6 +3,7 @@ package sender import ( "context" "fmt" + "log" "golang.org/x/time/rate" @@ -10,83 +11,108 @@ import ( "github.com/sei-protocol/sei-load/stats" "github.com/sei-protocol/sei-load/types" "github.com/sei-protocol/sei-load/utils" - "github.com/sei-protocol/sei-load/utils/service" + "github.com/sei-protocol/sei-load/utils/scope" ) // ShardedSender implements TxSender with multiple workers, one per endpoint type ShardedSender struct { - workers []*Worker + cfg *config.LoadConfig + limiter *rate.Limiter // Shared rate limiter for transaction sending + clients []*ethClient + shards []*Queue[*types.LoadTx] } -// NewShardedSender creates a new sharded sender with workers for each endpoint. -// inclusion, when present, is shared across all workers so each routes its -// successful sends to the one tracker. +// NewShardedSender creates a new sharded sender. +// Txs of each shard are sent sequentially, using a single eth client. func NewShardedSender(cfg *config.LoadConfig, limiter *rate.Limiter, collector *stats.Collector, inclusion utils.Option[*stats.InclusionTracker]) (*ShardedSender, error) { if len(cfg.Endpoints) == 0 { return nil, fmt.Errorf("no endpoints configured") } - - // Open-loop lets the scheduler own the arrival clock (see doc.go), so the - // worker skips gating to avoid double-throttling; closed-loop keeps it. - skipRateLimit := cfg.Settings.ArrivalModel == config.ArrivalModelOpenLoop - - workers := make([]*Worker, len(cfg.Endpoints)) - for i, endpoint := range cfg.Endpoints { - workers[i] = NewWorker(&WorkerConfig{ - ID: i, - SeiChainID: cfg.SeiChainID, - Endpoint: endpoint, - BufferSize: cfg.Settings.BufferSize, - Tasks: cfg.Settings.TasksPerEndpoint, - DryRun: cfg.Settings.DryRun, - Debug: cfg.Settings.Debug, - Collector: collector, - Limiter: limiter, - SkipRateLimit: skipRateLimit, - Inclusion: inclusion, - }) + numShards := cfg.GetNumShards() + if numShards <= 0 { + return nil, fmt.Errorf("no shards configured") + } + totalQueueSize := cfg.TotalQueueSize() + if totalQueueSize <= 0 { + return nil, fmt.Errorf("queue size has to be positive") + } + var clients []*ethClient + for id, endpoint := range cfg.Endpoints { + clients = append(clients, newEthClient(ðClientConfig{ + ChainID: cfg.SeiChainID, + ID: id, + Endpoint: endpoint, + Tasks: cfg.Settings.TasksPerEndpoint, + DryRun: cfg.Settings.DryRun, + Debug: cfg.Settings.Debug, + Collector: collector, + Inclusion: inclusion, + })) } + pool := NewQueuePool[*types.LoadTx](totalQueueSize) + var shards []*Queue[*types.LoadTx] + for range numShards { + shards = append(shards, pool.NewQueue()) + } + return &ShardedSender{ + cfg: cfg, + limiter: limiter, + clients: clients, + shards: shards, + }, nil +} - return &ShardedSender{workers: workers}, nil +// Send implements TxSender interface - calculates shard ID and routes to appropriate worker +func (s *ShardedSender) Send(ctx context.Context, tx *types.LoadTx) error { + return s.shards[tx.ShardID(len(s.shards))].Send(ctx, tx) } // Start initializes and starts all workers -func (s *ShardedSender) Run(ctx context.Context) error { - return service.Run(ctx, func(ctx context.Context, scope service.Scope) error { - for _, worker := range s.workers { - scope.Spawn(func() error { return worker.Run(ctx) }) +func (ss *ShardedSender) Run(ctx context.Context) error { + cancel := meteredSenders.MustRegister(ss) + defer cancel() + return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + for _, client := range ss.clients { + s.Spawn(func() error { return client.Run(ctx) }) + } + for i, shard := range ss.shards { + s.Spawn(func() error { + client := ss.clients[i%len(ss.clients)] + for ctx.Err() == nil { + tx, err := shard.Recv(ctx) + if err != nil { + return err + } + if err := ss.limiter.Wait(ctx); err != nil { + return err + } + if err := client.Send(ctx, tx); err != nil { + log.Printf("%v", err) + } + } + return ctx.Err() + }) } return nil }) } -// Send implements TxSender interface - calculates shard ID and routes to appropriate worker -func (s *ShardedSender) Send(ctx context.Context, tx *types.LoadTx) error { - // Calculate shard ID based on the transaction - shardID := tx.ShardID(len(s.workers)) - // Send to the appropriate worker - return s.workers[shardID].Send(ctx, tx) +type ShardStats struct { + ChainID string + ID int + Endpoint string + TxsQueued int } -// GetWorkerStats returns statistics for all workers -func (s *ShardedSender) GetWorkerStats() []WorkerStats { - stats := make([]WorkerStats, len(s.workers)) - for i, worker := range s.workers { - stats[i] = WorkerStats{ - WorkerID: i, - Endpoint: worker.Endpoint(), - ChannelLength: worker.ChannelLength(), - } +func (ss *ShardedSender) ShardStats() []ShardStats { + var stats []ShardStats + for i, shard := range ss.shards { + stats = append(stats, ShardStats{ + ChainID: ss.cfg.SeiChainID, + ID: i, + Endpoint: ss.clients[i%len(ss.clients)].cfg.Endpoint, + TxsQueued: shard.Len(), + }) } return stats } - -// WorkerStats contains statistics for a single worker -type WorkerStats struct { - WorkerID int - Endpoint string - ChannelLength int -} - -// NumShards returns the number of shards (workers) -func (s *ShardedSender) NumShards() int { return len(s.workers) } diff --git a/sender/worker.go b/sender/worker.go deleted file mode 100644 index 21d33e4..0000000 --- a/sender/worker.go +++ /dev/null @@ -1,242 +0,0 @@ -package sender - -import ( - "context" - "fmt" - "log" - "net" - "net/http" - "net/url" - "time" - - "github.com/ethereum/go-ethereum/ethclient" - "github.com/ethereum/go-ethereum/rpc" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" - "go.opentelemetry.io/otel/trace" - "golang.org/x/time/rate" - - "github.com/sei-protocol/sei-load/stats" - "github.com/sei-protocol/sei-load/types" - "github.com/sei-protocol/sei-load/utils" - "github.com/sei-protocol/sei-load/utils/service" -) - -var tracer = otel.Tracer("github.com/sei-protocol/sei-load/sender") - -type WorkerConfig struct { - ID int - SeiChainID string - Endpoint string - BufferSize int - Tasks int - DryRun bool - Debug bool - Collector *stats.Collector - Limiter *rate.Limiter // Shared rate authority; nil disables gating. - // SkipRateLimit opts a worker out of limiter gating. Zero value (false) is the - // safe default (gate when Limiter is set); set true only in open-loop, where - // the scheduler owns the clock (see doc.go). - SkipRateLimit bool - // Inclusion, when present, receives each successful send at send-completion so - // the tracker can stamp InclusionTime (see doc.go). None disables tracking. - Inclusion utils.Option[*stats.InclusionTracker] -} - -// Worker handles sending transactions to a specific endpoint -type Worker struct { - cfg *WorkerConfig - txChan chan *types.LoadTx -} - -// HttpClientOption configures the Transport used by newHttpClient. -type HttpClientOption func(*http.Transport) - -// WithMaxIdleConns overrides the global idle-connection pool size. -func WithMaxIdleConns(n int) HttpClientOption { - return func(t *http.Transport) { t.MaxIdleConns = n } -} - -// WithMaxIdleConnsPerHost overrides the per-host idle-connection pool size. -// Scale with goroutine count to avoid TCP re-dial on each completion. -func WithMaxIdleConnsPerHost(n int) HttpClientOption { - return func(t *http.Transport) { t.MaxIdleConnsPerHost = n } -} - -// newHttpTransport is the base transport factory. Exists separately so tests -// can inspect the unwrapped *http.Transport; newHttpClient returns it wrapped -// in otelhttp, whose inner transport isn't publicly accessible. -func newHttpTransport(opts ...HttpClientOption) *http.Transport { - t := &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - MaxIdleConns: 500, - MaxIdleConnsPerHost: 50, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - DisableKeepAlives: false, - } - for _, opt := range opts { - opt(t) - } - return t -} - -// newHttpClient returns an otelhttp-wrapped client: injects traceparent on -// outbound, emits http.client.* metrics. Requires observability.Setup to have -// installed the global TextMapPropagator. -func newHttpClient(opts ...HttpClientOption) *http.Client { - return &http.Client{ - Timeout: 30 * time.Second, - Transport: otelhttp.NewTransport(newHttpTransport(opts...)), - } -} - -// newRPCClient returns a go-ethereum client configured for the endpoint scheme. -// HTTP(S) endpoints reuse the tuned otelhttp-backed transport; WS(S) endpoints -// use the default go-ethereum WebSocket transport. -func newRPCClient(ctx context.Context, endpoint string, opts ...HttpClientOption) (*ethclient.Client, error) { - u, err := url.Parse(endpoint) - if err != nil { - return nil, fmt.Errorf("parse endpoint %q: %w", endpoint, err) - } - - switch u.Scheme { - case "http", "https": - rpcClient, err := rpc.DialOptions(ctx, endpoint, rpc.WithHTTPClient(newHttpClient(opts...))) - if err != nil { - return nil, err - } - return ethclient.NewClient(rpcClient), nil - case "ws", "wss", "": - return ethclient.DialContext(ctx, endpoint) - default: - return nil, fmt.Errorf("unsupported RPC scheme %q for endpoint %s", u.Scheme, endpoint) - } -} - -// NewWorker creates a new worker for a specific endpoint -func NewWorker(cfg *WorkerConfig) *Worker { - w := &Worker{ - cfg: cfg, - txChan: make(chan *types.LoadTx, cfg.BufferSize), - } - meterWorkerQueueLength(w) - return w -} - -// Start begins the worker's processing loop -func (w *Worker) Run(ctx context.Context) error { - client, err := newRPCClient(ctx, w.cfg.Endpoint) - if err != nil { - return fmt.Errorf("dial %s: %w", w.cfg.Endpoint, err) - } - defer client.Close() - return service.Run(ctx, func(ctx context.Context, s service.Scope) error { - // Start multiple goroutines that share the same channel and RPC client. - for range w.cfg.Tasks { - s.Spawn(func() error { return w.runTxSender(ctx, client) }) - } - return nil - }) -} - -// Send queues a transaction for this worker to process -func (w *Worker) Send(ctx context.Context, tx *types.LoadTx) error { - return utils.Send(ctx, w.txChan, tx) -} - -// runTxSender is the main worker loop that processes transactions -func (w *Worker) runTxSender(ctx context.Context, client *ethclient.Client) error { - for ctx.Err() == nil { - // Closed-loop gates on the limiter before dequeue; open-loop skips it. - if !w.cfg.SkipRateLimit && w.cfg.Limiter != nil { - if err := w.cfg.Limiter.Wait(ctx); err != nil { - return err - } - } - - tx, err := utils.Recv(ctx, w.txChan) - if err != nil { - return err - } - - startTime := time.Now() - // Sole owner between dequeue and hand-off: stamp is race-free (see LoadTx). - tx.AttemptedSendTime = startTime - err = w.sendTransaction(ctx, client, tx) - // OnComplete must fire only after the real send returns β€” that is what - // bounds true unacked in-flight (see doc.go). Nil on closed-loop/batch. - if tx.OnComplete != nil { - tx.OnComplete(err) - } - w.cfg.Collector.RecordTransaction(tx.Scenario.Name, w.cfg.Endpoint, time.Since(startTime), err == nil) - // Register at send-completion, only on success: registered βŠ† succeeded. - // (The tracker is wired only for live runs β€” see main.go; DryRun never - // gets a tracker, so simulated sends are not inclusion-tracked.) - if err == nil { - if t, ok := w.cfg.Inclusion.Get(); ok { - t.Register(tx) - } - } - if err != nil { - log.Printf("%v", err) - } - } - return ctx.Err() -} - -// sendTransaction sends a single transaction to the endpoint -func (w *Worker) sendTransaction(ctx context.Context, client *ethclient.Client, tx *types.LoadTx) (_err error) { - ctx, span := tracer.Start(ctx, "sender.send_tx", trace.WithAttributes( - attribute.String("seiload.scenario", tx.Scenario.Name), - attribute.String("seiload.endpoint", w.cfg.Endpoint), - attribute.Int("seiload.worker_id", w.cfg.ID), - attribute.String("seiload.chain_id", w.cfg.SeiChainID), - )) - defer func(start time.Time) { - if _err != nil { - span.RecordError(_err) - } - span.End() - // Record inside the span ctx so exemplars link to the trace; worker_id - // stays off the histogram (cardinality), available via the span. - sendLatency.Record(ctx, time.Since(start).Seconds(), - metric.WithAttributes( - attribute.String("scenario", tx.Scenario.Name), - attribute.String("endpoint", w.cfg.Endpoint), - attribute.String("chain_id", w.cfg.SeiChainID), - statusAttrFromError(_err)), - ) - }(time.Now()) - if w.cfg.DryRun { - return utils.Sleep(ctx, 10*time.Microsecond) // minimal delay, no RPC - } - - if err := client.SendTransaction(ctx, tx.EthTx); err != nil { - txsRejected.Add(ctx, 1, metric.WithAttributes( - attribute.String("endpoint", w.cfg.Endpoint), - attribute.String("scenario", tx.Scenario.Name), - attribute.String("reason", "rpc"), - )) - return fmt.Errorf("Worker %d: Failed to send transaction: %w", w.cfg.ID, err) - } - - txsAccepted.Add(ctx, 1, metric.WithAttributes( - attribute.String("endpoint", w.cfg.Endpoint), - attribute.String("scenario", tx.Scenario.Name), - )) - return nil -} - -// ChannelLength returns the current length of the worker's channel (for monitoring). -// This function is safe for concurrent calls. -func (w *Worker) ChannelLength() int { return len(w.txChan) } - -// Endpoint returns the worker's endpoint -func (w *Worker) Endpoint() string { return w.cfg.Endpoint } diff --git a/sender/worker_test.go b/sender/worker_test.go deleted file mode 100644 index f280772..0000000 --- a/sender/worker_test.go +++ /dev/null @@ -1,223 +0,0 @@ -package sender - -import ( - "context" - "math/big" - "net/http" - "net/http/httptest" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/ethereum/go-ethereum/common" - ethtypes "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/rpc" - "github.com/stretchr/testify/require" - "golang.org/x/time/rate" - - "github.com/sei-protocol/sei-load/stats" - "github.com/sei-protocol/sei-load/types" - "github.com/sei-protocol/sei-load/utils" -) - -// drainWorkerWithLimiter runs runTxSender (DryRun: no RPC) over txCount queued -// txs gated by a tight limiter and returns how long the drain took. cancel fires -// once all txs are recorded, so the elapsed time reflects limiter pacing alone. -func drainWorkerWithLimiter(t *testing.T, skipRateLimit bool, txCount int, rps float64) time.Duration { - t.Helper() - collector := stats.NewCollector() - w := NewWorker(&WorkerConfig{ - ID: 0, - Endpoint: "dryrun", - BufferSize: txCount, - Tasks: 1, - DryRun: true, - Collector: collector, - Limiter: rate.NewLimiter(rate.Limit(rps), 1), - SkipRateLimit: skipRateLimit, - }) - for range txCount { - w.txChan <- &types.LoadTx{Scenario: &types.TxScenario{Name: "gate"}} - } - - ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) - defer cancel() - go func() { - for collector.GetStats().TotalTxs < uint64(txCount) { - time.Sleep(time.Millisecond) - } - cancel() - }() - - start := time.Now() - _ = w.runTxSender(ctx, nil) // DryRun never touches the client - return time.Since(start) -} - -// TestRunTxSender_RateLimitedByDefault is the SkipRateLimit-flip guard: the -// zero-value config (SkipRateLimit=false) with a non-nil Limiter must gate, so -// an omitted flag can never silently drop rate limiting. With burst=1 at `rps`, -// draining txCount txs cannot finish faster than (txCount-1)/rps. -func TestRunTxSender_RateLimitedByDefault(t *testing.T) { - const txCount = 10 - const rps = 50.0 // floor: (10-1)/50 = 180ms - elapsed := drainWorkerWithLimiter(t, false, txCount, rps) - require.GreaterOrEqual(t, elapsed, 150*time.Millisecond, - "default config must rate-limit (safe zero value)") -} - -// TestRunTxSender_SkipRateLimitBypassesLimiter confirms the open-loop opt-out: -// SkipRateLimit=true ignores the limiter entirely, so the same drain finishes -// far under the gated floor. -func TestRunTxSender_SkipRateLimitBypassesLimiter(t *testing.T) { - const txCount = 10 - const rps = 50.0 - elapsed := drainWorkerWithLimiter(t, true, txCount, rps) - require.Less(t, elapsed, 100*time.Millisecond, - "SkipRateLimit must bypass the limiter") -} - -// dryRunTx builds a minimal LoadTx with a real eth tx so EthTx.Hash() works. -func dryRunTx(nonce uint64) *types.LoadTx { - eth := ethtypes.NewTx(ðtypes.LegacyTx{ - Nonce: nonce, GasPrice: big.NewInt(1), Gas: 21000, - To: &common.Address{}, Value: big.NewInt(0), - }) - return &types.LoadTx{EthTx: eth, Scenario: &types.TxScenario{Name: "incl"}} -} - -// inflightCount reads the tracker's registry size via its Summary (read after a -// drain, so inflight is the registered-minus-terminal count). -func inflightCount(tr *stats.InclusionTracker) uint64 { - return tr.Summary().InflightAtShutdown -} - -// TestRunTxSender_RegistersSuccessfulSend asserts the inclusion hand-off: -// a successful (DryRun) send registers the tx with the tracker, and Register -// runs strictly AFTER OnComplete (the permit-release ordering in doc.go). -func TestRunTxSender_RegistersSuccessfulSend(t *testing.T) { - tracker := stats.NewInclusionTracker("test-chain", time.Hour, 100, true /* openLoop */) - collector := stats.NewCollector() - w := NewWorker(&WorkerConfig{ - ID: 0, Endpoint: "dryrun", BufferSize: 4, Tasks: 1, DryRun: true, - Collector: collector, SkipRateLimit: true, - Inclusion: utils.Some(tracker), - }) - - // Single tx so the registry starts empty: at OnComplete time inflight must - // still be 0, proving Register runs strictly after OnComplete. - var inflightAtComplete atomic.Int64 - tx := dryRunTx(0) - tx.OnComplete = func(error) { - inflightAtComplete.Store(int64(inflightCount(tracker))) - } - w.txChan <- tx - - ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) - defer cancel() - go func() { - for collector.GetStats().TotalTxs < 1 { - time.Sleep(time.Millisecond) - } - cancel() - }() - _ = w.runTxSender(ctx, nil) - - require.Equal(t, int64(0), inflightAtComplete.Load(), - "Register must fire after OnComplete (registry empty at OnComplete time)") - require.Equal(t, uint64(1), inflightCount(tracker), - "a successful send registers exactly once") -} - -// TestRunTxSender_NoInclusionTracker confirms a None tracker is a safe no-op. -func TestRunTxSender_NoInclusionTracker(t *testing.T) { - collector := stats.NewCollector() - w := NewWorker(&WorkerConfig{ - ID: 0, Endpoint: "dryrun", BufferSize: 2, Tasks: 1, DryRun: true, - Collector: collector, SkipRateLimit: true, - Inclusion: utils.None[*stats.InclusionTracker](), - }) - w.txChan <- dryRunTx(0) - - ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) - defer cancel() - go func() { - for collector.GetStats().TotalTxs < 1 { - time.Sleep(time.Millisecond) - } - cancel() - }() - require.NotPanics(t, func() { _ = w.runTxSender(ctx, nil) }) -} - -func TestNewHttpTransport_Defaults(t *testing.T) { - tr := newHttpTransport() - - require.Equal(t, 500, tr.MaxIdleConns) - require.Equal(t, 50, tr.MaxIdleConnsPerHost) - require.Equal(t, 90*time.Second, tr.IdleConnTimeout) - require.False(t, tr.DisableKeepAlives) -} - -func TestNewHttpTransport_WithMaxIdleConns(t *testing.T) { - tr := newHttpTransport(WithMaxIdleConns(2048)) - - require.Equal(t, 2048, tr.MaxIdleConns) - require.Equal(t, 50, tr.MaxIdleConnsPerHost, "per-host default preserved") -} - -func TestNewHttpTransport_WithMaxIdleConnsPerHost(t *testing.T) { - tr := newHttpTransport(WithMaxIdleConnsPerHost(1024)) - - require.Equal(t, 1024, tr.MaxIdleConnsPerHost) - require.Equal(t, 500, tr.MaxIdleConns, "global default preserved") -} - -func TestNewHttpTransport_MultipleOptions(t *testing.T) { - tr := newHttpTransport( - WithMaxIdleConns(4096), - WithMaxIdleConnsPerHost(1024), - ) - - require.Equal(t, 4096, tr.MaxIdleConns) - require.Equal(t, 1024, tr.MaxIdleConnsPerHost) -} - -func TestNewHttpClient_Smoke(t *testing.T) { - c := newHttpClient() - require.Equal(t, 30*time.Second, c.Timeout) - require.NotNil(t, c.Transport, "Transport must be set") - _, isBareTransport := c.Transport.(*http.Transport) - require.False(t, isBareTransport, "Transport should be wrapped by otelhttp, not bare *http.Transport") -} - -func TestNewRPCClient_HTTP(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer srv.Close() - - client, err := newRPCClient(context.Background(), srv.URL) - require.NoError(t, err) - require.NotNil(t, client) - client.Close() -} - -func TestNewRPCClient_WS(t *testing.T) { - srv := rpc.NewServer() - ts := httptest.NewServer(srv.WebsocketHandler([]string{"*"})) - defer ts.Close() - - wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") - client, err := newRPCClient(context.Background(), wsURL) - require.NoError(t, err) - require.NotNil(t, client) - client.Close() -} - -func TestNewRPCClient_UnsupportedScheme(t *testing.T) { - client, err := newRPCClient(context.Background(), "ftp://example.com") - require.Error(t, err) - require.Nil(t, client) -} diff --git a/stats/block_collector.go b/stats/block_collector.go index 6a1794c..82d22fc 100644 --- a/stats/block_collector.go +++ b/stats/block_collector.go @@ -10,7 +10,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" "github.com/sei-protocol/sei-load/utils" - "github.com/sei-protocol/sei-load/utils/service" + "github.com/sei-protocol/sei-load/utils/scope" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) @@ -57,7 +57,7 @@ func NewBlockCollector(seiChainID string) *BlockCollector { // Start begins block subscription and data collection func (bc *BlockCollector) Run(ctx context.Context, firstEndpoint string) error { wsEndpoint := utils.GetWSEndpoint(firstEndpoint) - return service.Run(ctx, func(ctx context.Context, s service.Scope) error { + return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { // Connect to WebSocket endpoint client, err := ethclient.Dial(wsEndpoint) if err != nil { diff --git a/types/scenario.go b/types/scenario.go index af2f852..a70b8a4 100644 --- a/types/scenario.go +++ b/types/scenario.go @@ -18,7 +18,7 @@ import ( // that stage, and is immutable thereafter; ownership transfers with the pointer // across the channels, so the writes need no locking. The open-loop scheduler // writes IntendedSendTime and SequenceIndex while it solely owns the tx (before -// the worker hand-off); the worker writes AttemptedSendTime; the inclusion +// the sender hand-off); the sender writes AttemptedSendTime; the inclusion // tracker writes InclusionTime. A zero timestamp means "not recorded" (e.g. // prewarm txs, or a stage not yet reached) β€” consumers must treat it as // untracked, never as the zero epoch. @@ -46,16 +46,16 @@ type LoadTx struct { // model (see IntendedSendTime); the run's arrival model is authoritative. SequenceIndex uint64 // AttemptedSendTime is when the send was actually attempted, written by the - // worker goroutine that owns the tx between dequeue and send completion. + // sender goroutine that owns the tx between dequeue and send completion. AttemptedSendTime time.Time // OnComplete, if set, is invoked exactly once when the network send attempt // for this tx finishes (after sendTransaction returns), with the send error // or nil. The open-loop scheduler sets it to release the in-flight permit so // the bound covers true unacked sends (enqueue + send), not just queue depth; - // see the open-loop scheduler. The worker invokes it after sendTransaction + // see the open-loop scheduler. The sender invokes it after send completion // and is the sole invoker on the happy path. Nil in the closed-loop and batch - // paths, where the worker simply skips it. The callback must be cheap and - // non-blocking β€” the worker holds the tx and calls it inline. Written by the + // paths, where the sender simply skips it. The callback must be cheap and + // non-blocking β€” the sender holds the tx and calls it inline. Written by the // owning goroutine before hand-off, per the lifecycle concurrency contract. OnComplete func(err error) // InclusionTime is when the tx was observed included on-chain, written only diff --git a/utils/channels.go b/utils/channels.go index 1e11b90..a8bafee 100644 --- a/utils/channels.go +++ b/utils/channels.go @@ -2,8 +2,6 @@ package utils import ( "context" - - "github.com/pkg/errors" ) // Recv receives a value from a channel or returns an error if the context is canceled. @@ -51,21 +49,3 @@ func SendOrDrop[T any](ch chan<- T, v T) { default: // drop the item } } - -// ForEach is a helper function that reads from a channel and calls a handler for each item. -// this avoids needing a lot of for/select boilerplate everywhere. -func ForEach[T any](ctx context.Context, ch <-chan T, handler func(T) error) error { - for { - select { - case <-ctx.Done(): - return errors.WithStack(ctx.Err()) - case item, ok := <-ch: - if !ok { - return nil // Channel closed - } - if err := handler(item); err != nil { - return err // Stop on error - } - } - } -} diff --git a/utils/mutex.go b/utils/mutex.go index beb1154..ffd5952 100644 --- a/utils/mutex.go +++ b/utils/mutex.go @@ -33,6 +33,41 @@ func (m *Mutex[T]) Lock() iter.Seq[T] { } } +// Mutex guards access to object of type T. +type RWMutex[T any] struct { + mu sync.RWMutex + value T +} + +// NewMutex creates a new Mutex with given object. +func NewRWMutex[T any](value T) (m RWMutex[T]) { + m.value = value + // nolint:nakedret + return +} + +// Lock returns an iterator which locks the mutex and yields the guarded object. +// The mutex is unlocked when the iterator is done. +// If the mutex is nil, the iterator is a no-op. +func (m *RWMutex[T]) Lock() iter.Seq[T] { + return func(yield func(val T) bool) { + m.mu.Lock() + defer m.mu.Unlock() + _ = yield(m.value) + } +} + +// RLock returns an iterator which locks the mutex FOR READ and yields the guarded object. +// The mutex is unlocked when the iterator is done. +// If the mutex is nil, the iterator is a no-op. +func (m *RWMutex[T]) RLock() iter.Seq[T] { + return func(yield func(val T) bool) { + m.mu.RLock() + defer m.mu.RUnlock() + _ = yield(m.value) + } +} + // version of the value stored in an atomic watch. type version[T any] struct { updated chan struct{} @@ -48,80 +83,31 @@ type atomicWatch[T any] struct { ptr atomic.Pointer[version[T]] } -type AtomicSend[T any] struct { - atomicWatch[T] -} - -// Store updates the value of the atomic watch. -func (w *AtomicSend[T]) Send(value T) { - close(w.ptr.Swap(newVersion(value)).updated) -} +type AtomicSend[T any] struct{ atomicWatch[T] } -// Update conditionally updates the value of the atomic watch. -func (w *AtomicSend[T]) Update(f func(T) (T, bool)) { - old := w.ptr.Load() - if value, ok := f(old.value); ok { - w.ptr.Store(newVersion(value)) - close(old.updated) - } +func (w *AtomicSend[T]) Subscribe() AtomicRecv[T] { + return AtomicRecv[T]{&w.atomicWatch} } +// NewAtomicWatch creates a new AtomicWatch with the given initial value. func NewAtomicSend[T any](value T) (w AtomicSend[T]) { w.ptr.Store(newVersion(value)) // nolint:nakedret return } -func (w *AtomicSend[T]) Subscribe() AtomicRecv[T] { - return AtomicRecv[T]{&w.atomicWatch} -} - -// AtomicWatch stores a pointer to an IMMUTABLE value. -// Loading and waiting for updates do NOT require locking. -// TODO(gprusak): remove mutex and rename to AtomicSend, -// this will allow for sharing a mutex across multiple AtomicSenders. -type AtomicWatch[T any] struct { - atomicWatch[T] - mu sync.Mutex +// Store updates the value of the atomic watch. +func (w *AtomicSend[T]) Store(value T) { + close(w.ptr.Swap(newVersion(value)).updated) } // AtomicRecv is a read-only reference to AtomicWatch. type AtomicRecv[T any] struct{ *atomicWatch[T] } -// NewAtomicWatch creates a new AtomicWatch with the given initial value. -func NewAtomicWatch[T any](value T) (w AtomicWatch[T]) { - w.ptr.Store(newVersion(value)) - // nolint:nakedret - return -} - -// Subscribe returns a view-only API of the atomic watch. -func (w *AtomicWatch[T]) Subscribe() AtomicRecv[T] { - return AtomicRecv[T]{&w.atomicWatch} -} - // Load returns the current value of the atomic watch. // Does not do any locking. func (w *atomicWatch[T]) Load() T { return w.ptr.Load().value } -// Store updates the value of the atomic watch. -func (w *AtomicWatch[T]) Store(value T) { - w.mu.Lock() - defer w.mu.Unlock() - close(w.ptr.Swap(newVersion(value)).updated) -} - -// Update conditionally updates the value of the atomic watch. -func (w *AtomicWatch[T]) Update(f func(T) (T, bool)) { - w.mu.Lock() - defer w.mu.Unlock() - old := w.ptr.Load() - if value, ok := f(old.value); ok { - w.ptr.Store(newVersion(value)) - close(old.updated) - } -} - // Wait waits for the value of the atomic watch to satisfy the predicate. // Does not do any locking. func (w *atomicWatch[T]) Wait(ctx context.Context, pred func(T) bool) (T, error) { @@ -232,3 +218,17 @@ func (w *Watch[T]) Lock() iter.Seq2[T, *WatchCtrl] { _ = yield(w.val, &w.ctrl) } } + +// MonitorWatchUpdates calls f and checks if it has updated the watch. +func MonitorWatchUpdates[T any](w *Watch[T], f func()) bool { + w.ctrl.mu.Lock() + updated := w.ctrl.updated + w.ctrl.mu.Unlock() + f() + select { + case <-updated: + return true + default: + return false + } +} diff --git a/utils/option.go b/utils/option.go index 85fd6a4..2dd26f2 100644 --- a/utils/option.go +++ b/utils/option.go @@ -33,13 +33,20 @@ func (o Option[T]) IsPresent() bool { } // Or returns the value if present, otherwise returns the default value. -func (o *Option[T]) Or(def T) T { +func (o Option[T]) Or(def T) T { if o.isPresent { return o.value } return def } +func (o Option[T]) OrPanic(msg string) T { + if o.isPresent { + return o.value + } + panic(msg) +} + // MapOpt applies a function to the value if present, returning a new Option. func MapOpt[T, R any](o Option[T], f func(T) R) Option[R] { if o.isPresent { diff --git a/utils/option_test.go b/utils/option_test.go new file mode 100644 index 0000000..03fb54f --- /dev/null +++ b/utils/option_test.go @@ -0,0 +1,32 @@ +package utils + +import ( + "encoding/json" + "testing" + + "github.com/sei-protocol/sei-load/utils/require" +) + +func testJSON[T any](t *testing.T, want T) { + enc, err := json.Marshal(want) + require.NoError(t, err) + t.Logf("%s", enc) + var got T + require.NoError(t, json.Unmarshal(enc, &got)) + require.NoError(t, TestDiff(want, got)) +} + +func TestOptionJSON(t *testing.T) { + type a struct { + X Option[int] + Y Option[string] + } + type b struct { + X Option[int] `json:"X,omitzero"` + Y Option[string] `json:"Y,omitzero"` + } + testJSON(t, &a{}) + testJSON(t, &a{Some(1), Some("a")}) + testJSON(t, &b{}) + testJSON(t, &b{Some(1), Some("a")}) +} diff --git a/utils/panic.go b/utils/panic.go new file mode 100644 index 0000000..960f787 --- /dev/null +++ b/utils/panic.go @@ -0,0 +1,17 @@ +package utils + +// OrPanic panics if err is non-nil. Use for initialization-time or otherwise +// unrecoverable failures where returning an error is not an option (e.g. var +// initializers, metric instrument creation). +func OrPanic(err error) { + if err != nil { + panic(err) + } +} + +// OrPanic1 returns v, panicking if err is non-nil. Convenience for wrapping a +// (value, error) call in a var initializer that cannot fail at runtime. +func OrPanic1[T any](v T, err error) T { + OrPanic(err) + return v +} diff --git a/utils/proto.go b/utils/proto.go index 5f5ad7a..0842c98 100644 --- a/utils/proto.go +++ b/utils/proto.go @@ -1,28 +1,19 @@ package utils import ( - "crypto/sha256" "errors" "fmt" "sync" - "google.golang.org/protobuf/proto" + "github.com/gogo/protobuf/proto" ) -// Hash is a SHA-256 hash. -type Hash [sha256.Size]byte - -// GetHash computes a hash of the given data. -func GetHash(data []byte) Hash { - return sha256.Sum256(data) -} - -// ParseHash parses a Hash from bytes. -func ParseHash(raw []byte) (Hash, error) { - if got, want := len(raw), sha256.Size; got != want { - return Hash{}, fmt.Errorf("hash size = %v, want %v", got, want) +func ErrorAs[T error](err error) Option[T] { + var target T + if errors.As(err, &target) { + return Some(target) } - return Hash(raw), nil + return None[T]() } // ProtoClone clones a proto.Message object. @@ -35,16 +26,6 @@ func ProtoEqual[T proto.Message](a, b T) bool { return proto.Equal(a, b) } -// ProtoHash hashes a proto.Message object. -// TODO(gprusak): make it deterministic. -func ProtoHash(a proto.Message) Hash { - raw, err := proto.Marshal(a) - if err != nil { - panic(err) - } - return sha256.Sum256(raw) -} - // ProtoMessage is comparable proto.Message. type ProtoMessage interface { comparable diff --git a/utils/require/require.go b/utils/require/require.go new file mode 100644 index 0000000..14d011d --- /dev/null +++ b/utils/require/require.go @@ -0,0 +1,113 @@ +// Package require reexports strongly typed `testify/require` API. +// We don't reexport `New`, because methods cannot be generic. +package require + +import ( + "cmp" + + "github.com/stretchr/testify/require" +) + +// TestingT . +type TestingT = require.TestingT + +// False . +var False = require.False + +// True . +var True = require.True + +// Zero . +var Zero = require.Zero + +// NotZero . +var NotZero = require.NotZero + +// Contains . +var Contains = require.Contains + +func ElementsMatch[T any](t TestingT, a []T, b []T, msgAndArgs ...any) { + require.ElementsMatch(t, a, b, msgAndArgs...) +} + +// Eventually . +var Eventually = require.Eventually + +// EqualError . +// TODO: get rid of comparing errors by strings, +// use concrete error types instead. +var EqualError = require.EqualError + +// Error . +var Error = require.Error + +// ErrorIs . +var ErrorIs = require.ErrorIs + +// NoError . +var NoError = require.NoError + +// Empty . +var Empty = require.Empty + +// NotEmpty . +var NotEmpty = require.NotEmpty + +// Len . +var Len = require.Len + +// Nil . +var Nil = require.Nil + +// NotNil . +var NotNil = require.NotNil + +// Panics . +var Panics = require.Panics + +// Fail . +var Fail = require.Fail + +// FailNow . +var FailNow = require.FailNow + +// NoFileExists . +var NoFileExists = require.NoFileExists + +// FileExists . +var FileExists = require.FileExists + +// Positive . +func Positive[T cmp.Ordered](t TestingT, e T, msgAndArgs ...any) { + require.Positive(t, e, msgAndArgs...) +} + +// Less . +func Less[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { + require.Less(t, e1, e2, msgAndArgs...) +} + +// LessOrEqual . +func LessOrEqual[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { + require.LessOrEqual(t, e1, e2, msgAndArgs...) +} + +// Greater . +func Greater[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { + require.Greater(t, e1, e2, msgAndArgs...) +} + +// GreaterOrEqual . +func GreaterOrEqual[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { + require.GreaterOrEqual(t, e1, e2, msgAndArgs...) +} + +// Equal . +func Equal[T any](t TestingT, expected, actual T, msgAndArgs ...any) { + require.Equal(t, expected, actual, msgAndArgs...) +} + +// NotEqual . +func NotEqual[T any](t TestingT, expected, actual T, msgAndArgs ...any) { + require.NotEqual(t, expected, actual, msgAndArgs...) +} diff --git a/utils/scope/global.go b/utils/scope/global.go new file mode 100644 index 0000000..91758a0 --- /dev/null +++ b/utils/scope/global.go @@ -0,0 +1,80 @@ +package scope + +import ( + "context" + + "github.com/sei-protocol/sei-load/utils" +) + +// GlobalHandle is a handle to a task spawned via SpawnGlobal. +type GlobalHandle[T any] struct { + cancel context.CancelFunc + done chan struct{} + res T +} + +// SpawnGlobal spawns a task in a global context. +// Use with care, as it is not tied to any scope and must be terminated manually by calling Terminate(). +// The task does not return an error, because there is no canonical way to handle it. +// Can be used as an intermediate step when migrating code to use scopes. +func SpawnGlobal[T any](task func(ctx context.Context) T) *GlobalHandle[T] { + ctx, cancel := context.WithCancel(context.Background()) + h := &GlobalHandle[T]{ + cancel: cancel, + done: make(chan struct{}), + } + go func() { + h.res = task(ctx) + close(h.done) + }() + return h +} + +// WhileRunning restricts ctx to the lifetime of the task. +// WARNING: If the task is already finished, it SKIPs running f and returns context.Canceled. +func (h *GlobalHandle[T]) WhileRunning(ctx context.Context, f func(ctx context.Context) error) error { + select { + case <-h.done: + return context.Canceled + default: + } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + select { + case <-ctx.Done(): + case <-h.done: + cancel() + } + }() + return f(ctx) +} + +// WhileRunning1 is like WhileRunning but for functions returning a value. +func WhileRunning1[R any, T any](ctx context.Context, h *GlobalHandle[T], f func(ctx context.Context) (R, error)) (res R, err error) { + // We need to set the error outside the closure, because + // h.WhileRunning() may return context.Canceled if the task is already finished. + err = h.WhileRunning(ctx, func(ctx context.Context) error { + res, err = f(ctx) + return err + }) + return +} + +// Join awaits tasks completion. +func (h *GlobalHandle[T]) Join(ctx context.Context) (T, error) { + select { + case <-ctx.Done(): + return utils.Zero[T](), ctx.Err() + case <-h.done: + return h.res, nil + } +} + +// Terminate cancels the task and waits for it to finish. +// Returns the task's result. +func (h *GlobalHandle[T]) Terminate() T { + h.cancel() + <-h.done + return h.res +} diff --git a/utils/scope/parallel.go b/utils/scope/parallel.go new file mode 100644 index 0000000..1377184 --- /dev/null +++ b/utils/scope/parallel.go @@ -0,0 +1,41 @@ +package scope + +import ( + "sync" + "sync/atomic" +) + +type parallelScope struct { + wg sync.WaitGroup + err atomic.Pointer[error] +} + +// ParallelScope is a scope which doesn't require cancellation token, +// just parallelization. +type ParallelScope struct{ *parallelScope } + +// Spawn spawns a new task in the scope. +func (s *parallelScope) Spawn(t func() error) { + s.wg.Add(1) + go func() { + if err := t(); err != nil { + s.err.CompareAndSwap(nil, &err) + } + s.wg.Done() + }() +} + +// Parallel executes a function in parallel scope. +// Compared to Run, it does not allow for early cancellation, +// therefore is suitable for non-blocking computations. +// Returns the first error returned by any of the spawned tasks. +// Waits until all the tasks complete, before returning. +func Parallel(main func(ParallelScope) error) error { + var s parallelScope + s.Spawn(func() error { return main(ParallelScope{&s}) }) + s.wg.Wait() + if perr := s.err.Load(); perr != nil { + return *perr + } + return nil +} diff --git a/utils/service/parallel_test.go b/utils/scope/parallel_test.go similarity index 62% rename from utils/service/parallel_test.go rename to utils/scope/parallel_test.go index 57f7ba2..7f98872 100644 --- a/utils/service/parallel_test.go +++ b/utils/scope/parallel_test.go @@ -1,15 +1,13 @@ -package service +package scope import ( "errors" "testing" - - "github.com/stretchr/testify/require" ) func TestParallelOk(t *testing.T) { x := [10]int{} - err := Parallel(func(s ParallelScope) error { + if err := Parallel(func(s ParallelScope) error { for i := range x { s.Spawn(func() error { x[i] = i @@ -17,10 +15,13 @@ func TestParallelOk(t *testing.T) { }) } return nil - }) - require.NoError(t, err) + }); err != nil { + t.Fatal(err) + } for want, got := range x { - require.Equal(t, want, got, "x[%d] = %d, want %d", want, got, want) + if want != got { + t.Fatalf("x[%d] = %d, want %d", want, got, want) + } } } @@ -39,11 +40,15 @@ func TestParallelFail(t *testing.T) { } return nil }) - require.ErrorIs(t, wantErr, err, "err = %v, want %v", err, wantErr) + if !errors.Is(err, wantErr) { + t.Fatalf("err = %v, want %v", err, wantErr) + } for want, got := range x { if want%2 == 0 { want = 0 } - require.Equal(t, want, got, "x[%d] = %d, want %d", want, got, want) + if want != got { + t.Fatalf("x[%d] = %d, want %d", want, got, want) + } } } diff --git a/utils/scope/start.go b/utils/scope/start.go new file mode 100644 index 0000000..e6e80f3 --- /dev/null +++ b/utils/scope/start.go @@ -0,0 +1,166 @@ +package scope + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "github.com/sei-protocol/sei-load/utils" +) + +type scope struct { + // scope is a concurrecy primitive, so no-ctx-in-struct rule does not apply + // nolint:containedctx + ctx context.Context + cancel context.CancelFunc + all sync.WaitGroup + main sync.WaitGroup + errOnce sync.Once + err error +} + +// Scope of concurrenct tasks. +type Scope struct{ *scope } + +// SpawnBg spawns a background task. +// Background tasks get canceled when all the main tasks return. +func (s Scope) SpawnBg(t func() error) { + s.all.Add(1) + go func() { + defer s.all.Done() + if err := t(); err != nil { + s.Cancel(err) + } + }() +} + +// Spawn spawns a main task. +// Scope gets automatically canceled when all the main tasks return. +func (s Scope) Spawn(t func() error) { + s.main.Add(1) + s.SpawnBg(func() error { + defer s.main.Done() + return t() + }) +} + +// Cancels the scope. +// If err is not nil and no error was set before, +// sets err as the scope error. +func (s Scope) Cancel(err error) { + if err != nil { + s.errOnce.Do(func() { + s.err = err + }) + } + s.cancel() +} + +// JoinHandle is a handle to an awaitable task. +type JoinHandle[R any] struct { + result utils.AtomicRecv[*R] +} + +// Spawn1 is the same as Scope.Spawn, but allows awaiting completion of a task and getting its result. +func Spawn1[R any](s Scope, t func() (R, error)) JoinHandle[R] { + result := utils.NewAtomicSend[*R](nil) + s.Spawn(func() error { + v, err := t() + if err != nil { + return err + } + result.Store(&v) + return nil + }) + return JoinHandle[R]{result.Subscribe()} +} + +// Join awaits completion of a task and returns its result. +// WARNING: it does NOT return the error of the task - error is returned from the Run() command. +// Join() can only fail when context is canceled. +func (h JoinHandle[R]) Join(ctx context.Context) (R, error) { + res, err := h.result.Wait(ctx, func(v *R) bool { return v != nil }) + if err != nil { + return utils.Zero[R](), err + } + return *res, nil +} + +// If true, tasks that do not respect context cancellation will be logged. +// This is useful for debugging, but causes unnecessary overhead. +// Since this is a constant, debug guard should be optimized out by the compiler. +const enableDebugGuard = false + +func (s Scope) debugGuard(name string, done chan struct{}) { + select { + case <-done: + return + case <-s.ctx.Done(): + } + for { + select { + case <-done: + return + case <-time.After(10 * time.Second): + } + log.Printf("task %q still running", name) + } +} + +// SpawnNamed spawns a named main task. +func (s Scope) SpawnNamed(name string, t func() error) { + done := make(chan struct{}) + s.Spawn(func() error { + defer close(done) + if err := t(); err != nil { + return fmt.Errorf("%s: %w", name, err) + } + return nil + }) + if enableDebugGuard { + go s.debugGuard(name, done) + } +} + +// SpawnBgNamed spawns a named background task. +func (s Scope) SpawnBgNamed(name string, t func() error) { + done := make(chan struct{}) + s.SpawnBg(func() error { + defer close(done) + if err := t(); err != nil { + return fmt.Errorf("%s: %w", name, err) + } + return nil + }) + if enableDebugGuard { + go s.debugGuard(name, done) + } +} + +// Run runs a scope capable of spawning tasks. +// It is guaranteed that all the spawned tasks will be executed (even if spawned after the context is cancelled), +// and that `Run` will return only after all the tasks have completed. +// Context of the tasks will be automatically cancelled as soon as ANY task returns an error. +// Returns the first error returned by any task (main or background). +func Run(ctx context.Context, main func(context.Context, Scope) error) error { + ctx, cancel := context.WithCancel(ctx) + s := Scope{&scope{ctx: ctx, cancel: cancel}} + s.Spawn(func() error { return main(ctx, s) }) + s.main.Wait() + s.cancel() + s.all.Wait() + return s.err +} + +// Run1 is the same as Run, but returns the result of the main task. +func Run1[R any](ctx context.Context, main func(context.Context, Scope) (R, error)) (res R, err error) { + err = Run(ctx, func(ctx context.Context, s Scope) error { + var err error + res, err = main(ctx, s) + return err + }) + //nolint:nakedret + return +} diff --git a/utils/semaphore.go b/utils/semaphore.go deleted file mode 100644 index 941c670..0000000 --- a/utils/semaphore.go +++ /dev/null @@ -1,37 +0,0 @@ -package utils - -import ( - "context" -) - -// Semaphore provides a way to bound concurrenct access to a resource. -type Semaphore struct { - ch chan struct{} -} - -// NewSemaphore constructs a new semaphore with n permits. -func NewSemaphore(n int) *Semaphore { - return &Semaphore{ch: make(chan struct{}, n)} -} - -// Acquire acquires a permit from the semaphore. -// Blocks until a permit is available. -func (s *Semaphore) Acquire(ctx context.Context) (release func(), err error) { - if err := Send(ctx, s.ch, struct{}{}); err != nil { - return nil, err - } - return func() { <-s.ch }, nil -} - -// TryAcquire acquires a permit without blocking. It returns the release func -// and true if a permit was available, or nil and false if all permits are held. -// Used by callers that must never block waiting for capacity (e.g. an open-loop -// scheduler that drops rather than throttling its clock). -func (s *Semaphore) TryAcquire() (release func(), ok bool) { - select { - case s.ch <- struct{}{}: - return func() { <-s.ch }, true - default: - return nil, false - } -} diff --git a/utils/service/parallel.go b/utils/service/parallel.go index f2bcea9..14a7fcd 100644 --- a/utils/service/parallel.go +++ b/utils/service/parallel.go @@ -1,41 +1,9 @@ package service -import ( - "sync" - "sync/atomic" -) +import "github.com/sei-protocol/sei-load/utils/scope" -type parallelScope struct { - wg sync.WaitGroup - err atomic.Pointer[error] -} - -// ParallelScope is a scope which doesn't require cancellation token, -// just parallelization. -type ParallelScope struct{ *parallelScope } - -// Spawn spawns a new task in the scope. -func (s *parallelScope) Spawn(t func() error) { - s.wg.Add(1) - go func() { - if err := t(); err != nil { - s.err.CompareAndSwap(nil, &err) - } - s.wg.Done() - }() -} +type ParallelScope = scope.ParallelScope -// Parallel executes a function in parallel scope. -// Compared to Run, it does not allow for early cancellation, -// therefore is suitable for non-blocking computations. -// Returns the first error returned by any of the spawned tasks. -// Waits until all the tasks complete, before returning. func Parallel(main func(ParallelScope) error) error { - var s parallelScope - s.Spawn(func() error { return main(ParallelScope{&s}) }) - s.wg.Wait() - if perr := s.err.Load(); perr != nil { - return *perr - } - return nil + return scope.Parallel(main) } diff --git a/utils/service/start.go b/utils/service/start.go index a077f95..04312c3 100644 --- a/utils/service/start.go +++ b/utils/service/start.go @@ -2,142 +2,22 @@ package service import ( "context" - "fmt" - "log" - "sync" - "time" - "golang.org/x/sync/errgroup" - - "github.com/sei-protocol/sei-load/utils" + "github.com/sei-protocol/sei-load/utils/scope" ) -// Scope of concurrenct tasks. -type Scope struct { - // scope is a concurrecy primitive, so no-ctx-in-struct rule does not apply - // nolint:containedctx - ctx context.Context - all *errgroup.Group - main *sync.WaitGroup -} - -// Spawn spawns a main task. -// Scope gets automatically canceled when all the main tasks return. -func (s Scope) Spawn(t func() error) { - s.main.Add(1) - s.all.Go(func() error { - defer s.main.Done() - return t() - }) -} +type Scope = scope.Scope -// JoinHandle is a handle to an awaitable task. -type JoinHandle[R any] struct { - result utils.AtomicRecv[*R] -} +type JoinHandle[R any] = scope.JoinHandle[R] -// Spawn1 is the same as Scope.Spawn, but allows awaiting completion of a task and getting its result. func Spawn1[R any](s Scope, t func() (R, error)) JoinHandle[R] { - send := utils.NewAtomicSend[*R](nil) - s.Spawn(func() error { - v, err := t() - if err != nil { - return err - } - send.Send(&v) - return nil - }) - return JoinHandle[R]{send.Subscribe()} -} - -// Join awaits completion of a task and returns its result. -// WARNING: it does NOT return the error of the task - error is returned from the Run() command. -// Join() can only fail when context is canceled. -func (h JoinHandle[R]) Join(ctx context.Context) (R, error) { - res, err := h.result.Wait(ctx, func(v *R) bool { return v != nil }) - if err != nil { - return utils.Zero[R](), err - } - return *res, nil + return scope.Spawn1(s, t) } -// If true, tasks that do not respect context cancellation will be logged. -// This is useful for debugging, but causes unnecessary overhead. -// Since this is a constant, debug guard should be optimized out by the compiler. -const enableDebugGuard = false - -func (s Scope) debugGuard(name string, done chan struct{}) { - select { - case <-done: - return - case <-s.ctx.Done(): - } - for { - select { - case <-done: - return - case <-time.After(10 * time.Second): - } - log.Printf("task %q still running", name) - } -} - -// SpawnNamed spawns a named main task. -func (s Scope) SpawnNamed(name string, t func() error) { - done := make(chan struct{}) - s.Spawn(func() error { - defer close(done) - if err := t(); err != nil { - return fmt.Errorf("%s: %w", name, err) - } - return nil - }) - if enableDebugGuard { - go s.debugGuard(name, done) - } -} - -// SpawnBgNamed spawns a named background task. -func (s Scope) SpawnBgNamed(name string, t func() error) { - done := make(chan struct{}) - s.SpawnBg(func() error { - defer close(done) - if err := t(); err != nil { - return fmt.Errorf("%s: %w", name, err) - } - return nil - }) - if enableDebugGuard { - go s.debugGuard(name, done) - } -} - -// SpawnBg spawns a background task. -// Background tasks get canceled when all the main tasks return. -func (s Scope) SpawnBg(t func() error) { s.all.Go(t) } - -// Run runs a scope capable of spawning tasks. -// It is guaranteed that all the spawned tasks will be executed (even if spawned after the context is cancelled), -// and that `Run` will return only after all the tasks have completed. -// Context of the tasks will be automatically cancelled as soon as ANY task returns an error. -// Returns the first error returned by any task (main or background). func Run(ctx context.Context, main func(context.Context, Scope) error) error { - ctx, cancel := context.WithCancel(ctx) - all, ctx := errgroup.WithContext(ctx) - s := Scope{ctx, all, &sync.WaitGroup{}} - s.Spawn(func() error { return main(ctx, s) }) - s.main.Wait() - cancel() - return s.all.Wait() + return scope.Run(ctx, main) } -// Run1 is the same as Run, but returns the result of the main task. -func Run1[R any](ctx context.Context, main func(context.Context, Scope) (R, error)) (res R, err error) { - err = Run(ctx, func(ctx context.Context, s Scope) error { - var err error - res, err = main(ctx, s) - return err - }) - //nolint:nakedret - return +func Run1[R any](ctx context.Context, main func(context.Context, Scope) (R, error)) (R, error) { + return scope.Run1(ctx, main) } diff --git a/utils/testonly.go b/utils/testonly.go index 7c0ddf0..56d3ac6 100644 --- a/utils/testonly.go +++ b/utils/testonly.go @@ -1,15 +1,17 @@ package utils import ( + "bytes" + "context" "fmt" "math/big" "math/rand" "reflect" "time" + "github.com/gogo/protobuf/proto" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" ) @@ -19,7 +21,7 @@ type ReadOnly struct{} // isReadOnly returns true if t embeds ReadOnly. func isReadOnly(t reflect.Type) bool { - want := reflect.TypeOf(ReadOnly{}) + want := reflect.TypeFor[ReadOnly]() if t.Kind() != reflect.Struct { return false } @@ -45,6 +47,9 @@ var cmpOpts = []cmp.Option{ protocmp.Transform(), cmp.Exporter(isReadOnly), cmpopts.EquateEmpty(), + // Optimization for comparing slices of bytes. + // Applies iff any of the slices is non-empty to avoid collision with EquateEmpty. + cmp.FilterValues(func(x, y []byte) bool { return len(x) > 0 || len(y) > 0 }, cmp.Comparer(bytes.Equal)), cmp.Comparer(cmpComparer[big.Int]), } @@ -61,46 +66,119 @@ func TestEqual[T any](a, b T) bool { return cmp.Equal(a, b, cmpOpts...) } -// TestRngSplit returns a new random number splitted from the given one. +// Thread-safe wrapper of rand.Rand. +type Rng struct{ inner *Mutex[*rand.Rand] } + +func (rng Rng) Read(p []byte) (int, error) { + for inner := range rng.inner.Lock() { + return inner.Read(p) + } + panic("unreachable") +} + +func (rng Rng) Int63() int64 { + for inner := range rng.inner.Lock() { + return inner.Int63() + } + panic("unreachable") +} + +func (rng Rng) Uint64() uint64 { + for inner := range rng.inner.Lock() { + return inner.Uint64() + } + panic("unreachable") +} + +func (rng Rng) Int() int { + for inner := range rng.inner.Lock() { + return inner.Int() + } + panic("unreachable") +} + +func (rng Rng) Intn(n int) int { + for inner := range rng.inner.Lock() { + return inner.Intn(n) + } + panic("unreachable") +} + +func (rng Rng) Int63n(n int64) int64 { + for inner := range rng.inner.Lock() { + return inner.Int63n(n) + } + panic("unreachable") +} + +func (rng Rng) Shuffle(n int, swap func(i, j int)) { + for inner := range rng.inner.Lock() { + inner.Shuffle(n, swap) + } +} + +// Split returns a new random number splitted from the given one. +// It should be used to provide deterministic rngs to independent goroutines. // This is a very primitive splitting, known to result with dependent randomness. // If that ever causes a problem, we can switch to SplitMix. -func TestRngSplit(rng *rand.Rand) *rand.Rand { - return rand.New(rand.NewSource(rng.Int63())) +func (rng Rng) Split() Rng { + for inner := range rng.inner.Lock() { + return TestRngFromSeed(inner.Int63()) + } + panic("unreachable") } // TestRng returns a deterministic random number generator. -func TestRng() *rand.Rand { - return rand.New(rand.NewSource(789345342)) +func TestRng() Rng { + return TestRngFromSeed(789345342) +} + +func TestRngFromSeed(seed int64) Rng { + return Rng{Alloc(NewMutex(rand.New(rand.NewSource(seed))))} +} + +func GenBool(rng Rng) bool { + return rng.Intn(2) == 0 } var alphanum = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") // GenString generates a random string of length n. -func GenString(rng *rand.Rand, n int) string { +func GenString(rng Rng, n int) string { s := make([]rune, n) for i := range n { - s[i] = alphanum[rand.Intn(len(alphanum))] + s[i] = alphanum[rng.Intn(len(alphanum))] } return string(s) } +// Shuffle reorders the elements of s uniformly at random. +func Shuffle[T any](rng Rng, s []T) { + for i := 1; i < len(s); i += 1 { + j := rng.Intn(i + 1) + s[i], s[j] = s[j], s[i] + } +} + // GenBytes generates a random byte slice. -func GenBytes(rng *rand.Rand, n int) []byte { +func GenBytes(rng Rng, n int) []byte { s := make([]byte, n) - _, _ = rng.Read(s) + for inner := range rng.inner.Lock() { + _, _ = inner.Read(s) + } return s } // GenF is a function which generates T. -type GenF[T any] = func(rng *rand.Rand) T +type GenF[T any] = func(rng Rng) T // GenSlice generates a slice of small random length. -func GenSlice[T any](rng *rand.Rand, gen GenF[T]) []T { +func GenSlice[T any](rng Rng, gen GenF[T]) []T { return GenSliceN(rng, 2+rng.Intn(3), gen) } // GenSliceN generates a slice of n elements. -func GenSliceN[T any](rng *rand.Rand, n int, gen GenF[T]) []T { +func GenSliceN[T any](rng Rng, n int, gen GenF[T]) []T { s := make([]T, n) for i := range s { s[i] = gen(rng) @@ -109,12 +187,12 @@ func GenSliceN[T any](rng *rand.Rand, n int, gen GenF[T]) []T { } // GenMap generates a map of small random length. -func GenMap[K comparable, V any](rng *rand.Rand, genK GenF[K], genV GenF[V]) map[K]V { +func GenMap[K comparable, V any](rng Rng, genK GenF[K], genV GenF[V]) map[K]V { return GenMapN(rng, 2+rng.Intn(3), genK, genV) } // GenMapN generates a map of n elements. -func GenMapN[K comparable, V any](rng *rand.Rand, n int, genK GenF[K], genV GenF[V]) map[K]V { +func GenMapN[K comparable, V any](rng Rng, n int, genK GenF[K], genV GenF[V]) map[K]V { m := make(map[K]V, n) for len(m) < n { m[genK(rng)] = genV(rng) @@ -123,19 +201,12 @@ func GenMapN[K comparable, V any](rng *rand.Rand, n int, genK GenF[K], genV GenF } // GenTimestamp generates a random timestamp. -func GenTimestamp(rng *rand.Rand) time.Time { +func GenTimestamp(rng Rng) time.Time { return time.Unix(0, rng.Int63()) } -// GenHash generates a random Hash. -func GenHash(rng *rand.Rand) Hash { - var h Hash - _, _ = rng.Read(h[:]) - return h -} - // Test tests whether reencoding a value is an identity operation. -func (c ProtoConv[T, P]) Test(want T) error { +func (c *ProtoConv[T, P]) Test(want T) error { p := c.Encode(want) raw, err := proto.Marshal(p) if err != nil { @@ -150,3 +221,14 @@ func (c ProtoConv[T, P]) Test(want T) error { } return TestDiff(want, got) } + +// IgnoreAfterCancel silently drops the error if the context is already canceled. +// Should be used for background tasks in tests, which cannot be guaranteed to exit gracefully. +// For example - if you have a tcp connection, then during cleanup one end will disconnect faster than the other, +// causing a race condition between context cancellation and disconnection error. +func IgnoreAfterCancel(ctx context.Context, err error) error { + if ctx.Err() != nil { + return nil + } + return err +} diff --git a/utils/wait.go b/utils/wait.go index 4c8c663..d1875be 100644 --- a/utils/wait.go +++ b/utils/wait.go @@ -4,6 +4,7 @@ import ( "context" "encoding" "errors" + "sync/atomic" "time" ) @@ -15,6 +16,18 @@ func IgnoreCancel(err error) error { return err } +// WithDeadline executes a function with a deadline. +// If deadline is none, it executes the function without a deadline. +func WithDeadline(ctx context.Context, md Option[time.Time], f func(ctx context.Context) error) error { + d, ok := md.Get() + if !ok { + return f(ctx) + } + ctx, cancel := context.WithDeadline(ctx, d) + defer cancel() + return f(ctx) +} + // WithTimeout executes a function with a timeout. func WithTimeout(ctx context.Context, d time.Duration, f func(ctx context.Context) error) error { ctx, cancel := context.WithTimeout(ctx, d) @@ -22,6 +35,14 @@ func WithTimeout(ctx context.Context, d time.Duration, f func(ctx context.Contex return f(ctx) } +// WithOptTimeout executes a function with a timeout. +func WithOptTimeout(ctx context.Context, d Option[time.Duration], f func(ctx context.Context) error) error { + if d, ok := d.Get(); ok { + return WithTimeout(ctx, d, f) + } + return f(ctx) +} + // WithTimeout1 executes a function with a timeout. func WithTimeout1[R any](ctx context.Context, d time.Duration, f func(ctx context.Context) (R, error)) (R, error) { ctx, cancel := context.WithTimeout(ctx, d) @@ -29,6 +50,14 @@ func WithTimeout1[R any](ctx context.Context, d time.Duration, f func(ctx contex return f(ctx) } +// WithOptTimeout1 executes a function with a timeout. +func WithOptTimeout1[R any](ctx context.Context, d Option[time.Duration], f func(ctx context.Context) (R, error)) (R, error) { + if d, ok := d.Get(); ok { + return WithTimeout1(ctx, d, f) + } + return f(ctx) +} + // Sleep sleeps for a duration or until the context is canceled. func Sleep(ctx context.Context, d time.Duration) error { select { @@ -117,3 +146,27 @@ func (d Duration) Duration() time.Duration { func (d Duration) Seconds() float64 { return time.Duration(d).Seconds() } + +// Once is an idempotent signal. +type Once struct { + _ NoCopy + ch chan struct{} + done atomic.Bool +} + +func NewOnce() (o Once) { + o.ch = make(chan struct{}) + return +} + +func (o *Once) Send() { + if o.done.Swap(true) { + return + } + close(o.ch) +} + +func (o *Once) Recv(ctx context.Context) error { + _, _, err := RecvOrClosed(ctx, o.ch) + return err +} diff --git a/utils/wait_test.go b/utils/wait_test.go index 0331ae9..91edc12 100644 --- a/utils/wait_test.go +++ b/utils/wait_test.go @@ -4,16 +4,20 @@ import ( "encoding/json" "testing" "time" - - "github.com/stretchr/testify/require" ) func TestJSON(t *testing.T) { var got, want struct{ X Duration } want.X = Duration(100 * time.Millisecond) j, err := json.Marshal(want) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } t.Logf("%s", j) - require.NoError(t, json.Unmarshal(j, &got)) - require.NoError(t, TestDiff(want, got)) + if err := json.Unmarshal(j, &got); err != nil { + t.Fatal(err) + } + if err := TestDiff(want, got); err != nil { + t.Fatal(err) + } }