From 893034ba2806dce99b4a6da17170da7247902bf1 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Fri, 22 May 2026 09:23:39 +0200 Subject: [PATCH 1/8] [raft] change transact signature --- cmds/db-manager/cleanup/evict.go | 5 +++-- pkg/raftstore/store.go | 5 +++-- pkg/rid/application/application_test.go | 5 +++-- pkg/rid/application/isa.go | 7 ++++--- pkg/rid/application/subscription.go | 7 ++++--- pkg/rid/store/sqlstore/store_test.go | 9 +++++---- pkg/scd/constraints_handler.go | 9 +++++---- pkg/scd/operational_intents_handler.go | 9 +++++---- pkg/scd/subscriptions_handler.go | 9 +++++---- pkg/scd/uss_availability_handler.go | 5 +++-- pkg/sqlstore/store.go | 6 +++--- pkg/store/store.go | 10 +++++++++- 12 files changed, 52 insertions(+), 34 deletions(-) diff --git a/cmds/db-manager/cleanup/evict.go b/cmds/db-manager/cleanup/evict.go index 572b90f5b..f8c144db9 100644 --- a/cmds/db-manager/cleanup/evict.go +++ b/cmds/db-manager/cleanup/evict.go @@ -14,6 +14,7 @@ import ( scdmodels "github.com/interuss/dss/pkg/scd/models" scdrepos "github.com/interuss/dss/pkg/scd/repos" scds "github.com/interuss/dss/pkg/scd/store" + "github.com/interuss/dss/pkg/store" "github.com/interuss/stacktrace" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -100,7 +101,7 @@ func evict(cmd *cobra.Command, _ []string) error { } return nil } - if err = scdStore.Transact(ctx, scdAction); err != nil { + if _, err = scdStore.Transact(ctx, store.Request{}, scdAction); err != nil { return fmt.Errorf("failed to execute SCD transaction: %w", err) } @@ -145,7 +146,7 @@ func evict(cmd *cobra.Command, _ []string) error { return nil } - if err = ridStore.Transact(ctx, ridAction); err != nil { + if _, err = ridStore.Transact(ctx, store.Request{}, ridAction); err != nil { return fmt.Errorf("failed to execute RID transaction: %w", err) } diff --git a/pkg/raftstore/store.go b/pkg/raftstore/store.go index ee2b40b51..7f0b8b2af 100644 --- a/pkg/raftstore/store.go +++ b/pkg/raftstore/store.go @@ -5,6 +5,7 @@ import ( "github.com/interuss/dss/pkg/raftstore/consensus" raftparams "github.com/interuss/dss/pkg/raftstore/params" + "github.com/interuss/dss/pkg/store" "github.com/interuss/stacktrace" "go.uber.org/zap" ) @@ -29,9 +30,9 @@ func Init[R any](ctx context.Context, logger *zap.Logger, params raftparams.Conn } // Transact proposes the entry to Raft and blocks until it is committed and applied. -func (s *Store[R]) Transact(ctx context.Context, f func(context.Context, R) error) error { +func (s *Store[R]) Transact(ctx context.Context, _ store.Request, f func(context.Context, R) error) (any, error) { // TODO: implement - return nil + return nil, nil } // Interact returns a repository that can be used to query the store without proposing a Raft entry. diff --git a/pkg/rid/application/application_test.go b/pkg/rid/application/application_test.go index 0cc3fd885..f4e099495 100644 --- a/pkg/rid/application/application_test.go +++ b/pkg/rid/application/application_test.go @@ -13,6 +13,7 @@ import ( dssql "github.com/interuss/dss/pkg/sql" "github.com/interuss/dss/pkg/sqlstore" "github.com/interuss/dss/pkg/sqlstore/params" + dssstore "github.com/interuss/dss/pkg/store" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -35,8 +36,8 @@ func (s *mockRepo) Interact(ctx context.Context) (repos.Repository, error) { return s, nil } -func (s *mockRepo) Transact(ctx context.Context, f func(ctx context.Context, repo repos.Repository) error) error { - return f(ctx, s) +func (s *mockRepo) Transact(ctx context.Context, _ dssstore.Request, f func(ctx context.Context, repo repos.Repository) error) (any, error) { + return nil, f(ctx, s) } func (s *mockRepo) Close() error { diff --git a/pkg/rid/application/isa.go b/pkg/rid/application/isa.go index ea402af60..8978e720b 100644 --- a/pkg/rid/application/isa.go +++ b/pkg/rid/application/isa.go @@ -10,6 +10,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" "github.com/interuss/dss/pkg/rid/repos" + "github.com/interuss/dss/pkg/store" "github.com/interuss/stacktrace" ) @@ -63,7 +64,7 @@ func (a *app) DeleteISA(ctx context.Context, id dssmodels.ID, owner dssmodels.Ow subs []*ridmodels.Subscription ) // The following will automatically retry TXN retry errors. - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := a.store.Transact(ctx, store.Request{}, func(ctx context.Context, repo repos.Repository) error { old, err := repo.GetISA(ctx, id, true) switch { case err != nil: @@ -104,7 +105,7 @@ func (a *app) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServic subs []*ridmodels.Subscription ) // The following will automatically retry TXN retry errors. - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := a.store.Transact(ctx, store.Request{}, func(ctx context.Context, repo repos.Repository) error { // ensure it doesn't exist yet old, err := repo.GetISA(ctx, isa.ID, false) if err != nil { @@ -138,7 +139,7 @@ func (a *app) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServic subs []*ridmodels.Subscription ) // The following will automatically retry TXN retry errors. - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := a.store.Transact(ctx, store.Request{}, func(ctx context.Context, repo repos.Repository) error { var err error old, err := repo.GetISA(ctx, isa.ID, true) diff --git a/pkg/rid/application/subscription.go b/pkg/rid/application/subscription.go index 33ba2af36..a85ba25ca 100644 --- a/pkg/rid/application/subscription.go +++ b/pkg/rid/application/subscription.go @@ -8,6 +8,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" "github.com/interuss/dss/pkg/rid/repos" + "github.com/interuss/dss/pkg/store" "github.com/interuss/stacktrace" "go.uber.org/zap" ) @@ -60,7 +61,7 @@ func (a *app) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) return nil, stacktrace.Propagate(err, "Unable to adjust time range") } var sub *ridmodels.Subscription - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := a.store.Transact(ctx, store.Request{}, func(ctx context.Context, repo repos.Repository) error { // ensure it doesn't exist yet old, err := repo.GetSubscription(ctx, s.ID) @@ -98,7 +99,7 @@ func (a *app) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) func (a *app) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { var sub *ridmodels.Subscription - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := a.store.Transact(ctx, store.Request{}, func(ctx context.Context, repo repos.Repository) error { old, err := repo.GetSubscription(ctx, s.ID) switch { case err != nil: @@ -145,7 +146,7 @@ func (a *app) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) // DeleteSubscription deletes the Subscription identified by "id" and owned by "owner". func (a *app) DeleteSubscription(ctx context.Context, id dssmodels.ID, owner dssmodels.Owner, version *dssmodels.Version) (*ridmodels.Subscription, error) { var ret *ridmodels.Subscription - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := a.store.Transact(ctx, store.Request{}, func(ctx context.Context, repo repos.Repository) error { var err error old, err := repo.GetSubscription(ctx, id) switch { diff --git a/pkg/rid/store/sqlstore/store_test.go b/pkg/rid/store/sqlstore/store_test.go index 275a1a332..bd19dfd29 100644 --- a/pkg/rid/store/sqlstore/store_test.go +++ b/pkg/rid/store/sqlstore/store_test.go @@ -13,6 +13,7 @@ import ( "github.com/interuss/dss/pkg/rid/repos" "github.com/interuss/dss/pkg/sqlstore" "github.com/interuss/dss/pkg/sqlstore/params" + dssstore "github.com/interuss/dss/pkg/store" "github.com/jackc/pgx/v5/pgconn" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -97,7 +98,7 @@ func TestTxnRetrier(t *testing.T) { require.NotNil(t, store) defer tearDownStore() - err := store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := store.Transact(ctx, dssstore.Request{}, func(ctx context.Context, repo repos.Repository) error { // can query within this isa, err := repo.InsertISA(ctx, serviceArea) require.NotNil(t, isa) @@ -118,7 +119,7 @@ func TestTxnRetrier(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) defer cancel() count := 0 - err = store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err = store.Transact(ctx, dssstore.Request{}, func(ctx context.Context, repo repos.Repository) error { // can query within this count++ // Postgre retryable error @@ -141,13 +142,13 @@ func TestTransactor(t *testing.T) { subscription2 := subscriptionsPool[1].input txnCount := 0 - err := store.Transact(ctx, func(ctx context.Context, s1 repos.Repository) error { + _, err := store.Transact(ctx, dssstore.Request{}, func(ctx context.Context, s1 repos.Repository) error { // We should get to this retry, then return nothing. if txnCount > 0 { return errors.New("already failed") } txnCount++ - err := store.Transact(ctx, func(ctx context.Context, s2 repos.Repository) error { + _, err := store.Transact(ctx, dssstore.Request{}, func(ctx context.Context, s2 repos.Repository) error { subs, err := s1.SearchSubscriptions(ctx, subscription1.Cells) require.NoError(t, err) require.Len(t, subs, 0) diff --git a/pkg/scd/constraints_handler.go b/pkg/scd/constraints_handler.go index 0688176c7..fd4953c96 100644 --- a/pkg/scd/constraints_handler.go +++ b/pkg/scd/constraints_handler.go @@ -11,6 +11,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/dss/pkg/scd/repos" + "github.com/interuss/dss/pkg/store" "github.com/interuss/stacktrace" "github.com/jackc/pgx/v5" ) @@ -88,7 +89,7 @@ func (a *Server) DeleteConstraintReference(ctx context.Context, req *restapi.Del return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { err = stacktrace.Propagate(err, "Could not delete constraint") errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -147,7 +148,7 @@ func (a *Server) GetConstraintReference(ctx context.Context, req *restapi.GetCon return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { err = stacktrace.Propagate(err, "Could not get constraint") if stacktrace.GetCode(err) == dsserr.NotFound { @@ -306,7 +307,7 @@ func (a *Server) PutConstraintReference(ctx context.Context, manager string, ent return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { return nil, err // No need to Propagate this error as this is not a useful stacktrace line } @@ -434,7 +435,7 @@ func (a *Server) QueryConstraintReferences(ctx context.Context, req *restapi.Que return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { return restapi.QueryConstraintReferencesResponseSet{Response500: &api.InternalServerErrorBody{ ErrorMessage: *dsserr.Handle(ctx, stacktrace.Propagate(err, "Got an unexpected error"))}} diff --git a/pkg/scd/operational_intents_handler.go b/pkg/scd/operational_intents_handler.go index f611797d0..9b58b4192 100644 --- a/pkg/scd/operational_intents_handler.go +++ b/pkg/scd/operational_intents_handler.go @@ -13,6 +13,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/dss/pkg/scd/repos" + "github.com/interuss/dss/pkg/store" "github.com/interuss/stacktrace" ) @@ -165,7 +166,7 @@ func (a *Server) DeleteOperationalIntentReference(ctx context.Context, req *rest return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { err = stacktrace.Propagate(err, "Could not delete operational intent") errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -221,7 +222,7 @@ func (a *Server) GetOperationalIntentReference(ctx context.Context, req *restapi return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { err = stacktrace.Propagate(err, "Could not get operational intent") if stacktrace.GetCode(err) == dsserr.NotFound { @@ -288,7 +289,7 @@ func (a *Server) QueryOperationalIntentReferences(ctx context.Context, req *rest return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { err = stacktrace.Propagate(err, "Could not query operational intent") if stacktrace.GetCode(err) == dsserr.BadRequest { @@ -934,7 +935,7 @@ func (a *Server) upsertOperationalIntentReference(ctx context.Context, now time. return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { return nil, responseConflict, err // No need to Propagate this error as this is not a useful stacktrace line } diff --git a/pkg/scd/subscriptions_handler.go b/pkg/scd/subscriptions_handler.go index 98ab9fed2..60a83c498 100644 --- a/pkg/scd/subscriptions_handler.go +++ b/pkg/scd/subscriptions_handler.go @@ -12,6 +12,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/dss/pkg/scd/repos" + "github.com/interuss/dss/pkg/store" "github.com/interuss/stacktrace" "github.com/jonboulle/clockwork" ) @@ -276,7 +277,7 @@ func (a *Server) PutSubscription(ctx context.Context, manager string, subscripti return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { return nil, err // No need to Propagate this error as this is not a useful stacktrace line } @@ -340,7 +341,7 @@ func (a *Server) GetSubscription(ctx context.Context, req *restapi.GetSubscripti return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { err = stacktrace.Propagate(err, "Could not get subscription") errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -424,7 +425,7 @@ func (a *Server) QuerySubscriptions(ctx context.Context, req *restapi.QuerySubsc return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -517,7 +518,7 @@ func (a *Server) DeleteSubscription(ctx context.Context, req *restapi.DeleteSubs return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { err = stacktrace.Propagate(err, "Could not delete subscription") errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} diff --git a/pkg/scd/uss_availability_handler.go b/pkg/scd/uss_availability_handler.go index 517a9859a..ee6bd48ee 100644 --- a/pkg/scd/uss_availability_handler.go +++ b/pkg/scd/uss_availability_handler.go @@ -10,6 +10,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/dss/pkg/scd/repos" + "github.com/interuss/dss/pkg/store" "github.com/interuss/stacktrace" "github.com/jackc/pgx/v5" ) @@ -51,7 +52,7 @@ func (a *Server) GetUssAvailability(ctx context.Context, req *restapi.GetUssAvai return nil } - err := a.Store.Transact(ctx, action) + _, err := a.Store.Transact(ctx, store.Request{}, action) if err != nil { // In case of older DB versions where availability table doesn't exist if strings.Contains(err.Error(), "does not exist") { @@ -121,7 +122,7 @@ func (a *Server) SetUssAvailability(ctx context.Context, req *restapi.SetUssAvai } return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, store.Request{}, action) if err != nil { // In case of older DB versions where availability table doesn't exist if strings.Contains(err.Error(), "does not exist") { diff --git a/pkg/sqlstore/store.go b/pkg/sqlstore/store.go index 92437c190..a1e4ddac5 100644 --- a/pkg/sqlstore/store.go +++ b/pkg/sqlstore/store.go @@ -176,13 +176,13 @@ func (s *Store[R]) Interact(_ context.Context) (R, error) { return s.newRepo(s.Pool), nil } -func (s *Store[R]) Transact(ctx context.Context, f func(context.Context, R) error) error { +func (s *Store[R]) Transact(ctx context.Context, _ store.Request, f func(context.Context, R) error) (any, error) { if s.Version.Type == Yugabyte { - return s.transactYugabyte(ctx, f) + return nil, s.transactYugabyte(ctx, f) } ctx = crdb.WithMaxRetries(ctx, s.maxRetries) - return crdbpgx.ExecuteTx(ctx, s.Pool, pgx.TxOptions{IsoLevel: pgx.Serializable}, func(tx pgx.Tx) error { + return nil, crdbpgx.ExecuteTx(ctx, s.Pool, pgx.TxOptions{IsoLevel: pgx.Serializable}, func(tx pgx.Tx) error { return f(ctx, s.newRepo(tx)) }) } diff --git a/pkg/store/store.go b/pkg/store/store.go index e5e95c4b0..79aa1451f 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -7,6 +7,12 @@ import ( "github.com/interuss/stacktrace" ) +// Request carries the information needed by the Raftstore to handle a request. +type Request struct { + RequestType string + Payload any +} + // store.Store is the generic means to access and interact with any type of data backing the DSS // may ever use, by obtaining a means to perform R-specific (repo type) operations. type Store[R any] interface { @@ -16,7 +22,9 @@ type Store[R any] interface { Interact(context.Context) (R, error) // Attempt to apply the operations in f to the R Repo it is supplied. All operations performed // on the R Repo by f will be applied or rejected atomically. - Transact(ctx context.Context, f func(context.Context, R) error) error + // request is used by the Raftstore to build the proposal. + // The returned any is the proposal result. + Transact(ctx context.Context, request Request, f func(context.Context, R) error) (any, error) } const ( From 16c56e94c7a8d2d93604b4d92091ca74bcb82fa1 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Fri, 26 Jun 2026 15:19:51 +0200 Subject: [PATCH 2/8] [raft/store] add timestamp middleware --- cmds/core-service/main.go | 2 ++ pkg/timestamp/timestamp.go | 43 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 pkg/timestamp/timestamp.go diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index 5585346bb..2f4337b14 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -32,6 +32,7 @@ import ( scds "github.com/interuss/dss/pkg/scd/store" "github.com/interuss/dss/pkg/store" "github.com/interuss/dss/pkg/store/params" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/dss/pkg/version" "github.com/interuss/dss/pkg/versioning" "github.com/interuss/stacktrace" @@ -341,6 +342,7 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st handler = http.TimeoutHandler(handler, *timeout, "request timeout") handler = logging.HTTPMiddleware(logger, *dumpRequests, handler) handler = maxBodySizeMiddleware(1<<22, handler) // 4 MB + handler = timestamp.RequestTimestampMiddleware(handler) if *enableMetrics || *enableTracing { // We use the default settings; the APIRouter handler will override the span value accordingly, as it has more information. diff --git a/pkg/timestamp/timestamp.go b/pkg/timestamp/timestamp.go new file mode 100644 index 000000000..cdf047009 --- /dev/null +++ b/pkg/timestamp/timestamp.go @@ -0,0 +1,43 @@ +package timestamp + +import ( + "context" + "net/http" + "time" + + "github.com/interuss/stacktrace" +) + +type timestampKey struct{} + +// RequestTimestampFromContext returns the request timestamp from the context, or an error if not present or zero. +// The timestamp is set by the Middleware when a query is received then (on the receiver side) by the Raftstore when the query is applied. +// It is then used for deterministic execution of time-dependent queries. +func RequestTimestampFromContext(ctx context.Context) (time.Time, error) { + timestamp, ok := ctx.Value(timestampKey{}).(time.Time) + if !ok { + return time.Time{}, stacktrace.NewError("timestamp not found in context") + } + + if timestamp.IsZero() { + return time.Time{}, stacktrace.NewError("timestamp is zero") + } + + return timestamp, nil +} + +// WithRequestTimestamp returns a new context with the given timestamp. +func WithRequestTimestamp(ctx context.Context, timestamp time.Time) context.Context { + return context.WithValue(ctx, timestampKey{}, timestamp) +} + +// RequestTimestampMiddleware is an HTTP middleware that stamps each incoming +// request with its received time. This timestamp is later used as the +// timestamp of the Raft proposal, so that time-dependent queries +// execute deterministically across nodes and contexts (catchup / restart etc.). +func RequestTimestampMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := WithRequestTimestamp(r.Context(), time.Now()) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} From c2c055fb37e8847ad9f0607101dedcf5fd72410c Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Fri, 22 May 2026 09:23:39 +0200 Subject: [PATCH 3/8] [raft] add generic store --- pkg/aux_/store/raftstore/store.go | 20 ++++++- pkg/raftstore/consensus/proposal.go | 9 +++- pkg/raftstore/store.go | 84 +++++++++++++++++++++++------ pkg/rid/store/raftstore/store.go | 20 ++++++- pkg/scd/store/raftstore/store.go | 20 ++++++- 5 files changed, 132 insertions(+), 21 deletions(-) diff --git a/pkg/aux_/store/raftstore/store.go b/pkg/aux_/store/raftstore/store.go index d09162d87..2a96db742 100644 --- a/pkg/aux_/store/raftstore/store.go +++ b/pkg/aux_/store/raftstore/store.go @@ -5,7 +5,9 @@ import ( "github.com/interuss/dss/pkg/aux_/repos" auxraftparams "github.com/interuss/dss/pkg/aux_/store/raftstore/params" + dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/dss/pkg/raftstore" + "github.com/interuss/dss/pkg/raftstore/consensus" "github.com/interuss/stacktrace" "go.uber.org/zap" ) @@ -18,5 +20,21 @@ func Init(ctx context.Context, logger *zap.Logger) (*raftstore.Store[repos.Repos if err != nil { return nil, stacktrace.Propagate(err, "failed to get aux raft parameters") } - return raftstore.Init(ctx, logger.With(zap.String("service", "aux_")), params, func() repos.Repository { return &repo{} }) + return raftstore.Init(ctx, logger.With(zap.String("service", "aux_")), params, &repo{}) +} + +func (r *repo) GetRepo() repos.Repository { return r } + +func (r *repo) IsReadOnly(_ raftstore.RequestType) bool { return false } + +func (r *repo) GetSnapshot() ([]byte, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") +} + +func (r *repo) RestoreFromSnapshot([]byte) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") +} + +func (r *repo) Apply(_ context.Context, _ consensus.Proposal) (any, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") } diff --git a/pkg/raftstore/consensus/proposal.go b/pkg/raftstore/consensus/proposal.go index a5e837d82..a70ba5b03 100644 --- a/pkg/raftstore/consensus/proposal.go +++ b/pkg/raftstore/consensus/proposal.go @@ -7,6 +7,7 @@ import ( "time" "github.com/google/uuid" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/stacktrace" ) @@ -31,7 +32,11 @@ type Proposal struct { } func (c *Consensus) newProposal(ctx context.Context, requestType string, payload any, readOnly bool) (Proposal, error) { - // TODO - Fetch timestamp from context + timestamp, err := timestamp.RequestTimestampFromContext(ctx) + if err != nil || timestamp.IsZero() { + return Proposal{}, stacktrace.Propagate(err, "failed to get timestamp from context") + } + value, err := json.Marshal(payload) if err != nil { return Proposal{}, stacktrace.Propagate(err, "failed to serialize proposal payload") @@ -40,7 +45,7 @@ func (c *Consensus) newProposal(ctx context.Context, requestType string, payload return Proposal{ ID: uuid.NewString(), NodeID: c.nodeID, - Timestamp: time.Now().UTC(), + Timestamp: timestamp, RequestType: requestType, Value: value, ReadOnly: readOnly, diff --git a/pkg/raftstore/store.go b/pkg/raftstore/store.go index 7f0b8b2af..17166f4b6 100644 --- a/pkg/raftstore/store.go +++ b/pkg/raftstore/store.go @@ -3,45 +3,97 @@ package raftstore import ( "context" + "github.com/interuss/dss/pkg/logging" "github.com/interuss/dss/pkg/raftstore/consensus" raftparams "github.com/interuss/dss/pkg/raftstore/params" "github.com/interuss/dss/pkg/store" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/stacktrace" "go.uber.org/zap" ) +type RequestType = string + +type RaftRepo[R any] interface { + GetRepo() R + // Apply is called on every committed entry. The proposal must be applied atomically. + Apply(ctx context.Context, proposal consensus.Proposal) (any, error) + GetSnapshot() ([]byte, error) + RestoreFromSnapshot(data []byte) error + IsReadOnly(requestType RequestType) bool +} + type Store[R any] struct { - newRepo func() R - consensus *consensus.Consensus + logger *zap.Logger + + name string + raftRepo RaftRepo[R] + cancel context.CancelFunc + + Consensus *consensus.Consensus } -func Init[R any](ctx context.Context, logger *zap.Logger, params raftparams.ConnectParameters, newRepo func() R) (*Store[R], error) { +func Init[R any](ctx context.Context, logger *zap.Logger, params raftparams.ConnectParameters, r RaftRepo[R]) (*Store[R], error) { + ctx, cancel := context.WithCancel(ctx) + + store := &Store[R]{ + raftRepo: r, + logger: logging.WithValuesFromContext(ctx, logger), + cancel: cancel, + } commitC := make(chan consensus.EntryCommit) + go store.processCommits(ctx, commitC) + consensusInstance, err := consensus.NewConsensus(ctx, logger, params, func() ([]byte, error) { return nil, nil }, commitC) if err != nil { return nil, stacktrace.Propagate(err, "failed to initialize consensus") } - // TODO: start consumer goroutine reading from commitC - return &Store[R]{ - newRepo: newRepo, - consensus: consensusInstance, - }, nil + store.Consensus = consensusInstance + + return store, nil } -// Transact proposes the entry to Raft and blocks until it is committed and applied. -func (s *Store[R]) Transact(ctx context.Context, _ store.Request, f func(context.Context, R) error) (any, error) { - // TODO: implement - return nil, nil +// Transact passes the request to the consensus layer and blocks until it is committed and applied. +// The processCommits loop will call Apply on the proposal when it is committed. +func (s *Store[R]) Transact(ctx context.Context, request store.Request, _ func(context.Context, R) error) (any, error) { + return s.Consensus.HandleClientRequest(ctx, request.RequestType, request.Payload, s.raftRepo.IsReadOnly(request.RequestType)) } -// Interact returns a repository that can be used to query the store without proposing a Raft entry. func (s *Store[R]) Interact(_ context.Context) (R, error) { - return s.newRepo(), nil + return s.raftRepo.GetRepo(), nil } -// Close shuts down the consensus instance. +// Close shuts down the consensus instance and processCommits loop. func (s *Store[R]) Close() error { - // TODO: implement + s.Consensus.Stop(context.Background()) + s.cancel() return nil } + +// processCommits reads committed entries from the consensus layer and applies them via Apply. +func (s *Store[R]) processCommits(ctx context.Context, commitCh <-chan consensus.EntryCommit) { + for { + select { + case <-ctx.Done(): + s.logger.Info("stopping commit processing loop") + return + case commit, ok := <-commitCh: + if !ok { + s.logger.Info("commit channel closed, stopping commit processing loop") + return + } + + if commit.SnapshotData != nil { + if err := s.raftRepo.RestoreFromSnapshot(commit.SnapshotData); err != nil { + s.logger.Error("failed to restore from snapshot", zap.Error(err)) + } + continue + } + + ctx = timestamp.WithRequestTimestamp(ctx, commit.Prop.Timestamp) + result, err := s.raftRepo.Apply(ctx, commit.Prop) + commit.Done <- consensus.ProposalResult{Result: result, Error: err} + } + } +} diff --git a/pkg/rid/store/raftstore/store.go b/pkg/rid/store/raftstore/store.go index 578aabeac..ba6889504 100644 --- a/pkg/rid/store/raftstore/store.go +++ b/pkg/rid/store/raftstore/store.go @@ -3,7 +3,9 @@ package raftstore import ( "context" + dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/dss/pkg/raftstore" + "github.com/interuss/dss/pkg/raftstore/consensus" "github.com/interuss/dss/pkg/rid/repos" ridraftparams "github.com/interuss/dss/pkg/rid/store/raftstore/params" "github.com/interuss/stacktrace" @@ -18,5 +20,21 @@ func Init(ctx context.Context, logger *zap.Logger) (*raftstore.Store[repos.Repos if err != nil { return nil, stacktrace.Propagate(err, "failed to get rid raft parameters") } - return raftstore.Init(ctx, logger.With(zap.String("service", "rid")), params, func() repos.Repository { return &repo{} }) + return raftstore.Init(ctx, logger.With(zap.String("service", "rid")), params, &repo{}) +} + +func (r *repo) GetRepo() repos.Repository { return r } + +func (r *repo) IsReadOnly(_ raftstore.RequestType) bool { return false } + +func (r *repo) GetSnapshot() ([]byte, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") +} + +func (r *repo) RestoreFromSnapshot([]byte) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") +} + +func (r *repo) Apply(_ context.Context, _ consensus.Proposal) (any, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") } diff --git a/pkg/scd/store/raftstore/store.go b/pkg/scd/store/raftstore/store.go index a7eeb663a..48c59a886 100644 --- a/pkg/scd/store/raftstore/store.go +++ b/pkg/scd/store/raftstore/store.go @@ -3,7 +3,9 @@ package raftstore import ( "context" + dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/dss/pkg/raftstore" + "github.com/interuss/dss/pkg/raftstore/consensus" "github.com/interuss/dss/pkg/scd/repos" scdraftparams "github.com/interuss/dss/pkg/scd/store/raftstore/params" "github.com/interuss/stacktrace" @@ -18,5 +20,21 @@ func Init(ctx context.Context, logger *zap.Logger) (*raftstore.Store[repos.Repos if err != nil { return nil, stacktrace.Propagate(err, "failed to get scd raft parameters") } - return raftstore.Init(ctx, logger.With(zap.String("service", "scd")), params, func() repos.Repository { return &repo{} }) + return raftstore.Init(ctx, logger.With(zap.String("service", "scd")), params, &repo{}) +} + +func (r *repo) GetRepo() repos.Repository { return r } + +func (r *repo) IsReadOnly(_ raftstore.RequestType) bool { return false } + +func (r *repo) GetSnapshot() ([]byte, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") +} + +func (r *repo) RestoreFromSnapshot([]byte) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") +} + +func (r *repo) Apply(_ context.Context, _ consensus.Proposal) (any, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") } From 0bbc7e5b9eb4719399c3cb021f0bfdb875b35753 Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Mon, 29 Jun 2026 09:03:13 +0200 Subject: [PATCH 4/8] [raft] Add base memstore --- pkg/aux_/store/memstore/doc.go | 3 + pkg/aux_/store/memstore/dss.go | 25 ++++++++ pkg/aux_/store/memstore/store.go | 18 ++++++ pkg/aux_/store/store.go | 3 + pkg/memstore/store.go | 64 +++++++++++++++++++ pkg/rid/store/memstore/doc.go | 3 + .../memstore/identification_service_area.go | 40 ++++++++++++ pkg/rid/store/memstore/store.go | 18 ++++++ pkg/rid/store/memstore/subscriptions.go | 52 +++++++++++++++ pkg/rid/store/store.go | 3 + pkg/scd/store/memstore/availability.go | 18 ++++++ pkg/scd/store/memstore/constraints.go | 30 +++++++++ pkg/scd/store/memstore/doc.go | 3 + pkg/scd/store/memstore/operational_intents.go | 39 +++++++++++ pkg/scd/store/memstore/store.go | 18 ++++++ pkg/scd/store/memstore/subscriptions.go | 48 ++++++++++++++ pkg/scd/store/store.go | 3 + pkg/store/params/params.go | 2 + 18 files changed, 390 insertions(+) create mode 100644 pkg/aux_/store/memstore/doc.go create mode 100644 pkg/aux_/store/memstore/dss.go create mode 100644 pkg/aux_/store/memstore/store.go create mode 100644 pkg/memstore/store.go create mode 100644 pkg/rid/store/memstore/doc.go create mode 100644 pkg/rid/store/memstore/identification_service_area.go create mode 100644 pkg/rid/store/memstore/store.go create mode 100644 pkg/rid/store/memstore/subscriptions.go create mode 100644 pkg/scd/store/memstore/availability.go create mode 100644 pkg/scd/store/memstore/constraints.go create mode 100644 pkg/scd/store/memstore/doc.go create mode 100644 pkg/scd/store/memstore/operational_intents.go create mode 100644 pkg/scd/store/memstore/store.go create mode 100644 pkg/scd/store/memstore/subscriptions.go diff --git a/pkg/aux_/store/memstore/doc.go b/pkg/aux_/store/memstore/doc.go new file mode 100644 index 000000000..0fe2e87de --- /dev/null +++ b/pkg/aux_/store/memstore/doc.go @@ -0,0 +1,3 @@ +// Package aux_.store.memstore provides a full implementation of store.Store[aux_.repos.Repository] +// storing data in memory. It is meant to be used by raftstore. +package memstore diff --git a/pkg/aux_/store/memstore/dss.go b/pkg/aux_/store/memstore/dss.go new file mode 100644 index 000000000..38fa4d5bf --- /dev/null +++ b/pkg/aux_/store/memstore/dss.go @@ -0,0 +1,25 @@ +package memstore + +import ( + "context" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/stacktrace" +) + +func (r *repo) SaveOwnMetadata(_ context.Context, locality string, publicEndpoint string) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SaveOwnMetadata not implemented for memstore") +} + +func (r *repo) GetDSSMetadata(_ context.Context) ([]*auxmodels.DSSMetadata, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSMetadata not implemented for memstore") +} + +func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "RecordHeartbeat not implemented for memstore") +} + +func (r *repo) GetDSSAirspaceRepresentationID(_ context.Context) (string, error) { + return "", stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSAirspaceRepresentationID not implemented for memstore") +} diff --git a/pkg/aux_/store/memstore/store.go b/pkg/aux_/store/memstore/store.go new file mode 100644 index 000000000..c2b875f9b --- /dev/null +++ b/pkg/aux_/store/memstore/store.go @@ -0,0 +1,18 @@ +package memstore + +import ( + "context" + + "github.com/interuss/dss/pkg/aux_/repos" + "github.com/interuss/dss/pkg/memstore" + "go.uber.org/zap" +) + +// repo is a full implementation of aux_.repos.Repository for memory-based storage. +type repo struct{} + +func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { + return memstore.Init(ctx, logger, "aux_", &repo{}) +} + +func (r *repo) GetRepo() repos.Repository { return r } diff --git a/pkg/aux_/store/store.go b/pkg/aux_/store/store.go index 4a9049327..2532842a7 100644 --- a/pkg/aux_/store/store.go +++ b/pkg/aux_/store/store.go @@ -4,6 +4,7 @@ import ( "context" "github.com/interuss/dss/pkg/aux_/repos" + auxmemstore "github.com/interuss/dss/pkg/aux_/store/memstore" auxraftstore "github.com/interuss/dss/pkg/aux_/store/raftstore" auxsqlstore "github.com/interuss/dss/pkg/aux_/store/sqlstore" dssstore "github.com/interuss/dss/pkg/store" @@ -24,6 +25,8 @@ func Init(ctx context.Context, logger *zap.Logger, withCheckCron bool) (Store, e return auxsqlstore.Init(ctx, logger, withCheckCron) case params.RaftStoreType: return auxraftstore.Init(ctx, logger) + case params.MemStoreType: + return auxmemstore.Init(ctx, logger) default: return nil, stacktrace.NewError("Unsupported store type %q for aux", storeType) } diff --git a/pkg/memstore/store.go b/pkg/memstore/store.go new file mode 100644 index 000000000..9087b49de --- /dev/null +++ b/pkg/memstore/store.go @@ -0,0 +1,64 @@ +package memstore + +// Memstore is a special kind of store: +// Store instances store data in memory. There is no persistent storage. +// Store instances are a singleton. +// Repository usage is not thread-safe. +// +// As of now, it is made to be used by raftstorage. + +import ( + "context" + "sync" + + dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/dss/pkg/logging" + "github.com/interuss/stacktrace" + "go.uber.org/zap" +) + +type MemRepo[R any] interface { + GetRepo() R +} + +type Store[R any] struct { + logger *zap.Logger + + name string + memRepo MemRepo[R] +} + +var ( + stores = map[string]any{} + storesMu sync.Mutex +) + +func Init[R any](ctx context.Context, logger *zap.Logger, name string, r MemRepo[R]) (*Store[R], error) { + + storesMu.Lock() + defer storesMu.Unlock() + if s, ok := stores[name]; ok { + return s.(*Store[R]), nil + } + + store := &Store[R]{ + name: name, + logger: logging.WithValuesFromContext(ctx, logger), + memRepo: r, + } + + stores[name] = store + return store, nil +} + +func (s *Store[R]) Transact(ctx context.Context, requestType string, payload any, _ func(context.Context, R) error) (any, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "Transact not implemented for memstore") +} + +func (s *Store[R]) Interact(_ context.Context) (R, error) { + return s.memRepo.GetRepo(), nil +} + +func (s *Store[R]) Close() error { + return nil +} diff --git a/pkg/rid/store/memstore/doc.go b/pkg/rid/store/memstore/doc.go new file mode 100644 index 000000000..e42a19ad4 --- /dev/null +++ b/pkg/rid/store/memstore/doc.go @@ -0,0 +1,3 @@ +// Package rid.store.memstore provides a full implementation of store.Store[rid.repos.Repository] +// storing data in memory. It as meant to be used by raftstore. +package memstore diff --git a/pkg/rid/store/memstore/identification_service_area.go b/pkg/rid/store/memstore/identification_service_area.go new file mode 100644 index 000000000..3bd315fb7 --- /dev/null +++ b/pkg/rid/store/memstore/identification_service_area.go @@ -0,0 +1,40 @@ +package memstore + +import ( + "context" + "time" + + "github.com/golang/geo/s2" + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/stacktrace" +) + +func (r *repo) GetISA(_ context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetISA not implemented for memstore") +} + +func (r *repo) DeleteISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteISA not implemented for memstore") +} + +func (r *repo) InsertISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "InsertISA not implemented for memstore") +} + +func (r *repo) UpdateISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpdateISA not implemented for memstore") +} + +func (r *repo) SearchISAs(_ context.Context, cells s2.CellUnion, earliest *time.Time, latest *time.Time) ([]*ridmodels.IdentificationServiceArea, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchISAs not implemented for memstore") +} + +func (r *repo) ListExpiredISAs(_ context.Context, writer string, threshold time.Time) ([]*ridmodels.IdentificationServiceArea, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredISAs not implemented for memstore") +} + +func (r *repo) CountISAs(_ context.Context) (int64, error) { + return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountISAs not implemented for memstore") +} diff --git a/pkg/rid/store/memstore/store.go b/pkg/rid/store/memstore/store.go new file mode 100644 index 000000000..b50c2b169 --- /dev/null +++ b/pkg/rid/store/memstore/store.go @@ -0,0 +1,18 @@ +package memstore + +import ( + "context" + + "github.com/interuss/dss/pkg/memstore" + "github.com/interuss/dss/pkg/rid/repos" + "go.uber.org/zap" +) + +// repo is a full implementation of rid.repos.Repository for memory-based storage. +type repo struct{} + +func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { + return memstore.Init(ctx, logger, "rid", &repo{}) +} + +func (r *repo) GetRepo() repos.Repository { return r } diff --git a/pkg/rid/store/memstore/subscriptions.go b/pkg/rid/store/memstore/subscriptions.go new file mode 100644 index 000000000..ac744ab10 --- /dev/null +++ b/pkg/rid/store/memstore/subscriptions.go @@ -0,0 +1,52 @@ +package memstore + +import ( + "context" + "time" + + "github.com/golang/geo/s2" + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/stacktrace" +) + +func (r *repo) GetSubscription(_ context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetSubscription not implemented for memstore") +} + +func (r *repo) DeleteSubscription(_ context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteSubscription not implemented for memstore") +} + +func (r *repo) InsertSubscription(_ context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "InsertSubscription not implemented for memstore") +} + +func (r *repo) UpdateSubscription(_ context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpdateSubscription not implemented for memstore") +} + +func (r *repo) SearchSubscriptions(_ context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchSubscriptions not implemented for memstore") +} + +func (r *repo) SearchSubscriptionsByOwner(_ context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchSubscriptionsByOwner not implemented for memstore") +} + +func (r *repo) UpdateNotificationIdxsInCells(_ context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpdateNotificationIdxsInCells not implemented for memstore") +} + +func (r *repo) MaxSubscriptionCountInCellsByOwner(_ context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { + return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "MaxSubscriptionCountInCellsByOwner not implemented for memstore") +} + +func (r *repo) ListExpiredSubscriptions(_ context.Context, writer string, threshold time.Time) ([]*ridmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredSubscriptions not implemented for memstore") +} + +func (r *repo) CountSubscriptions(_ context.Context) (int64, error) { + return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountSubscriptions not implemented for memstore") +} diff --git a/pkg/rid/store/store.go b/pkg/rid/store/store.go index aa561a092..7f2dc6c20 100644 --- a/pkg/rid/store/store.go +++ b/pkg/rid/store/store.go @@ -4,6 +4,7 @@ import ( "context" "github.com/interuss/dss/pkg/rid/repos" + ridmemstore "github.com/interuss/dss/pkg/rid/store/memstore" ridraftstore "github.com/interuss/dss/pkg/rid/store/raftstore" ridsqlstore "github.com/interuss/dss/pkg/rid/store/sqlstore" dssstore "github.com/interuss/dss/pkg/store" @@ -24,6 +25,8 @@ func Init(ctx context.Context, logger *zap.Logger, withCheckCron bool) (Store, e return ridsqlstore.Init(ctx, logger, withCheckCron) case params.RaftStoreType: return ridraftstore.Init(ctx, logger) + case params.MemStoreType: + return ridmemstore.Init(ctx, logger) default: return nil, stacktrace.NewError("Unsupported store type %q for rid", storeType) } diff --git a/pkg/scd/store/memstore/availability.go b/pkg/scd/store/memstore/availability.go new file mode 100644 index 000000000..3579672e1 --- /dev/null +++ b/pkg/scd/store/memstore/availability.go @@ -0,0 +1,18 @@ +package memstore + +import ( + "context" + + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/stacktrace" +) + +func (r *repo) GetUssAvailability(_ context.Context, id dssmodels.Manager) (*scdmodels.UssAvailabilityStatus, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetUssAvailability not implemented for memstore") +} + +func (r *repo) UpsertUssAvailability(_ context.Context, ussa *scdmodels.UssAvailabilityStatus) (*scdmodels.UssAvailabilityStatus, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpsertUssAvailability not implemented for memstore") +} diff --git a/pkg/scd/store/memstore/constraints.go b/pkg/scd/store/memstore/constraints.go new file mode 100644 index 000000000..be3e46eee --- /dev/null +++ b/pkg/scd/store/memstore/constraints.go @@ -0,0 +1,30 @@ +package memstore + +import ( + "context" + + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/stacktrace" +) + +func (r *repo) SearchConstraints(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Constraint, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchConstraints not implemented for memstore") +} + +func (r *repo) GetConstraint(_ context.Context, id dssmodels.ID) (*scdmodels.Constraint, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetConstraint not implemented for memstore") +} + +func (r *repo) UpsertConstraint(_ context.Context, constraint *scdmodels.Constraint) (*scdmodels.Constraint, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpsertConstraint not implemented for memstore") +} + +func (r *repo) DeleteConstraint(_ context.Context, id dssmodels.ID) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteConstraint not implemented for memstore") +} + +func (r *repo) CountConstraints(_ context.Context) (int64, error) { + return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountConstraint not implemented for memstore") +} diff --git a/pkg/scd/store/memstore/doc.go b/pkg/scd/store/memstore/doc.go new file mode 100644 index 000000000..f86b85987 --- /dev/null +++ b/pkg/scd/store/memstore/doc.go @@ -0,0 +1,3 @@ +// Package scd.store.memstore provides a full implementation of store.Store[scd.repos.Repository] +// storing data in memory. It is meant to be used by raftstore. +package memstore diff --git a/pkg/scd/store/memstore/operational_intents.go b/pkg/scd/store/memstore/operational_intents.go new file mode 100644 index 000000000..b39f81793 --- /dev/null +++ b/pkg/scd/store/memstore/operational_intents.go @@ -0,0 +1,39 @@ +package memstore + +import ( + "context" + "time" + + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/stacktrace" +) + +func (r *repo) GetOperationalIntent(_ context.Context, id dssmodels.ID) (*scdmodels.OperationalIntent, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetOperationalIntent not implemented for memstore") +} + +func (r *repo) DeleteOperationalIntent(_ context.Context, id dssmodels.ID) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteOperationalIntent not implemented for memstore") +} + +func (r *repo) UpsertOperationalIntent(_ context.Context, operation *scdmodels.OperationalIntent) (*scdmodels.OperationalIntent, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpsertOperationalIntent not implemented for memstore") +} + +func (r *repo) SearchOperationalIntents(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.OperationalIntent, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchOperationalIntents not implemented for memstore") +} + +func (r *repo) GetDependentOperationalIntents(_ context.Context, subscriptionID dssmodels.ID) ([]dssmodels.ID, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDependentOperationalIntents not implemented for memstore") +} + +func (r *repo) ListExpiredOperationalIntents(_ context.Context, threshold time.Time) ([]*scdmodels.OperationalIntent, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredOperationalIntents not implemented for memstore") +} + +func (r *repo) CountOperationalIntents(_ context.Context) (int64, error) { + return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountOperationalIntents not implemented for memstore") +} diff --git a/pkg/scd/store/memstore/store.go b/pkg/scd/store/memstore/store.go new file mode 100644 index 000000000..45365614a --- /dev/null +++ b/pkg/scd/store/memstore/store.go @@ -0,0 +1,18 @@ +package memstore + +import ( + "context" + + "github.com/interuss/dss/pkg/memstore" + "github.com/interuss/dss/pkg/scd/repos" + "go.uber.org/zap" +) + +// repo is a full implementation of scd.repos.Repository for memory-based storage. +type repo struct{} + +func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { + return memstore.Init(ctx, logger, "scd", &repo{}) +} + +func (r *repo) GetRepo() repos.Repository { return r } diff --git a/pkg/scd/store/memstore/subscriptions.go b/pkg/scd/store/memstore/subscriptions.go new file mode 100644 index 000000000..c284e6df8 --- /dev/null +++ b/pkg/scd/store/memstore/subscriptions.go @@ -0,0 +1,48 @@ +package memstore + +import ( + "context" + "time" + + "github.com/golang/geo/s2" + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/stacktrace" +) + +func (r *repo) SearchSubscriptions(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchSubscriptions not implemented for memstore") +} + +func (r *repo) GetSubscription(_ context.Context, id dssmodels.ID) (*scdmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetSubscription not implemented for memstore") +} + +func (r *repo) UpsertSubscription(_ context.Context, sub *scdmodels.Subscription) (*scdmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpsertSubscription not implemented for memstore") +} + +func (r *repo) DeleteSubscription(_ context.Context, id dssmodels.ID) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteSubscription not implemented for memstore") +} + +func (r *repo) IncrementNotificationIndicesForOperationalIntents(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "IncrementNotificationIndicesForOperationalIntents not implemented for memstore") +} + +func (r *repo) IncrementNotificationIndicesForConstraints(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "IncrementNotificationIndicesForConstraints not implemented for memstore") +} + +func (r *repo) LockSubscriptionsOnCells(_ context.Context, cells s2.CellUnion, subscriptionIds []dssmodels.ID, startTime *time.Time, endTime *time.Time) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "LockSubscriptionsOnCells not implemented for memstore") +} + +func (r *repo) ListExpiredSubscriptions(_ context.Context, threshold time.Time) ([]*scdmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredSubscriptions not implemented for memstore") +} + +func (r *repo) CountSubscriptions(_ context.Context) (int64, error) { + return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountSubscriptions not implemented for memstore") +} diff --git a/pkg/scd/store/store.go b/pkg/scd/store/store.go index 23df63840..0cdc48b93 100644 --- a/pkg/scd/store/store.go +++ b/pkg/scd/store/store.go @@ -4,6 +4,7 @@ import ( "context" "github.com/interuss/dss/pkg/scd/repos" + scdmemstore "github.com/interuss/dss/pkg/scd/store/memstore" scdraftstore "github.com/interuss/dss/pkg/scd/store/raftstore" scdsqlstore "github.com/interuss/dss/pkg/scd/store/sqlstore" dssstore "github.com/interuss/dss/pkg/store" @@ -24,6 +25,8 @@ func Init(ctx context.Context, logger *zap.Logger, withCheckCron bool) (Store, e return scdsqlstore.Init(ctx, logger, withCheckCron) case params.RaftStoreType: return scdraftstore.Init(ctx, logger) + case params.MemStoreType: + return scdmemstore.Init(ctx, logger) default: return nil, stacktrace.NewError("Unsupported store type %q for scd", storeType) } diff --git a/pkg/store/params/params.go b/pkg/store/params/params.go index 80c5733e7..9b84f8bd9 100644 --- a/pkg/store/params/params.go +++ b/pkg/store/params/params.go @@ -8,6 +8,7 @@ import ( const ( RaftStoreType = "raft" SQLStoreType = "sql" + MemStoreType = "mem" ) type ( @@ -29,6 +30,7 @@ var ( ) func init() { + // NB: Memstore is not available here on purpose, as it is only to be used internally. flag.StringVar(&storeParameters.StoreType, "store_type", SQLStoreType, fmt.Sprintf("Store type. Use '%s' for CockroachDB/YugabyteDB and '%s' for Raft-based store.", SQLStoreType, RaftStoreType)) flag.BoolVar(&storeOptions.GlobalLock, "enable_scd_global_lock", false, "Experimental: Use a global lock when working with SCD subscriptions. Reduce global throughput but improve throughput with lot of subscriptions in the same areas.") flag.BoolVar(&storeOptions.TimeBasedNotificationIndex, "enable_time_based_notification_index", false, "Use a time-based notification index when working with RID and SCD subscriptions.") From 8a25c3bb87c933dd6aa7c1bd6037c7e6ee5318e1 Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Mon, 15 Jun 2026 15:07:26 +0200 Subject: [PATCH 5/8] [raft/memstore] Add aux memstore --- pkg/aux_/store/memstore/dss.go | 69 ++++++++++++++- pkg/aux_/store/memstore/dss_test.go | 125 ++++++++++++++++++++++++++++ pkg/aux_/store/memstore/store.go | 28 ++++++- 3 files changed, 216 insertions(+), 6 deletions(-) create mode 100644 pkg/aux_/store/memstore/dss_test.go diff --git a/pkg/aux_/store/memstore/dss.go b/pkg/aux_/store/memstore/dss.go index 38fa4d5bf..7198b3936 100644 --- a/pkg/aux_/store/memstore/dss.go +++ b/pkg/aux_/store/memstore/dss.go @@ -2,6 +2,8 @@ package memstore import ( "context" + "database/sql" + "time" auxmodels "github.com/interuss/dss/pkg/aux_/models" dsserr "github.com/interuss/dss/pkg/errors" @@ -9,17 +11,76 @@ import ( ) func (r *repo) SaveOwnMetadata(_ context.Context, locality string, publicEndpoint string) error { - return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SaveOwnMetadata not implemented for memstore") + if locality == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") + } + if publicEndpoint == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Public endpoint not set") + } + + r.participants[locality] = &participant{ + publicEndpoint: publicEndpoint, + updatedAt: time.Now(), + } + return nil } func (r *repo) GetDSSMetadata(_ context.Context) ([]*auxmodels.DSSMetadata, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSMetadata not implemented for memstore") + metadata := make([]*auxmodels.DSSMetadata, 0, len(r.participants)) + for locality, p := range r.participants { + updatedAt := p.updatedAt + m := &auxmodels.DSSMetadata{ + Locality: locality, + PublicEndpoint: p.publicEndpoint, + UpdatedAt: &updatedAt, + } + + // Find the latest heartbeat across all sources for this locality. + var latest auxmodels.Heartbeat + found := false + for key, hb := range r.heartbeats { + if key.locality != locality { + continue + } + if !found || hb.Timestamp.After(*latest.Timestamp) { + latest = hb + found = true + } + } + + if found { + m.LatestTimestamp.Source = sql.NullString{String: latest.Source, Valid: true} + m.LatestTimestamp.Timestamp = latest.Timestamp + m.LatestTimestamp.NextHeartbeatExpectedBefore = latest.NextHeartbeatExpectedBefore + m.LatestTimestamp.Reporter = sql.NullString{String: latest.Reporter, Valid: true} + } + + metadata = append(metadata, m) + } + return metadata, nil } func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) error { - return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "RecordHeartbeat not implemented for memstore") + if heartbeat.Locality == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") + } + if heartbeat.Source == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Source not set") + } + + if heartbeat.Timestamp == nil { + now := time.Now() + heartbeat.Timestamp = &now + } + + if heartbeat.NextHeartbeatExpectedBefore != nil && heartbeat.NextHeartbeatExpectedBefore.Before(*heartbeat.Timestamp) { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Cannot expect the timestamp of the next heartbeat before the timestamp of the new heartbeat") + } + + r.heartbeats[heartbeatKey{locality: heartbeat.Locality, source: heartbeat.Source}] = heartbeat + return nil } func (r *repo) GetDSSAirspaceRepresentationID(_ context.Context) (string, error) { - return "", stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSAirspaceRepresentationID not implemented for memstore") + return "", stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSAirspaceRepresentationID not implementable for memstore") } diff --git a/pkg/aux_/store/memstore/dss_test.go b/pkg/aux_/store/memstore/dss_test.go new file mode 100644 index 000000000..43563d27b --- /dev/null +++ b/pkg/aux_/store/memstore/dss_test.go @@ -0,0 +1,125 @@ +package memstore + +import ( + "context" + "testing" + "time" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/stacktrace" + "github.com/stretchr/testify/require" +) + +func TestSaveOwnMetadataValidation(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.SaveOwnMetadata(ctx, "", "https://example.com"))) + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.SaveOwnMetadata(ctx, "dss-1", ""))) +} + +func TestSaveOwnMetadataRoundTrip(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.Equal(t, "dss-1", md[0].Locality) + require.Equal(t, "https://example.com", md[0].PublicEndpoint) + require.NotNil(t, md[0].UpdatedAt) + + // No heartbeat recorded yet. + require.False(t, md[0].LatestTimestamp.Source.Valid) + require.Nil(t, md[0].LatestTimestamp.Timestamp) +} + +func TestSaveOwnMetadataUpsert(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://old.example.com")) + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://new.example.com")) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.Equal(t, "https://new.example.com", md[0].PublicEndpoint) +} + +func TestRecordHeartbeatValidation(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Source: "source1"}))) + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1"}))) + + ts := time.Now() + before := ts.Add(-time.Minute) + err := r.RecordHeartbeat(ctx, auxmodels.Heartbeat{ + Locality: "dss-1", + Source: "source1", + Timestamp: &ts, + NextHeartbeatExpectedBefore: &before, + }) + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(err)) +} + +func TestRecordHeartbeatDefaultsTimestamp(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1"})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Source.Valid) + require.NotNil(t, md[0].LatestTimestamp.Timestamp) +} + +func TestGetDSSMetadataPicksLatestHeartbeat(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + older := time.Now().Add(-time.Hour) + newer := time.Now() + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &older, Reporter: "uss1"})) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source2", Timestamp: &newer, Reporter: "uss2"})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Timestamp.Equal(newer)) + require.Equal(t, "source2", md[0].LatestTimestamp.Source.String) + require.Equal(t, "uss2", md[0].LatestTimestamp.Reporter.String) +} + +func TestGetDSSMetadataUpdatesHeartbeatPerSource(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + first := time.Now().Add(-time.Hour) + second := time.Now() + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &first})) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &second})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Timestamp.Equal(second)) +} diff --git a/pkg/aux_/store/memstore/store.go b/pkg/aux_/store/memstore/store.go index c2b875f9b..f94a1ffc2 100644 --- a/pkg/aux_/store/memstore/store.go +++ b/pkg/aux_/store/memstore/store.go @@ -2,17 +2,41 @@ package memstore import ( "context" + "time" + auxmodels "github.com/interuss/dss/pkg/aux_/models" "github.com/interuss/dss/pkg/aux_/repos" "github.com/interuss/dss/pkg/memstore" "go.uber.org/zap" ) // repo is a full implementation of aux_.repos.Repository for memory-based storage. -type repo struct{} +type repo struct { + // participants holds pool participants metadata, keyed by locality. + participants map[string]*participant + // heartbeats holds the latest heartbeat per (locality, source). + heartbeats map[heartbeatKey]auxmodels.Heartbeat +} + +type participant struct { + publicEndpoint string + updatedAt time.Time +} + +type heartbeatKey struct { + locality string + source string +} + +func newRepo() *repo { + return &repo{ + participants: map[string]*participant{}, + heartbeats: map[heartbeatKey]auxmodels.Heartbeat{}, + } +} func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { - return memstore.Init(ctx, logger, "aux_", &repo{}) + return memstore.Init(ctx, logger, "aux_", newRepo()) } func (r *repo) GetRepo() repos.Repository { return r } From a41e90612fdb89b6c7a950ab696a33c32c82ba1b Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Tue, 16 Jun 2026 11:05:14 +0200 Subject: [PATCH 6/8] [raft/memstore] Add snapshop capability to aux memstore --- pkg/aux_/store/memstore/dss.go | 22 ++++----- pkg/aux_/store/memstore/snapshot.go | 35 ++++++++++++++ pkg/aux_/store/memstore/snapshot_test.go | 59 ++++++++++++++++++++++++ pkg/aux_/store/memstore/store.go | 24 +++++++--- pkg/memstore/store.go | 2 + pkg/rid/store/memstore/snapshot.go | 13 ++++++ pkg/scd/store/memstore/snapshot.go | 13 ++++++ 7 files changed, 150 insertions(+), 18 deletions(-) create mode 100644 pkg/aux_/store/memstore/snapshot.go create mode 100644 pkg/aux_/store/memstore/snapshot_test.go create mode 100644 pkg/rid/store/memstore/snapshot.go create mode 100644 pkg/scd/store/memstore/snapshot.go diff --git a/pkg/aux_/store/memstore/dss.go b/pkg/aux_/store/memstore/dss.go index 7198b3936..79aeeee3a 100644 --- a/pkg/aux_/store/memstore/dss.go +++ b/pkg/aux_/store/memstore/dss.go @@ -18,28 +18,28 @@ func (r *repo) SaveOwnMetadata(_ context.Context, locality string, publicEndpoin return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Public endpoint not set") } - r.participants[locality] = &participant{ - publicEndpoint: publicEndpoint, - updatedAt: time.Now(), + r.state.Participants[locality] = &participant{ + PublicEndpoint: publicEndpoint, + UpdatedAt: time.Now().UTC(), } return nil } func (r *repo) GetDSSMetadata(_ context.Context) ([]*auxmodels.DSSMetadata, error) { - metadata := make([]*auxmodels.DSSMetadata, 0, len(r.participants)) - for locality, p := range r.participants { - updatedAt := p.updatedAt + metadata := make([]*auxmodels.DSSMetadata, 0, len(r.state.Participants)) + for locality, p := range r.state.Participants { + updatedAt := p.UpdatedAt m := &auxmodels.DSSMetadata{ Locality: locality, - PublicEndpoint: p.publicEndpoint, + PublicEndpoint: p.PublicEndpoint, UpdatedAt: &updatedAt, } // Find the latest heartbeat across all sources for this locality. var latest auxmodels.Heartbeat found := false - for key, hb := range r.heartbeats { - if key.locality != locality { + for key, hb := range r.state.Heartbeats { + if key.Locality != locality { continue } if !found || hb.Timestamp.After(*latest.Timestamp) { @@ -69,7 +69,7 @@ func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) } if heartbeat.Timestamp == nil { - now := time.Now() + now := time.Now().UTC() heartbeat.Timestamp = &now } @@ -77,7 +77,7 @@ func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Cannot expect the timestamp of the next heartbeat before the timestamp of the new heartbeat") } - r.heartbeats[heartbeatKey{locality: heartbeat.Locality, source: heartbeat.Source}] = heartbeat + r.state.Heartbeats[heartbeatKey{Locality: heartbeat.Locality, Source: heartbeat.Source}] = heartbeat return nil } diff --git a/pkg/aux_/store/memstore/snapshot.go b/pkg/aux_/store/memstore/snapshot.go new file mode 100644 index 000000000..15deb274b --- /dev/null +++ b/pkg/aux_/store/memstore/snapshot.go @@ -0,0 +1,35 @@ +package memstore + +import ( + "bytes" + "encoding/gob" + + "github.com/interuss/stacktrace" +) + +const snapshotVersion = 1 + +type snapshotEnvelope struct { + Version int + State state +} + +func (r *repo) GetSnapshot() ([]byte, error) { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion, State: r.state}); err != nil { + return nil, stacktrace.Propagate(err, "Failed to encode memstore snapshot") + } + return buf.Bytes(), nil +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + var env snapshotEnvelope + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&env); err != nil { + return stacktrace.Propagate(err, "Failed to decode memstore snapshot") + } + if env.Version != snapshotVersion { + return stacktrace.NewError("Unsupported memstore snapshot version %d, expected %d", env.Version, snapshotVersion) + } + r.state = env.State + return nil +} diff --git a/pkg/aux_/store/memstore/snapshot_test.go b/pkg/aux_/store/memstore/snapshot_test.go new file mode 100644 index 000000000..2085e2488 --- /dev/null +++ b/pkg/aux_/store/memstore/snapshot_test.go @@ -0,0 +1,59 @@ +package memstore + +import ( + "bytes" + "context" + "encoding/gob" + "testing" + "time" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + "github.com/stretchr/testify/require" +) + +func TestSnapshotRoundTrip(t *testing.T) { + ctx := context.Background() + src := newRepo() + require.NoError(t, src.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + ts := time.Now().UTC() + require.NoError(t, src.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source-1", Timestamp: &ts, Reporter: "uss-1"})) + + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := newRepo() + require.NoError(t, dst.RestoreFromSnapshot(data)) + + want, err := src.GetDSSMetadata(ctx) + require.NoError(t, err) + got, err := dst.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Equal(t, want, got) +} + +func TestRestoreFromSnapshotReplacesState(t *testing.T) { + ctx := context.Background() + src := newRepo() + require.NoError(t, src.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := newRepo() + require.NoError(t, dst.SaveOwnMetadata(ctx, "dss-2", "https://other.example.com")) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + md, err := dst.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Len(t, md, 1) + require.Equal(t, "dss-1", md[0].Locality) +} + +func TestRestoreFromSnapshotInvalidData(t *testing.T) { + require.Error(t, newRepo().RestoreFromSnapshot([]byte("random value that is definitely not valid"))) +} + +func TestRestoreFromSnapshotVersionMismatch(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion + 1})) + require.Error(t, newRepo().RestoreFromSnapshot(buf.Bytes())) +} diff --git a/pkg/aux_/store/memstore/store.go b/pkg/aux_/store/memstore/store.go index f94a1ffc2..52ed59d12 100644 --- a/pkg/aux_/store/memstore/store.go +++ b/pkg/aux_/store/memstore/store.go @@ -12,6 +12,15 @@ import ( // repo is a full implementation of aux_.repos.Repository for memory-based storage. type repo struct { + state state +} + +// state is the serializable in-memory state. +type state struct { + // Participants holds pool participants metadata, keyed by locality. + Participants map[string]*participant + // Heartbeats holds the latest heartbeat per (locality, source). + Heartbeats map[heartbeatKey]auxmodels.Heartbeat // participants holds pool participants metadata, keyed by locality. participants map[string]*participant // heartbeats holds the latest heartbeat per (locality, source). @@ -19,20 +28,21 @@ type repo struct { } type participant struct { - publicEndpoint string - updatedAt time.Time + PublicEndpoint string + UpdatedAt time.Time } type heartbeatKey struct { - locality string - source string + Locality string + Source string } func newRepo() *repo { return &repo{ - participants: map[string]*participant{}, - heartbeats: map[heartbeatKey]auxmodels.Heartbeat{}, - } + state: state{ + Participants: map[string]*participant{}, + Heartbeats: map[heartbeatKey]auxmodels.Heartbeat{}, + }} } func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { diff --git a/pkg/memstore/store.go b/pkg/memstore/store.go index 9087b49de..ecdcc6a99 100644 --- a/pkg/memstore/store.go +++ b/pkg/memstore/store.go @@ -19,6 +19,8 @@ import ( type MemRepo[R any] interface { GetRepo() R + GetSnapshot() ([]byte, error) + RestoreFromSnapshot([]byte) error } type Store[R any] struct { diff --git a/pkg/rid/store/memstore/snapshot.go b/pkg/rid/store/memstore/snapshot.go new file mode 100644 index 000000000..dea64e6a4 --- /dev/null +++ b/pkg/rid/store/memstore/snapshot.go @@ -0,0 +1,13 @@ +package memstore + +import ( + "github.com/interuss/stacktrace" +) + +func (r *repo) GetSnapshot() ([]byte, error) { + return nil, stacktrace.NewError("GetSnapshot not yet implemented for rid") +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + return stacktrace.NewError("RestoreFromSnapshot not yet implemented for rid") +} diff --git a/pkg/scd/store/memstore/snapshot.go b/pkg/scd/store/memstore/snapshot.go new file mode 100644 index 000000000..dea64e6a4 --- /dev/null +++ b/pkg/scd/store/memstore/snapshot.go @@ -0,0 +1,13 @@ +package memstore + +import ( + "github.com/interuss/stacktrace" +) + +func (r *repo) GetSnapshot() ([]byte, error) { + return nil, stacktrace.NewError("GetSnapshot not yet implemented for rid") +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + return stacktrace.NewError("RestoreFromSnapshot not yet implemented for rid") +} From 316e965d1b565270882064bbab7eb3456c399c6d Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Tue, 16 Jun 2026 11:06:01 +0200 Subject: [PATCH 7/8] [raft/memstore] Add rid memstore --- go.mod | 5 +- .../memstore/identification_service_area.go | 129 ++++++- .../identification_service_area_test.go | 321 ++++++++++++++++ pkg/rid/store/memstore/snapshot.go | 34 +- pkg/rid/store/memstore/snapshot_test.go | 82 ++++ pkg/rid/store/memstore/store.go | 124 +++++- pkg/rid/store/memstore/store_test.go | 49 +++ pkg/rid/store/memstore/subscriptions.go | 187 ++++++++- pkg/rid/store/memstore/subscriptions_test.go | 360 ++++++++++++++++++ 9 files changed, 1263 insertions(+), 28 deletions(-) create mode 100644 pkg/rid/store/memstore/identification_service_area_test.go create mode 100644 pkg/rid/store/memstore/snapshot_test.go create mode 100644 pkg/rid/store/memstore/store_test.go create mode 100644 pkg/rid/store/memstore/subscriptions_test.go diff --git a/go.mod b/go.mod index a8fa1da50..d0722521b 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/go-jose/go-jose/v4 v4.1.4 github.com/golang-jwt/jwt/v4 v4.5.2 github.com/golang/geo v0.0.0-20230421003525-6adc56603217 + github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/interuss/stacktrace v1.0.0 github.com/jackc/pgx/v5 v5.9.2 @@ -27,11 +28,13 @@ require ( go.opentelemetry.io/otel v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 go.opentelemetry.io/otel/exporters/prometheus v0.65.0 + go.opentelemetry.io/otel/metric v1.43.0 go.opentelemetry.io/otel/sdk v1.43.0 go.opentelemetry.io/otel/sdk/metric v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 go.uber.org/multierr v1.11.0 go.uber.org/zap v1.27.0 + golang.org/x/sync v0.20.0 ) require ( @@ -71,13 +74,11 @@ require ( go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect - go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect go.yaml.in/yaml/v2 v2.4.4 // indirect golang.org/x/crypto v0.49.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect - golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.36.0 // indirect golang.org/x/time v0.14.0 // indirect diff --git a/pkg/rid/store/memstore/identification_service_area.go b/pkg/rid/store/memstore/identification_service_area.go index 3bd315fb7..dde19d9ff 100644 --- a/pkg/rid/store/memstore/identification_service_area.go +++ b/pkg/rid/store/memstore/identification_service_area.go @@ -11,30 +11,141 @@ import ( "github.com/interuss/stacktrace" ) -func (r *repo) GetISA(_ context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetISA not implemented for memstore") +func isaRecordFromModel(isa *ridmodels.IdentificationServiceArea, updatedAt time.Time) *isaRecord { + return &isaRecord{ + ID: isa.ID, + URL: isa.URL, + Owner: isa.Owner, + Cells: cloneCells(isa.Cells), + StartTime: cloneTime(isa.StartTime), + EndTime: cloneTime(isa.EndTime), + AltitudeHi: cloneFloat32(isa.AltitudeHi), + AltitudeLo: cloneFloat32(isa.AltitudeLo), + Writer: isa.Writer, + UpdatedAt: updatedAt, + } } -func (r *repo) DeleteISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteISA not implemented for memstore") +// toModel rebuilds the ISA model +func (rec *isaRecord) toModel() *ridmodels.IdentificationServiceArea { + return &ridmodels.IdentificationServiceArea{ + ID: rec.ID, + URL: rec.URL, + Owner: rec.Owner, + Cells: cloneCells(rec.Cells), + StartTime: cloneTime(rec.StartTime), + EndTime: cloneTime(rec.EndTime), + Version: dssmodels.VersionFromTime(rec.UpdatedAt), + AltitudeHi: cloneFloat32(rec.AltitudeHi), + AltitudeLo: cloneFloat32(rec.AltitudeLo), + Writer: rec.Writer, + } +} + +func (r *repo) GetISA(_ context.Context, id dssmodels.ID, _ bool) (*ridmodels.IdentificationServiceArea, error) { + rec, ok := r.state.ISAs[id] + if !ok { + return nil, nil + } + return rec.toModel(), nil } func (r *repo) InsertISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "InsertISA not implemented for memstore") + if err := validateWriteData(isa.Cells, isa.StartTime, isa.EndTime); err != nil { + return nil, err + } + if _, ok := r.state.ISAs[isa.ID]; ok { + return nil, stacktrace.NewError("ISA with id %s already exists", isa.ID) + } + rec := isaRecordFromModel(isa, r.clock.Now()) + r.state.ISAs[isa.ID] = rec + return rec.toModel(), nil } func (r *repo) UpdateISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpdateISA not implemented for memstore") + if err := validateWriteData(isa.Cells, isa.StartTime, isa.EndTime); err != nil { + return nil, err + } + prev, ok := r.state.ISAs[isa.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(prev.UpdatedAt).Matches(isa.Version) { + return nil, nil + } + rec := isaRecordFromModel(isa, r.clock.Now()) + rec.Owner = prev.Owner + r.state.ISAs[isa.ID] = rec + return rec.toModel(), nil +} + +func (r *repo) DeleteISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + rec, ok := r.state.ISAs[isa.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(rec.UpdatedAt).Matches(isa.Version) { + return nil, nil + } + out := rec.toModel() + delete(r.state.ISAs, isa.ID) + return out, nil } func (r *repo) SearchISAs(_ context.Context, cells s2.CellUnion, earliest *time.Time, latest *time.Time) ([]*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchISAs not implemented for memstore") + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Missing cell IDs for query") + } + if earliest == nil { + return nil, stacktrace.NewError("Earliest start time is missing") + } + + want := cellSet(cells) + var out []*ridmodels.IdentificationServiceArea + for _, rec := range r.state.ISAs { + // ends_at >= earliest + if rec.EndTime == nil || rec.EndTime.Before(*earliest) { + continue + } + // COALESCE(starts_at <= latest, true) + if latest != nil && rec.StartTime != nil && rec.StartTime.After(*latest) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil } func (r *repo) ListExpiredISAs(_ context.Context, writer string, threshold time.Time) ([]*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredISAs not implemented for memstore") + var out []*ridmodels.IdentificationServiceArea + for _, rec := range r.state.ISAs { + // ends_at <= threshold + if rec.EndTime == nil || rec.EndTime.After(threshold) { + continue + } + if writer == "" { + if rec.Writer != "" { + continue + } + } else if rec.Writer != writer { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil } func (r *repo) CountISAs(_ context.Context) (int64, error) { - return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountISAs not implemented for memstore") + return int64(len(r.state.ISAs)), nil } diff --git a/pkg/rid/store/memstore/identification_service_area_test.go b/pkg/rid/store/memstore/identification_service_area_test.go new file mode 100644 index 000000000..70699e415 --- /dev/null +++ b/pkg/rid/store/memstore/identification_service_area_test.go @@ -0,0 +1,321 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/golang/geo/s2" + "github.com/google/uuid" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/rid/repos" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +var ( + // Ensure the struct conforms to the interface + _ repos.ISA = &repo{} + overflow = uint64(17106221850767130624) // face 5 L13 overflows + serviceArea = &ridmodels.IdentificationServiceArea{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner(uuid.New().String()), + URL: "https://no/place/like/home/for/flights", + StartTime: &startTime, + EndTime: &endTime, + Writer: writer, + Cells: s2.CellUnion{ + s2.CellID(uint64(overflow)), + s2.CellID(17106221850767130624), + }, + } +) + +func TestStoreSearchISAs(t *testing.T) { + ctx := context.Background() + cells := s2.CellUnion{ + s2.CellID(17106221850767130624), + s2.CellID(17106221885126868992), + s2.CellID(17106221919486607360), + s2.CellID(uint64(overflow)), + } + repo := setUpStore(t) + + isa := *serviceArea + isa.Cells = cells + saOut, err := repo.InsertISA(ctx, &isa) + require.NoError(t, err) + require.NotNil(t, saOut) + require.Equal(t, isa.ID, saOut.ID) + + for _, r := range []struct { + name string + cells s2.CellUnion + timestampMutator func(time.Time, time.Time) (*time.Time, *time.Time) + expectedLen int + }{ + { + name: "search for empty cell", + cells: s2.CellUnion{s2.CellID(17106221953846345728)}, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 0, + }, + { + name: "search for only one cell", + cells: s2.CellUnion{s2.CellID(17106221850767130624)}, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 1, + }, + { + name: "search for only one cell with high bit set", + cells: s2.CellUnion{s2.CellID(uint64(overflow))}, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 1, + }, + { + name: "search with nil ends_at", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 1, + }, + { + name: "search with exact timestamps", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, &end + }, + expectedLen: 1, + }, + { + name: "search with non-matching time span", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + var ( + offset = time.Duration(100 * time.Second) + earliest = end.Add(offset) + latest = end.Add(offset * 2) + ) + + return &earliest, &latest + }, + expectedLen: 0, + }, + { + name: "search with expanded time span", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + var ( + offset = time.Duration(100 * time.Second) + earliest = start.Add(-offset) + latest = end.Add(offset) + ) + + return &earliest, &latest + }, + expectedLen: 1, + }, + } { + t.Run(r.name, func(t *testing.T) { + earliest, latest := r.timestampMutator(*saOut.StartTime, *saOut.EndTime) + + serviceAreas, err := repo.SearchISAs(ctx, r.cells, earliest, latest) + require.NoError(t, err) + require.Len(t, serviceAreas, r.expectedLen) + }) + } +} + +func TestBadVersion(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + saOut1, err := repo.InsertISA(ctx, serviceArea) + require.NoError(t, err) + require.NotNil(t, saOut1) + + // Rewriting service area should fail + saOut2, err := repo.UpdateISA(ctx, serviceArea) + require.NoError(t, err) + require.Nil(t, saOut2) + + // Rewriting, but with the correct version should work. + newEndTime := saOut1.EndTime.Add(time.Minute) + saOut1.EndTime = &newEndTime + saOut3, err := repo.UpdateISA(ctx, saOut1) + require.NoError(t, err) + require.NotNil(t, saOut3) +} + +func TestStoreExpiredISA(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + saOut, err := repo.InsertISA(ctx, serviceArea) + require.NoError(t, err) + require.NotNil(t, saOut) + + // The ISA's endTime is one hour from now. + fakeClock.Advance(59 * time.Minute) + + // We should still be able to find the ISA by searching and by ID. + now := fakeClock.Now() + serviceAreas, err := repo.SearchISAs(ctx, serviceArea.Cells, &now, nil) + require.NoError(t, err) + require.Len(t, serviceAreas, 1) + + ret, err := repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, ret) + + // But now the ISA has expired. + fakeClock.Advance(2 * time.Minute) + now = fakeClock.Now() + + serviceAreas, err = repo.SearchISAs(ctx, serviceArea.Cells, &now, nil) + require.NoError(t, err) + require.Len(t, serviceAreas, 0) + + // A get should work even if it is expired. + ret, err = repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, ret) +} + +func TestStoreDeleteISAs(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + // Insert the ISA. + copy := *serviceArea + isa, err := repo.InsertISA(ctx, ©) + require.NoError(t, err) + require.NotNil(t, isa) + + // Delete the ISA. + // Ensure a fresh Get, then delete still updates the sub indexes + isa, err = repo.GetISA(ctx, isa.ID, false) + require.NoError(t, err) + + serviceAreaOut, err := repo.DeleteISA(ctx, isa) + require.NoError(t, err) + require.Equal(t, isa, serviceAreaOut) +} + +func TestStoreISAWithNoGeoData(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + endTime := fakeClock.Now().Add(24 * time.Hour) + sub := &ridmodels.IdentificationServiceArea{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner("original owner"), + EndTime: &endTime, + } + _, err := repo.InsertISA(ctx, sub) + require.Error(t, err) +} + +func TestListExpiredISAs(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert ISA with endtime 1 day from now + isa1 := *serviceArea + startTime := fakeClock.Now() + isa1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + isa1.EndTime = &endTime + saOut1, err := repo.InsertISA(ctx, &isa1) + require.NoError(t, err) + require.NotNil(t, saOut1) + + // Insert ISA with endtime to 30 minutes ago + isa2 := *serviceArea + startTime = fakeClock.Now().Add(-1 * time.Hour) + isa2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + isa2.EndTime = &endTime + isa2.ID = dssmodels.ID(uuid.New().String()) + saOut2, err := repo.InsertISA(ctx, &isa2) + require.NoError(t, err) + require.NotNil(t, saOut2) + + serviceAreas, err := repo.ListExpiredISAs(ctx, writer, fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, serviceAreas, 1) +} + +func TestListExpiredISAsWithEmptyWriter(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert ISA with endtime 1 day from now + isa1 := *serviceArea + startTime := fakeClock.Now() + isa1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + isa1.EndTime = &endTime + isa1.Writer = "" + saOut1, err := repo.InsertISA(ctx, &isa1) + require.NoError(t, err) + require.NotNil(t, saOut1) + + // Insert ISA with endtime to 30 minutes ago + isa2 := *serviceArea + startTime = fakeClock.Now().Add(-1 * time.Hour) + isa2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + isa2.EndTime = &endTime + isa2.ID = dssmodels.ID(uuid.New().String()) + isa2.Writer = "" + saOut2, err := repo.InsertISA(ctx, &isa2) + require.NoError(t, err) + require.NotNil(t, saOut2) + + serviceAreas, err := repo.ListExpiredISAs(ctx, "", fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, serviceAreas, 1) +} + +func TestStoreCountISAs(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + // Insert the ISA. + copy := *serviceArea + isa, err := repo.InsertISA(ctx, ©) + require.NoError(t, err) + require.NotNil(t, isa) + + //Cound should be one + count, err := repo.CountISAs(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(1)) + + // Delete the ISA. + // Ensure a fresh Get, then delete still updates the sub indexes + isa, err = repo.GetISA(ctx, isa.ID, false) + require.NoError(t, err) + + serviceAreaOut, err := repo.DeleteISA(ctx, isa) + require.NoError(t, err) + require.Equal(t, isa, serviceAreaOut) + + //Cound should be zero + count, err = repo.CountISAs(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(0)) +} diff --git a/pkg/rid/store/memstore/snapshot.go b/pkg/rid/store/memstore/snapshot.go index dea64e6a4..fe2aa2c70 100644 --- a/pkg/rid/store/memstore/snapshot.go +++ b/pkg/rid/store/memstore/snapshot.go @@ -1,13 +1,43 @@ package memstore import ( + "bytes" + "encoding/gob" + + dssmodels "github.com/interuss/dss/pkg/models" "github.com/interuss/stacktrace" ) +const snapshotVersion = 1 + +type snapshotEnvelope struct { + Version int + State state +} + func (r *repo) GetSnapshot() ([]byte, error) { - return nil, stacktrace.NewError("GetSnapshot not yet implemented for rid") + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion, State: r.state}); err != nil { + return nil, stacktrace.Propagate(err, "Failed to encode memstore snapshot") + } + return buf.Bytes(), nil } func (r *repo) RestoreFromSnapshot(data []byte) error { - return stacktrace.NewError("RestoreFromSnapshot not yet implemented for rid") + var env snapshotEnvelope + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&env); err != nil { + return stacktrace.Propagate(err, "Failed to decode memstore snapshot") + } + if env.Version != snapshotVersion { + return stacktrace.NewError("Unsupported memstore snapshot version %d, expected %d", env.Version, snapshotVersion) + } + r.state = env.State + // gob decodes an empty map as nil; re-initialize to keep the repo writable. + if r.state.ISAs == nil { + r.state.ISAs = map[dssmodels.ID]*isaRecord{} + } + if r.state.Subscriptions == nil { + r.state.Subscriptions = map[dssmodels.ID]*subscriptionRecord{} + } + return nil } diff --git a/pkg/rid/store/memstore/snapshot_test.go b/pkg/rid/store/memstore/snapshot_test.go new file mode 100644 index 000000000..a9764f812 --- /dev/null +++ b/pkg/rid/store/memstore/snapshot_test.go @@ -0,0 +1,82 @@ +package memstore + +import ( + "bytes" + "context" + "encoding/gob" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/interuss/dss/pkg/models" + "github.com/stretchr/testify/require" +) + +func TestSnapshotRoundTrip(t *testing.T) { + ctx := context.Background() + src := setUpStore(t) + _, err := src.InsertISA(ctx, serviceArea) + require.NoError(t, err) + _, err = src.InsertSubscription(ctx, subscriptionsPool[0].input) + require.NoError(t, err) + + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := setUpStore(t) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + wantISA, err := src.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + gotISA, err := dst.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + + if diff := cmp.Diff(wantISA, gotISA, cmpopts.EquateApproxTime(0), cmp.AllowUnexported(models.Version{})); diff != "" { + t.Errorf("IdentificationServiceArea mismatch (-want +got):\n%s", diff) + } + + wantSub, err := src.GetSubscription(ctx, subscriptionsPool[0].input.ID) + require.NoError(t, err) + gotSub, err := dst.GetSubscription(ctx, subscriptionsPool[0].input.ID) + require.NoError(t, err) + + if diff := cmp.Diff(wantSub, gotSub, cmpopts.EquateApproxTime(0), cmp.AllowUnexported(models.Version{})); diff != "" { + t.Errorf("Subscription mismatch (-want +got):\n%s", diff) + } +} + +func TestRestoreFromSnapshotReplacesState(t *testing.T) { + ctx := context.Background() + src := setUpStore(t) + _, err := src.InsertISA(ctx, serviceArea) + require.NoError(t, err) + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := setUpStore(t) + other := *serviceArea + other.ID = "00000000-0000-4000-8000-000000000002" + _, err = dst.InsertISA(ctx, &other) + require.NoError(t, err) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + count, err := dst.CountISAs(ctx) + require.NoError(t, err) + require.Equal(t, int64(1), count) + got, err := dst.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, got) + gone, err := dst.GetISA(ctx, other.ID, false) + require.NoError(t, err) + require.Nil(t, gone) +} + +func TestRestoreFromSnapshotInvalidData(t *testing.T) { + require.Error(t, setUpStore(t).RestoreFromSnapshot([]byte("random value that is definitely not valid"))) +} + +func TestRestoreFromSnapshotVersionMismatch(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion + 1})) + require.Error(t, setUpStore(t).RestoreFromSnapshot(buf.Bytes())) +} diff --git a/pkg/rid/store/memstore/store.go b/pkg/rid/store/memstore/store.go index b50c2b169..824ea4901 100644 --- a/pkg/rid/store/memstore/store.go +++ b/pkg/rid/store/memstore/store.go @@ -2,17 +2,137 @@ package memstore import ( "context" + "time" + "github.com/golang/geo/s2" + "github.com/interuss/dss/pkg/geo" "github.com/interuss/dss/pkg/memstore" + dssmodels "github.com/interuss/dss/pkg/models" "github.com/interuss/dss/pkg/rid/repos" + "github.com/interuss/stacktrace" + "github.com/jonboulle/clockwork" "go.uber.org/zap" ) // repo is a full implementation of rid.repos.Repository for memory-based storage. -type repo struct{} +type repo struct { + state state + clock clockwork.Clock +} + +// state is the serializable in-memory state. +type state struct { + // ISAs holds the stored ISAs keyed by ID. + ISAs map[dssmodels.ID]*isaRecord + // Subscriptions holds the stored subscriptions keyed by ID. + Subscriptions map[dssmodels.ID]*subscriptionRecord +} + +// isaRecord is the gob-serializable representation of an ISA. It intentionally +// stores only primitive fields: the model's Version is never persisted, it is +// derived from UpdatedAt on read. +type isaRecord struct { + ID dssmodels.ID + URL string + Owner dssmodels.Owner + Cells s2.CellUnion + StartTime *time.Time + EndTime *time.Time + AltitudeHi *float32 + AltitudeLo *float32 + Writer string + UpdatedAt time.Time +} + +// subscriptionRecord is the gob-serializable representation of a Subscription. +type subscriptionRecord struct { + ID dssmodels.ID + URL string + NotificationIndex int + Owner dssmodels.Owner + Cells s2.CellUnion + StartTime *time.Time + EndTime *time.Time + AltitudeHi *float32 + AltitudeLo *float32 + Writer string + UpdatedAt time.Time +} + +func newRepo() *repo { + r := &repo{clock: clockwork.NewRealClock()} + r.resetState() + return r +} + +func (r *repo) resetState() { + r.state = state{ + ISAs: map[dssmodels.ID]*isaRecord{}, + Subscriptions: map[dssmodels.ID]*subscriptionRecord{}, + } +} func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { - return memstore.Init(ctx, logger, "rid", &repo{}) + return memstore.Init(ctx, logger, "rid", newRepo()) } func (r *repo) GetRepo() repos.Repository { return r } + +// validateWriteData validate constraints on an ISA +func validateWriteData(cells s2.CellUnion, start, end *time.Time) error { + if len(cells) == 0 { + return stacktrace.NewError("At least one cell must be provided") + } + for _, c := range cells { + if err := geo.ValidateCell(c); err != nil { + return stacktrace.Propagate(err, "Error validating cell") + } + } + if start != nil && end != nil && !start.Before(*end) { + return stacktrace.NewError("Start time must be strictly before end time") + } + return nil +} + +// cellSet builds a lookup set from a cell union. +func cellSet(cells s2.CellUnion) map[s2.CellID]struct{} { + set := make(map[s2.CellID]struct{}, len(cells)) + for _, c := range cells { + set[c] = struct{}{} + } + return set +} + +// overlaps reports whether any cell is present in set (equivalent to the SQL +// "cells && $x" array-overlap operator). +func overlaps(cells s2.CellUnion, set map[s2.CellID]struct{}) bool { + for _, c := range cells { + if _, ok := set[c]; ok { + return true + } + } + return false +} + +func cloneCells(cells s2.CellUnion) s2.CellUnion { + if cells == nil { + return nil + } + return append(s2.CellUnion(nil), cells...) +} + +func cloneTime(t *time.Time) *time.Time { + if t == nil { + return nil + } + v := *t + return &v +} + +func cloneFloat32(f *float32) *float32 { + if f == nil { + return nil + } + v := *f + return &v +} diff --git a/pkg/rid/store/memstore/store_test.go b/pkg/rid/store/memstore/store_test.go new file mode 100644 index 000000000..05a78d0c6 --- /dev/null +++ b/pkg/rid/store/memstore/store_test.go @@ -0,0 +1,49 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +var ( + fakeClock = clockwork.NewFakeClock() + startTime = fakeClock.Now().UTC().Add(-time.Minute) + endTime = fakeClock.Now().UTC().Add(time.Hour) + writer = "writer" +) + +// setUpStore returns a fresh in-memory repo whose clock is the (reset) package +// fakeClock, so tests can advance time deterministically. +func setUpStore(t *testing.T) *repo { + t.Helper() + fakeClock = clockwork.NewFakeClock() + r := newRepo() + r.clock = fakeClock + return r +} + +func TestDatabaseEnsuresBeginsBeforeExpires(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + var ( + begins = time.Now().UTC() + expires = begins.Add(-5 * time.Minute) + ) + _, err := repo.InsertSubscription(ctx, &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "me-myself-and-i", + URL: "https://no/place/like/home", + NotificationIndex: 42, + StartTime: &begins, + EndTime: &expires, + }) + require.Error(t, err) +} diff --git a/pkg/rid/store/memstore/subscriptions.go b/pkg/rid/store/memstore/subscriptions.go index ac744ab10..addf17e9a 100644 --- a/pkg/rid/store/memstore/subscriptions.go +++ b/pkg/rid/store/memstore/subscriptions.go @@ -11,42 +11,203 @@ import ( "github.com/interuss/stacktrace" ) +func subRecordFromModel(s *ridmodels.Subscription, updatedAt time.Time) *subscriptionRecord { + return &subscriptionRecord{ + ID: s.ID, + URL: s.URL, + NotificationIndex: s.NotificationIndex, + Owner: s.Owner, + Cells: cloneCells(s.Cells), + StartTime: cloneTime(s.StartTime), + EndTime: cloneTime(s.EndTime), + AltitudeHi: cloneFloat32(s.AltitudeHi), + AltitudeLo: cloneFloat32(s.AltitudeLo), + Writer: s.Writer, + UpdatedAt: updatedAt, + } +} + +func (rec *subscriptionRecord) toModel() *ridmodels.Subscription { + return &ridmodels.Subscription{ + ID: rec.ID, + URL: rec.URL, + NotificationIndex: rec.NotificationIndex, + Owner: rec.Owner, + Cells: cloneCells(rec.Cells), + StartTime: cloneTime(rec.StartTime), + EndTime: cloneTime(rec.EndTime), + Version: dssmodels.VersionFromTime(rec.UpdatedAt), + AltitudeHi: cloneFloat32(rec.AltitudeHi), + AltitudeLo: cloneFloat32(rec.AltitudeLo), + Writer: rec.Writer, + } +} + func (r *repo) GetSubscription(_ context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetSubscription not implemented for memstore") + rec, ok := r.state.Subscriptions[id] + if !ok { + return nil, nil + } + return rec.toModel(), nil } -func (r *repo) DeleteSubscription(_ context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteSubscription not implemented for memstore") +func (r *repo) InsertSubscription(_ context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { + if err := validateWriteData(s.Cells, s.StartTime, s.EndTime); err != nil { + return nil, err + } + if _, ok := r.state.Subscriptions[s.ID]; ok { + return nil, stacktrace.NewError("Subscription with id %s already exists", s.ID) + } + rec := subRecordFromModel(s, r.clock.Now()) + r.state.Subscriptions[s.ID] = rec + return rec.toModel(), nil } -func (r *repo) InsertSubscription(_ context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "InsertSubscription not implemented for memstore") +func (r *repo) UpdateSubscription(_ context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { + if err := validateWriteData(s.Cells, s.StartTime, s.EndTime); err != nil { + return nil, err + } + prev, ok := r.state.Subscriptions[s.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(prev.UpdatedAt).Matches(s.Version) { + return nil, nil + } + rec := subRecordFromModel(s, r.clock.Now()) + rec.Owner = prev.Owner + r.state.Subscriptions[s.ID] = rec + return rec.toModel(), nil } -func (r *repo) UpdateSubscription(_ context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpdateSubscription not implemented for memstore") +func (r *repo) DeleteSubscription(_ context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { + rec, ok := r.state.Subscriptions[s.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(rec.UpdatedAt).Matches(s.Version) { + return nil, nil + } + out := rec.toModel() + delete(r.state.Subscriptions, s.ID) + return out, nil } func (r *repo) SearchSubscriptions(_ context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchSubscriptions not implemented for memstore") + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") + } + now := r.clock.Now() + want := cellSet(cells) + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil } func (r *repo) SearchSubscriptionsByOwner(_ context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchSubscriptionsByOwner not implemented for memstore") + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") + } + now := r.clock.Now() + want := cellSet(cells) + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + if rec.Owner != owner { + continue + } + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil } +// UpdateNotificationIdxsInCells increments the notification index for each +// subscription in the given cells. func (r *repo) UpdateNotificationIdxsInCells(_ context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpdateNotificationIdxsInCells not implemented for memstore") + now := r.clock.Now() + want := cellSet(cells) + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + rec.NotificationIndex++ + out = append(out, rec.toModel()) + } + return out, nil } func (r *repo) MaxSubscriptionCountInCellsByOwner(_ context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { - return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "MaxSubscriptionCountInCellsByOwner not implemented for memstore") + now := r.clock.Now() + want := cellSet(cells) + counts := make(map[s2.CellID]int, len(cells)) + for _, rec := range r.state.Subscriptions { + if rec.Owner != owner { + continue + } + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + for _, c := range rec.Cells { + if _, ok := want[c]; ok { + counts[c]++ + } + } + } + best := 0 + for _, n := range counts { + if n > best { + best = n + } + } + return best, nil } func (r *repo) ListExpiredSubscriptions(_ context.Context, writer string, threshold time.Time) ([]*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredSubscriptions not implemented for memstore") + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + // ends_at <= threshold + if rec.EndTime == nil || rec.EndTime.After(threshold) { + continue + } + if writer == "" { + if rec.Writer != "" { + continue + } + } else if rec.Writer != writer { + continue + } + out = append(out, rec.toModel()) + + // TODO: This miminc sqlstore inconsistency of not limiting results there, comparted to ISAs. Should it be normalized? + } + return out, nil } func (r *repo) CountSubscriptions(_ context.Context) (int64, error) { - return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountSubscriptions not implemented for memstore") + return int64(len(r.state.Subscriptions)), nil } diff --git a/pkg/rid/store/memstore/subscriptions_test.go b/pkg/rid/store/memstore/subscriptions_test.go new file mode 100644 index 000000000..6f7832379 --- /dev/null +++ b/pkg/rid/store/memstore/subscriptions_test.go @@ -0,0 +1,360 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/golang/geo/s2" + "github.com/google/uuid" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/rid/repos" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +var ( + // Ensure the struct conforms to the interface + _ repos.Subscription = &repo{} + subscriptionsPool = []struct { + name string + input *ridmodels.Subscription + }{ + { + name: "a subscription with startTime and endTime", + input: &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "myself", + URL: "https://no/place/like/home", + StartTime: &startTime, + EndTime: &endTime, + NotificationIndex: 42, + Writer: writer, + Cells: s2.CellUnion{ + s2.CellID(uint64(overflow)), + 12494535935418957824, + }, + }, + }, + { + name: "a subscription without startTime and with endTime", + input: &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "myself", + URL: "https://no/place/like/home", + EndTime: &endTime, + NotificationIndex: 42, + Cells: s2.CellUnion{ + 12494535935418957824, + }, + }, + }, + { + name: "a subscription without startTime and with endTime", + input: &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "me", + URL: "https://no/place/like/home", + StartTime: &startTime, + EndTime: &endTime, + NotificationIndex: 42, + Cells: s2.CellUnion{ + 12494535935418957824, + }, + }, + }, + } +) + +func TestStoreGetSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + sub2, err := repo.GetSubscription(ctx, sub1.ID) + require.NoError(t, err) + require.NotNil(t, sub2) + + require.Equal(t, *sub1, *sub2) + }) + } +} + +func TestStoreInsertSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + // Test changes without the version differing. + r2 := *sub1 + r2.URL = "new url" + sub2, err := repo.UpdateSubscription(ctx, &r2) + require.NoError(t, err) + require.NotNil(t, sub2) + require.Equal(t, "new url", sub2.URL) + + // Test it doesn't work when Version is nil. + r3 := *sub2 + r3.URL = "new url 2" + r3.Version = nil + sub3, err := repo.UpdateSubscription(ctx, &r3) + require.NoError(t, err) + require.Nil(t, sub3) + + // Bad version doesn't work. + r4 := *sub2 + r4.URL = "new url 3" + r4.Version = dssmodels.VersionFromTime(time.Now()) + sub4, err := repo.UpdateSubscription(ctx, &r4) + require.NoError(t, err) + require.Nil(t, sub4) + + sub5, err := repo.GetSubscription(ctx, sub1.ID) + require.NoError(t, err) + require.NotNil(t, sub5) + + require.Equal(t, *sub2, *sub5) + }) + } +} + +func TestStoreDeleteSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + // Ensure mismatched versions returns nothing + sub1BadVersion := *sub1 + sub1BadVersion.Version, err = dssmodels.VersionFromString("a3cg3tcuhk00") + require.NoError(t, err) + sub2, err := repo.DeleteSubscription(ctx, &sub1BadVersion) + require.NoError(t, err) + require.Nil(t, sub2) + + sub4, err := repo.DeleteSubscription(ctx, sub1) + require.NoError(t, err) + require.NotNil(t, sub4) + + require.Equal(t, *sub1, *sub4) + }) + } +} + +func TestStoreSearchSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + var ( + // pick an L13 value that overflows. + overflow = uint64(17106221850767130624) + + cells = s2.CellUnion{ + s2.CellID(12494535935418957824), + s2.CellID(12494535866699481088), + s2.CellID(12494535901059219456), + s2.CellID(12494535866699481088), + s2.CellID(overflow), + } + owners = []dssmodels.Owner{ + "me", + "my", + "self", + } + ) + + for i, r := range subscriptionsPool { + subscription := *r.input + subscription.Owner = owners[i] + subscription.Cells = cells[:i+1] + sub1, err := repo.InsertSubscription(ctx, &subscription) + require.NoError(t, err) + require.NotNil(t, sub1) + } + // Test normal search + found, err := repo.SearchSubscriptions(ctx, cells) + require.NoError(t, err) + require.Len(t, found, 3) + for _, owner := range owners { + found, err := repo.SearchSubscriptionsByOwner(ctx, cells, owner) + require.NoError(t, err) + require.NotNil(t, found) + // We insert one subscription per owner. Hence, no matter how many cells are touched by the subscription, + // the result should always be 1. + require.Len(t, found, 1) + } +} + +func TestStoreExpiredSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + endTime := fakeClock.Now().Add(24 * time.Hour) + sub := &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner("original owner"), + Cells: s2.CellUnion{s2.CellID(12494535866699481088)}, + EndTime: &endTime, + } + _, err := repo.InsertSubscription(ctx, sub) + require.NoError(t, err) + + // The subscription's endTime is 24 hours from now. + fakeClock.Advance(23 * time.Hour) + + // We should still be able to find the subscription by searching and by ID. + subs, err := repo.SearchSubscriptionsByOwner(ctx, sub.Cells, "original owner") + require.NoError(t, err) + require.Len(t, subs, 1) + + ret, err := repo.GetSubscription(ctx, sub.ID) + require.NoError(t, err) + require.NotNil(t, &ret) + + // But now the subscription has expired. + fakeClock.Advance(2 * time.Hour) + + subs, err = repo.SearchSubscriptionsByOwner(ctx, sub.Cells, "original owner") + require.NoError(t, err) + require.Len(t, subs, 0) + + ret, err = repo.GetSubscription(ctx, sub.ID) + require.NotNil(t, ret) + require.NoError(t, err) +} + +func TestStoreSubscriptionWithNoGeoData(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + endTime := fakeClock.Now().Add(24 * time.Hour) + sub := &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner("original owner"), + EndTime: &endTime, + } + _, err := repo.InsertSubscription(ctx, sub) + require.Error(t, err) +} + +func TestMaxSubscriptionCountInCellsByOwner(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, s := range subscriptionsPool { + _, err := repo.InsertSubscription(ctx, s.input) + require.NoError(t, err) + } + + count, err := repo.MaxSubscriptionCountInCellsByOwner(ctx, s2.CellUnion{12494535935418957824}, "myself") + require.NoError(t, err) + require.Equal(t, 2, count) +} + +func TestListExpiredSubscriptions(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert Subscription with endtime 1 day from now + subscripiton1 := *subscriptionsPool[0].input + startTime := fakeClock.Now() + subscripiton1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + subscripiton1.EndTime = &endTime + subOut1, err := repo.InsertSubscription(ctx, &subscripiton1) + require.NoError(t, err) + require.NotNil(t, subOut1) + + // Insert Subscription with endtime to 30 minutes ago + subscripiton2 := *subscriptionsPool[0].input + startTime = fakeClock.Now().Add(-1 * time.Hour) + subscripiton2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + subscripiton2.EndTime = &endTime + subscripiton2.ID = dssmodels.ID(uuid.New().String()) + subOut2, err := repo.InsertSubscription(ctx, &subscripiton2) + require.NoError(t, err) + require.NotNil(t, subOut2) + + subscriptions, err := repo.ListExpiredSubscriptions(ctx, writer, fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, subscriptions, 1) +} + +func TestListExpiredSubscriptionsWithEmptyWriter(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert Subscription with endtime 1 day from now + subscripiton1 := *subscriptionsPool[0].input + startTime := fakeClock.Now() + subscripiton1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + subscripiton1.EndTime = &endTime + subscripiton1.Writer = "" + subOut1, err := repo.InsertSubscription(ctx, &subscripiton1) + require.NoError(t, err) + require.NotNil(t, subOut1) + + // Insert Subscription with endtime to 30 minutes ago + subscripiton2 := *subscriptionsPool[0].input + startTime = fakeClock.Now().Add(-1 * time.Hour) + subscripiton2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + subscripiton2.EndTime = &endTime + subscripiton2.ID = dssmodels.ID(uuid.New().String()) + subscripiton2.Writer = "" + subOut2, err := repo.InsertSubscription(ctx, &subscripiton2) + require.NoError(t, err) + require.NotNil(t, subOut2) + + subscriptions, err := repo.ListExpiredSubscriptions(ctx, "", fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, subscriptions, 1) +} + +func TestStoreCountSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + //Cound should be one + count, err := repo.CountSubscriptions(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(1)) + + sub4, err := repo.DeleteSubscription(ctx, sub1) + require.NoError(t, err) + require.NotNil(t, sub4) + + //Cound should be zero + count, err = repo.CountSubscriptions(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(0)) + }) + } +} From d5e48e1ea5a389465d31acbca4f22295a5bed3ba Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Tue, 16 Jun 2026 11:31:52 +0200 Subject: [PATCH 8/8] [raft/memstore] Use now from context --- pkg/aux_/store/memstore/dss.go | 10 ++++---- .../memstore/identification_service_area.go | 9 ++++--- pkg/rid/store/memstore/store.go | 4 +-- pkg/rid/store/memstore/store_test.go | 2 -- pkg/rid/store/memstore/subscriptions.go | 25 ++++++++++--------- pkg/rid/store/memstore/subscriptions_test.go | 3 +++ 6 files changed, 27 insertions(+), 26 deletions(-) diff --git a/pkg/aux_/store/memstore/dss.go b/pkg/aux_/store/memstore/dss.go index 79aeeee3a..4f19699a5 100644 --- a/pkg/aux_/store/memstore/dss.go +++ b/pkg/aux_/store/memstore/dss.go @@ -3,14 +3,14 @@ package memstore import ( "context" "database/sql" - "time" auxmodels "github.com/interuss/dss/pkg/aux_/models" dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/stacktrace" ) -func (r *repo) SaveOwnMetadata(_ context.Context, locality string, publicEndpoint string) error { +func (r *repo) SaveOwnMetadata(ctx context.Context, locality string, publicEndpoint string) error { if locality == "" { return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") } @@ -20,7 +20,7 @@ func (r *repo) SaveOwnMetadata(_ context.Context, locality string, publicEndpoin r.state.Participants[locality] = &participant{ PublicEndpoint: publicEndpoint, - UpdatedAt: time.Now().UTC(), + UpdatedAt: timestamp.NowFromContext(ctx), } return nil } @@ -60,7 +60,7 @@ func (r *repo) GetDSSMetadata(_ context.Context) ([]*auxmodels.DSSMetadata, erro return metadata, nil } -func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) error { +func (r *repo) RecordHeartbeat(ctx context.Context, heartbeat auxmodels.Heartbeat) error { if heartbeat.Locality == "" { return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") } @@ -69,7 +69,7 @@ func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) } if heartbeat.Timestamp == nil { - now := time.Now().UTC() + now := timestamp.NowFromContext(ctx) heartbeat.Timestamp = &now } diff --git a/pkg/rid/store/memstore/identification_service_area.go b/pkg/rid/store/memstore/identification_service_area.go index dde19d9ff..29055426a 100644 --- a/pkg/rid/store/memstore/identification_service_area.go +++ b/pkg/rid/store/memstore/identification_service_area.go @@ -8,6 +8,7 @@ import ( dsserr "github.com/interuss/dss/pkg/errors" dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/stacktrace" ) @@ -50,19 +51,19 @@ func (r *repo) GetISA(_ context.Context, id dssmodels.ID, _ bool) (*ridmodels.Id return rec.toModel(), nil } -func (r *repo) InsertISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { +func (r *repo) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { if err := validateWriteData(isa.Cells, isa.StartTime, isa.EndTime); err != nil { return nil, err } if _, ok := r.state.ISAs[isa.ID]; ok { return nil, stacktrace.NewError("ISA with id %s already exists", isa.ID) } - rec := isaRecordFromModel(isa, r.clock.Now()) + rec := isaRecordFromModel(isa, timestamp.NowFromContext(ctx)) r.state.ISAs[isa.ID] = rec return rec.toModel(), nil } -func (r *repo) UpdateISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { +func (r *repo) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { if err := validateWriteData(isa.Cells, isa.StartTime, isa.EndTime); err != nil { return nil, err } @@ -73,7 +74,7 @@ func (r *repo) UpdateISA(_ context.Context, isa *ridmodels.IdentificationService if !dssmodels.VersionFromTime(prev.UpdatedAt).Matches(isa.Version) { return nil, nil } - rec := isaRecordFromModel(isa, r.clock.Now()) + rec := isaRecordFromModel(isa, timestamp.NowFromContext(ctx)) rec.Owner = prev.Owner r.state.ISAs[isa.ID] = rec return rec.toModel(), nil diff --git a/pkg/rid/store/memstore/store.go b/pkg/rid/store/memstore/store.go index 824ea4901..7cd02e161 100644 --- a/pkg/rid/store/memstore/store.go +++ b/pkg/rid/store/memstore/store.go @@ -10,14 +10,12 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" "github.com/interuss/dss/pkg/rid/repos" "github.com/interuss/stacktrace" - "github.com/jonboulle/clockwork" "go.uber.org/zap" ) // repo is a full implementation of rid.repos.Repository for memory-based storage. type repo struct { state state - clock clockwork.Clock } // state is the serializable in-memory state. @@ -60,7 +58,7 @@ type subscriptionRecord struct { } func newRepo() *repo { - r := &repo{clock: clockwork.NewRealClock()} + r := &repo{} r.resetState() return r } diff --git a/pkg/rid/store/memstore/store_test.go b/pkg/rid/store/memstore/store_test.go index 05a78d0c6..3bf484463 100644 --- a/pkg/rid/store/memstore/store_test.go +++ b/pkg/rid/store/memstore/store_test.go @@ -23,9 +23,7 @@ var ( // fakeClock, so tests can advance time deterministically. func setUpStore(t *testing.T) *repo { t.Helper() - fakeClock = clockwork.NewFakeClock() r := newRepo() - r.clock = fakeClock return r } diff --git a/pkg/rid/store/memstore/subscriptions.go b/pkg/rid/store/memstore/subscriptions.go index addf17e9a..5923c3927 100644 --- a/pkg/rid/store/memstore/subscriptions.go +++ b/pkg/rid/store/memstore/subscriptions.go @@ -8,6 +8,7 @@ import ( dsserr "github.com/interuss/dss/pkg/errors" dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/stacktrace" ) @@ -51,19 +52,19 @@ func (r *repo) GetSubscription(_ context.Context, id dssmodels.ID) (*ridmodels.S return rec.toModel(), nil } -func (r *repo) InsertSubscription(_ context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { +func (r *repo) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { if err := validateWriteData(s.Cells, s.StartTime, s.EndTime); err != nil { return nil, err } if _, ok := r.state.Subscriptions[s.ID]; ok { return nil, stacktrace.NewError("Subscription with id %s already exists", s.ID) } - rec := subRecordFromModel(s, r.clock.Now()) + rec := subRecordFromModel(s, timestamp.NowFromContext(ctx)) r.state.Subscriptions[s.ID] = rec return rec.toModel(), nil } -func (r *repo) UpdateSubscription(_ context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { +func (r *repo) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { if err := validateWriteData(s.Cells, s.StartTime, s.EndTime); err != nil { return nil, err } @@ -74,7 +75,7 @@ func (r *repo) UpdateSubscription(_ context.Context, s *ridmodels.Subscription) if !dssmodels.VersionFromTime(prev.UpdatedAt).Matches(s.Version) { return nil, nil } - rec := subRecordFromModel(s, r.clock.Now()) + rec := subRecordFromModel(s, timestamp.NowFromContext(ctx)) rec.Owner = prev.Owner r.state.Subscriptions[s.ID] = rec return rec.toModel(), nil @@ -93,11 +94,11 @@ func (r *repo) DeleteSubscription(_ context.Context, s *ridmodels.Subscription) return out, nil } -func (r *repo) SearchSubscriptions(_ context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { +func (r *repo) SearchSubscriptions(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { if len(cells) == 0 { return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") } - now := r.clock.Now() + now := timestamp.NowFromContext(ctx) want := cellSet(cells) var out []*ridmodels.Subscription for _, rec := range r.state.Subscriptions { @@ -116,11 +117,11 @@ func (r *repo) SearchSubscriptions(_ context.Context, cells s2.CellUnion) ([]*ri return out, nil } -func (r *repo) SearchSubscriptionsByOwner(_ context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { +func (r *repo) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { if len(cells) == 0 { return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") } - now := r.clock.Now() + now := timestamp.NowFromContext(ctx) want := cellSet(cells) var out []*ridmodels.Subscription for _, rec := range r.state.Subscriptions { @@ -144,8 +145,8 @@ func (r *repo) SearchSubscriptionsByOwner(_ context.Context, cells s2.CellUnion, // UpdateNotificationIdxsInCells increments the notification index for each // subscription in the given cells. -func (r *repo) UpdateNotificationIdxsInCells(_ context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { - now := r.clock.Now() +func (r *repo) UpdateNotificationIdxsInCells(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { + now := timestamp.NowFromContext(ctx) want := cellSet(cells) var out []*ridmodels.Subscription for _, rec := range r.state.Subscriptions { @@ -161,8 +162,8 @@ func (r *repo) UpdateNotificationIdxsInCells(_ context.Context, cells s2.CellUni return out, nil } -func (r *repo) MaxSubscriptionCountInCellsByOwner(_ context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { - now := r.clock.Now() +func (r *repo) MaxSubscriptionCountInCellsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { + now := timestamp.NowFromContext(ctx) want := cellSet(cells) counts := make(map[s2.CellID]int, len(cells)) for _, rec := range r.state.Subscriptions { diff --git a/pkg/rid/store/memstore/subscriptions_test.go b/pkg/rid/store/memstore/subscriptions_test.go index 6f7832379..3b90cf0e0 100644 --- a/pkg/rid/store/memstore/subscriptions_test.go +++ b/pkg/rid/store/memstore/subscriptions_test.go @@ -10,6 +10,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" "github.com/interuss/dss/pkg/rid/repos" + "github.com/interuss/dss/pkg/timestamp" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" ) @@ -216,6 +217,7 @@ func TestStoreExpiredSubscription(t *testing.T) { // The subscription's endTime is 24 hours from now. fakeClock.Advance(23 * time.Hour) + ctx = timestamp.WithTimestamp(ctx, fakeClock.Now()) // We should still be able to find the subscription by searching and by ID. subs, err := repo.SearchSubscriptionsByOwner(ctx, sub.Cells, "original owner") @@ -228,6 +230,7 @@ func TestStoreExpiredSubscription(t *testing.T) { // But now the subscription has expired. fakeClock.Advance(2 * time.Hour) + ctx = timestamp.WithTimestamp(ctx, fakeClock.Now()) subs, err = repo.SearchSubscriptionsByOwner(ctx, sub.Cells, "original owner") require.NoError(t, err)