diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index 7931bf045..cfeccae75 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -32,6 +32,7 @@ import ( scds "github.com/interuss/dss/pkg/scd/store" "github.com/interuss/dss/pkg/store" "github.com/interuss/dss/pkg/store/params" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/dss/pkg/version" "github.com/interuss/dss/pkg/versioning" "github.com/interuss/stacktrace" @@ -340,6 +341,7 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st handler = authorizer.TokenMiddleware(handler) handler = http.TimeoutHandler(handler, *timeout, "request timeout") handler = logging.HTTPMiddleware(logger, *dumpRequests, handler) + handler = timestamp.Middleware(handler) if *enableMetrics || *enableTracing { // We use the default settings; the APIRouter handler will override the span value accordingly, as it has more information. diff --git a/cmds/db-manager/cleanup/evict.go b/cmds/db-manager/cleanup/evict.go index 9584ef620..e605cb4c8 100644 --- a/cmds/db-manager/cleanup/evict.go +++ b/cmds/db-manager/cleanup/evict.go @@ -100,7 +100,7 @@ func evict(cmd *cobra.Command, _ []string) error { } return nil } - if err = scdStore.Transact(ctx, scdAction); err != nil { + if _, err = scdStore.Transact(ctx, "", nil, scdAction); err != nil { return fmt.Errorf("failed to execute SCD transaction: %w", err) } @@ -145,7 +145,7 @@ func evict(cmd *cobra.Command, _ []string) error { return nil } - if err = ridStore.Transact(ctx, ridAction); err != nil { + if _, err = ridStore.Transact(ctx, "", nil, ridAction); err != nil { return fmt.Errorf("failed to execute RID transaction: %w", err) } diff --git a/go.mod b/go.mod index a8fa1da50..d0722521b 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/go-jose/go-jose/v4 v4.1.4 github.com/golang-jwt/jwt/v4 v4.5.2 github.com/golang/geo v0.0.0-20230421003525-6adc56603217 + github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/interuss/stacktrace v1.0.0 github.com/jackc/pgx/v5 v5.9.2 @@ -27,11 +28,13 @@ require ( go.opentelemetry.io/otel v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 go.opentelemetry.io/otel/exporters/prometheus v0.65.0 + go.opentelemetry.io/otel/metric v1.43.0 go.opentelemetry.io/otel/sdk v1.43.0 go.opentelemetry.io/otel/sdk/metric v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 go.uber.org/multierr v1.11.0 go.uber.org/zap v1.27.0 + golang.org/x/sync v0.20.0 ) require ( @@ -71,13 +74,11 @@ require ( go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect - go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect go.yaml.in/yaml/v2 v2.4.4 // indirect golang.org/x/crypto v0.49.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect - golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.36.0 // indirect golang.org/x/time v0.14.0 // indirect diff --git a/pkg/aux_/store/memstore/doc.go b/pkg/aux_/store/memstore/doc.go new file mode 100644 index 000000000..0fe2e87de --- /dev/null +++ b/pkg/aux_/store/memstore/doc.go @@ -0,0 +1,3 @@ +// Package aux_.store.memstore provides a full implementation of store.Store[aux_.repos.Repository] +// storing data in memory. It is meant to be used by raftstore. +package memstore diff --git a/pkg/aux_/store/memstore/dss.go b/pkg/aux_/store/memstore/dss.go new file mode 100644 index 000000000..4f19699a5 --- /dev/null +++ b/pkg/aux_/store/memstore/dss.go @@ -0,0 +1,86 @@ +package memstore + +import ( + "context" + "database/sql" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/dss/pkg/timestamp" + "github.com/interuss/stacktrace" +) + +func (r *repo) SaveOwnMetadata(ctx context.Context, locality string, publicEndpoint string) error { + if locality == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") + } + if publicEndpoint == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Public endpoint not set") + } + + r.state.Participants[locality] = &participant{ + PublicEndpoint: publicEndpoint, + UpdatedAt: timestamp.NowFromContext(ctx), + } + return nil +} + +func (r *repo) GetDSSMetadata(_ context.Context) ([]*auxmodels.DSSMetadata, error) { + metadata := make([]*auxmodels.DSSMetadata, 0, len(r.state.Participants)) + for locality, p := range r.state.Participants { + updatedAt := p.UpdatedAt + m := &auxmodels.DSSMetadata{ + Locality: locality, + PublicEndpoint: p.PublicEndpoint, + UpdatedAt: &updatedAt, + } + + // Find the latest heartbeat across all sources for this locality. + var latest auxmodels.Heartbeat + found := false + for key, hb := range r.state.Heartbeats { + if key.Locality != locality { + continue + } + if !found || hb.Timestamp.After(*latest.Timestamp) { + latest = hb + found = true + } + } + + if found { + m.LatestTimestamp.Source = sql.NullString{String: latest.Source, Valid: true} + m.LatestTimestamp.Timestamp = latest.Timestamp + m.LatestTimestamp.NextHeartbeatExpectedBefore = latest.NextHeartbeatExpectedBefore + m.LatestTimestamp.Reporter = sql.NullString{String: latest.Reporter, Valid: true} + } + + metadata = append(metadata, m) + } + return metadata, nil +} + +func (r *repo) RecordHeartbeat(ctx context.Context, heartbeat auxmodels.Heartbeat) error { + if heartbeat.Locality == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") + } + if heartbeat.Source == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Source not set") + } + + if heartbeat.Timestamp == nil { + now := timestamp.NowFromContext(ctx) + heartbeat.Timestamp = &now + } + + if heartbeat.NextHeartbeatExpectedBefore != nil && heartbeat.NextHeartbeatExpectedBefore.Before(*heartbeat.Timestamp) { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Cannot expect the timestamp of the next heartbeat before the timestamp of the new heartbeat") + } + + r.state.Heartbeats[heartbeatKey{Locality: heartbeat.Locality, Source: heartbeat.Source}] = heartbeat + return nil +} + +func (r *repo) GetDSSAirspaceRepresentationID(_ context.Context) (string, error) { + return "", stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSAirspaceRepresentationID not implementable for memstore") +} diff --git a/pkg/aux_/store/memstore/dss_test.go b/pkg/aux_/store/memstore/dss_test.go new file mode 100644 index 000000000..43563d27b --- /dev/null +++ b/pkg/aux_/store/memstore/dss_test.go @@ -0,0 +1,125 @@ +package memstore + +import ( + "context" + "testing" + "time" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/stacktrace" + "github.com/stretchr/testify/require" +) + +func TestSaveOwnMetadataValidation(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.SaveOwnMetadata(ctx, "", "https://example.com"))) + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.SaveOwnMetadata(ctx, "dss-1", ""))) +} + +func TestSaveOwnMetadataRoundTrip(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.Equal(t, "dss-1", md[0].Locality) + require.Equal(t, "https://example.com", md[0].PublicEndpoint) + require.NotNil(t, md[0].UpdatedAt) + + // No heartbeat recorded yet. + require.False(t, md[0].LatestTimestamp.Source.Valid) + require.Nil(t, md[0].LatestTimestamp.Timestamp) +} + +func TestSaveOwnMetadataUpsert(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://old.example.com")) + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://new.example.com")) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.Equal(t, "https://new.example.com", md[0].PublicEndpoint) +} + +func TestRecordHeartbeatValidation(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Source: "source1"}))) + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1"}))) + + ts := time.Now() + before := ts.Add(-time.Minute) + err := r.RecordHeartbeat(ctx, auxmodels.Heartbeat{ + Locality: "dss-1", + Source: "source1", + Timestamp: &ts, + NextHeartbeatExpectedBefore: &before, + }) + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(err)) +} + +func TestRecordHeartbeatDefaultsTimestamp(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1"})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Source.Valid) + require.NotNil(t, md[0].LatestTimestamp.Timestamp) +} + +func TestGetDSSMetadataPicksLatestHeartbeat(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + older := time.Now().Add(-time.Hour) + newer := time.Now() + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &older, Reporter: "uss1"})) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source2", Timestamp: &newer, Reporter: "uss2"})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Timestamp.Equal(newer)) + require.Equal(t, "source2", md[0].LatestTimestamp.Source.String) + require.Equal(t, "uss2", md[0].LatestTimestamp.Reporter.String) +} + +func TestGetDSSMetadataUpdatesHeartbeatPerSource(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + first := time.Now().Add(-time.Hour) + second := time.Now() + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &first})) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &second})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Timestamp.Equal(second)) +} diff --git a/pkg/aux_/store/memstore/snapshot.go b/pkg/aux_/store/memstore/snapshot.go new file mode 100644 index 000000000..15deb274b --- /dev/null +++ b/pkg/aux_/store/memstore/snapshot.go @@ -0,0 +1,35 @@ +package memstore + +import ( + "bytes" + "encoding/gob" + + "github.com/interuss/stacktrace" +) + +const snapshotVersion = 1 + +type snapshotEnvelope struct { + Version int + State state +} + +func (r *repo) GetSnapshot() ([]byte, error) { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion, State: r.state}); err != nil { + return nil, stacktrace.Propagate(err, "Failed to encode memstore snapshot") + } + return buf.Bytes(), nil +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + var env snapshotEnvelope + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&env); err != nil { + return stacktrace.Propagate(err, "Failed to decode memstore snapshot") + } + if env.Version != snapshotVersion { + return stacktrace.NewError("Unsupported memstore snapshot version %d, expected %d", env.Version, snapshotVersion) + } + r.state = env.State + return nil +} diff --git a/pkg/aux_/store/memstore/snapshot_test.go b/pkg/aux_/store/memstore/snapshot_test.go new file mode 100644 index 000000000..2085e2488 --- /dev/null +++ b/pkg/aux_/store/memstore/snapshot_test.go @@ -0,0 +1,59 @@ +package memstore + +import ( + "bytes" + "context" + "encoding/gob" + "testing" + "time" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + "github.com/stretchr/testify/require" +) + +func TestSnapshotRoundTrip(t *testing.T) { + ctx := context.Background() + src := newRepo() + require.NoError(t, src.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + ts := time.Now().UTC() + require.NoError(t, src.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source-1", Timestamp: &ts, Reporter: "uss-1"})) + + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := newRepo() + require.NoError(t, dst.RestoreFromSnapshot(data)) + + want, err := src.GetDSSMetadata(ctx) + require.NoError(t, err) + got, err := dst.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Equal(t, want, got) +} + +func TestRestoreFromSnapshotReplacesState(t *testing.T) { + ctx := context.Background() + src := newRepo() + require.NoError(t, src.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := newRepo() + require.NoError(t, dst.SaveOwnMetadata(ctx, "dss-2", "https://other.example.com")) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + md, err := dst.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Len(t, md, 1) + require.Equal(t, "dss-1", md[0].Locality) +} + +func TestRestoreFromSnapshotInvalidData(t *testing.T) { + require.Error(t, newRepo().RestoreFromSnapshot([]byte("random value that is definitely not valid"))) +} + +func TestRestoreFromSnapshotVersionMismatch(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion + 1})) + require.Error(t, newRepo().RestoreFromSnapshot(buf.Bytes())) +} diff --git a/pkg/aux_/store/memstore/store.go b/pkg/aux_/store/memstore/store.go new file mode 100644 index 000000000..52ed59d12 --- /dev/null +++ b/pkg/aux_/store/memstore/store.go @@ -0,0 +1,52 @@ +package memstore + +import ( + "context" + "time" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + "github.com/interuss/dss/pkg/aux_/repos" + "github.com/interuss/dss/pkg/memstore" + "go.uber.org/zap" +) + +// repo is a full implementation of aux_.repos.Repository for memory-based storage. +type repo struct { + state state +} + +// state is the serializable in-memory state. +type state struct { + // Participants holds pool participants metadata, keyed by locality. + Participants map[string]*participant + // Heartbeats holds the latest heartbeat per (locality, source). + Heartbeats map[heartbeatKey]auxmodels.Heartbeat + // participants holds pool participants metadata, keyed by locality. + participants map[string]*participant + // heartbeats holds the latest heartbeat per (locality, source). + heartbeats map[heartbeatKey]auxmodels.Heartbeat +} + +type participant struct { + PublicEndpoint string + UpdatedAt time.Time +} + +type heartbeatKey struct { + Locality string + Source string +} + +func newRepo() *repo { + return &repo{ + state: state{ + Participants: map[string]*participant{}, + Heartbeats: map[heartbeatKey]auxmodels.Heartbeat{}, + }} +} + +func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { + return memstore.Init(ctx, logger, "aux_", newRepo()) +} + +func (r *repo) GetRepo() repos.Repository { return r } diff --git a/pkg/aux_/store/raftstore/params/params.go b/pkg/aux_/store/raftstore/params/params.go new file mode 100644 index 000000000..3cd731a4d --- /dev/null +++ b/pkg/aux_/store/raftstore/params/params.go @@ -0,0 +1,29 @@ +package params + +import ( + "flag" + + raftparams "github.com/interuss/dss/pkg/raftstore/params" + "github.com/interuss/stacktrace" +) + +const peersFlag = "aux_raft_peers" + +var peers string + +func init() { + flag.StringVar(&peers, peersFlag, "", `Comma-separated "nodeID=peerURL" pairs for the aux store, e.g. "1=http://node1:9031,2=http://node2:9031,3=http://node3:9031"`) +} + +func GetConnectParameters() (raftparams.ConnectParameters, error) { + if peers == "" { + return raftparams.ConnectParameters{}, stacktrace.NewError("--%s is required", peersFlag) + } + + p, err := raftparams.GetConnectParameters("aux") + if err != nil { + return raftparams.ConnectParameters{}, err + } + p.Peers = peers + return p, nil +} diff --git a/pkg/aux_/store/raftstore/store.go b/pkg/aux_/store/raftstore/store.go index ae74ad990..5e68bc964 100644 --- a/pkg/aux_/store/raftstore/store.go +++ b/pkg/aux_/store/raftstore/store.go @@ -4,7 +4,11 @@ import ( "context" "github.com/interuss/dss/pkg/aux_/repos" + auxraftparams "github.com/interuss/dss/pkg/aux_/store/raftstore/params" + dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/dss/pkg/raftstore" + "github.com/interuss/dss/pkg/raftstore/consensus" + "github.com/interuss/stacktrace" "go.uber.org/zap" ) @@ -12,5 +16,25 @@ import ( type repo struct{} func Init(ctx context.Context, logger *zap.Logger) (*raftstore.Store[repos.Repository], error) { - return raftstore.Init[repos.Repository](ctx, logger, func() repos.Repository { return &repo{} }) + params, err := auxraftparams.GetConnectParameters() + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get aux raft parameters") + } + return raftstore.Init(ctx, logger, params, "aux", &repo{}) +} + +func (r *repo) GetRepo() repos.Repository { return r } + +func (r *repo) IsReadOnly(_ raftstore.RequestType) bool { return false } + +func (r *repo) GetSnapshot() ([]byte, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") +} + +func (r *repo) RestoreFromSnapshot([]byte) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") +} + +func (r *repo) Apply(_ context.Context, _ consensus.Proposal) (any, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") } diff --git a/pkg/aux_/store/store.go b/pkg/aux_/store/store.go index 4a9049327..2532842a7 100644 --- a/pkg/aux_/store/store.go +++ b/pkg/aux_/store/store.go @@ -4,6 +4,7 @@ import ( "context" "github.com/interuss/dss/pkg/aux_/repos" + auxmemstore "github.com/interuss/dss/pkg/aux_/store/memstore" auxraftstore "github.com/interuss/dss/pkg/aux_/store/raftstore" auxsqlstore "github.com/interuss/dss/pkg/aux_/store/sqlstore" dssstore "github.com/interuss/dss/pkg/store" @@ -24,6 +25,8 @@ func Init(ctx context.Context, logger *zap.Logger, withCheckCron bool) (Store, e return auxsqlstore.Init(ctx, logger, withCheckCron) case params.RaftStoreType: return auxraftstore.Init(ctx, logger) + case params.MemStoreType: + return auxmemstore.Init(ctx, logger) default: return nil, stacktrace.NewError("Unsupported store type %q for aux", storeType) } diff --git a/pkg/memstore/store.go b/pkg/memstore/store.go new file mode 100644 index 000000000..baf6a1799 --- /dev/null +++ b/pkg/memstore/store.go @@ -0,0 +1,67 @@ +package memstore + +// Memstore are a special kind of store: +// Store instances store data in memory. There is no persistent storage. +// Store instances are a singleton. +// Repository usage is not thread-safe. +// +// As of now, they are made to be used by raftstorage. +// Adaptations could be done to use them directly in the future. + +import ( + "context" + "sync" + + dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/dss/pkg/logging" + "github.com/interuss/stacktrace" + "go.uber.org/zap" +) + +type MemRepo[R any] interface { + GetRepo() R + GetSnapshot() ([]byte, error) + RestoreFromSnapshot([]byte) error +} + +type Store[R any] struct { + logger *zap.Logger + + name string + memRepo MemRepo[R] +} + +var ( + stores = map[string]any{} + storesMu sync.Mutex +) + +func Init[R any](ctx context.Context, logger *zap.Logger, name string, r MemRepo[R]) (*Store[R], error) { + + storesMu.Lock() + defer storesMu.Unlock() + if s, ok := stores[name]; ok { + return s.(*Store[R]), nil + } + + store := &Store[R]{ + name: name, + logger: logging.WithValuesFromContext(ctx, logger), + memRepo: r, + } + + stores[name] = store + return store, nil +} + +func (s *Store[R]) Transact(ctx context.Context, requestType string, payload any, _ func(context.Context, R) error) (any, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "Transact not implemented for memstore") +} + +func (s *Store[R]) Interact(_ context.Context) (R, error) { + return s.memRepo.GetRepo(), nil +} + +func (s *Store[R]) Close() error { + return nil +} diff --git a/pkg/raftstore/consensus/consensus.go b/pkg/raftstore/consensus/consensus.go index 8377e22d3..4d6c2367e 100644 --- a/pkg/raftstore/consensus/consensus.go +++ b/pkg/raftstore/consensus/consensus.go @@ -2,10 +2,12 @@ package consensus import ( "context" + "encoding/json" "errors" "fmt" "net/http" "net/url" + "sync" "time" "github.com/interuss/dss/pkg/logging" @@ -29,18 +31,30 @@ type Consensus struct { server *http.Server storage *storage + commitC chan<- EntryCommit + + tracker *proposalsTracker + stopOnce sync.Once + + once sync.Once + shutdownTimeout time.Duration confState raftpb.ConfState snapshotIndex uint64 appliedIndex uint64 } -func NewConsensus(ctx context.Context, logger *zap.Logger, peers map[uint64]*url.URL, connectParams params.ConnectParameters) (*Consensus, error) { - storage, old, err := newStorage(ctx, logger.With(zap.String("component", "storage")), connectParams.DataDir, connectParams.NodeID, connectParams.SnapshotCatchupEntries) +func NewConsensus(ctx context.Context, logger *zap.Logger, connectParams params.ConnectParameters, provider snapshotProvider, commitC chan<- EntryCommit) (*Consensus, error) { + storage, old, err := newStorage(ctx, logger.With(zap.String("component", "storage")), connectParams.DataDir, connectParams.NodeID, provider, connectParams.SnapshotCatchupEntries) if err != nil { return nil, stacktrace.Propagate(err, "failed to initialize storage") } + peers, err := connectParams.PeerMap() + if err != nil { + return nil, stacktrace.Propagate(err, "failed to parse peer map") + } + nodeUrl, ok := peers[connectParams.NodeID] if !ok { return nil, stacktrace.NewError("node ID %d not found in peers map", connectParams.NodeID) @@ -63,6 +77,10 @@ func NewConsensus(ctx context.Context, logger *zap.Logger, peers map[uint64]*url node: node, storage: storage, + commitC: commitC, + tracker: newProposalsTracker(), + + shutdownTimeout: 2 * connectParams.ElectionInterval(), } err = consensus.initTransport(ctx, connectParams.NodeID, connectParams.ClusterID, peers) @@ -80,26 +98,62 @@ func NewConsensus(ctx context.Context, logger *zap.Logger, peers map[uint64]*url consensus.appliedIndex = snap.Metadata.Index go func() { - err := consensus.handleReady(connectParams.TickInterval) + err := consensus.handleReady(connectParams.TickInterval, connectParams.SnapshotIntervalEntries) if err != nil { consensus.logger.Error("handleReady exited with error, shutting down consensus", zap.Error(err)) } - shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*connectParams.ElectionInterval()) + consensus.Stop(context.Background()) + }() + + return consensus, nil +} + +func (c *Consensus) Stop(ctx context.Context) { + c.once.Do(func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), c.shutdownTimeout) defer cancel() - if shutdownErr := consensus.server.Shutdown(shutdownCtx); shutdownErr != nil { - consensus.logger.Error("failed to shutdown http server", zap.Error(shutdownErr)) + if shutdownErr := c.server.Shutdown(shutdownCtx); shutdownErr != nil { + c.logger.Error("failed to shutdown http server", zap.Error(shutdownErr)) } else { - consensus.logger.Info("http server shutdown complete") + c.logger.Info("http server shutdown complete") } - consensus.transport.Stop() - consensus.logger.Info("transport stopped") - consensus.node.Stop() - consensus.logger.Info("raft node stopped") - }() + c.transport.Stop() + c.logger.Info("transport stopped") + c.node.Stop() + c.logger.Info("raft node stopped") + }) +} - return consensus, nil +// ProposeValue blocks until the proposal is committed and applied / dropped or until ctx is cancelled. +func (c *Consensus) ProposeValue(ctx context.Context, requestType string, payload any, readOnly bool) (any, error) { + proposal, err := newProposal(ctx, requestType, payload, readOnly) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to create proposal") + } + + buf, err := json.Marshal(proposal) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to marshal proposal") + } + + applied := c.tracker.track(proposal.ID) + + err = c.node.Propose(ctx, buf) + if err != nil { + c.tracker.untrack(proposal.ID, ProposalResult{Error: err}) + return nil, stacktrace.Propagate(err, "failed to propose value to Raft") + } + + select { + case res := <-applied: + return res.Result, res.Error + + case <-ctx.Done(): + c.tracker.untrack(proposal.ID, ProposalResult{Error: ctx.Err()}) + return nil, ctx.Err() + } } func peersList(peers map[uint64]*url.URL) []raft.Peer { @@ -120,7 +174,7 @@ func (c *Consensus) initTransport(ctx context.Context, nodeID uint64, clusterID Raft: c, ServerStats: v2stats.NewServerStats(nodeIDStr, nodeIDStr), LeaderStats: v2stats.NewLeaderStats(c.logger, nodeIDStr), - ErrorC: make(chan error), + ErrorC: make(chan error, 1), } err := transport.Start() @@ -142,6 +196,8 @@ func (c *Consensus) initTransport(ctx context.Context, nodeID uint64, clusterID return stacktrace.NewError("node ID %d not found in peers map", nodeID) } + c.transport = transport + c.server = &http.Server{ Addr: listeningAddr, Handler: transport.Handler(), @@ -155,12 +211,11 @@ func (c *Consensus) initTransport(ctx context.Context, nodeID uint64, clusterID } }() - c.transport = transport return nil } // handleReady processes the Ready channel of the Raft node and applies committed entries to the state machine -func (c *Consensus) handleReady(tickInterval time.Duration) error { +func (c *Consensus) handleReady(tickInterval time.Duration, snapshotInterval uint64) error { ticker := time.NewTicker(tickInterval) defer ticker.Stop() @@ -185,10 +240,7 @@ func (c *Consensus) handleReady(tickInterval time.Duration) error { return stacktrace.NewError("snapshot index %d shall be greater than current applied index %d", rd.Snapshot.Metadata.Index, c.appliedIndex) } - err = c.dispatchSnapshot(rd.Snapshot.Data) - if err != nil { - return stacktrace.Propagate(err, "failed to dispatch snapshot") - } + c.commitC <- EntryCommit{SnapshotData: rd.Snapshot.Data} c.confState = rd.Snapshot.Metadata.ConfState c.snapshotIndex = rd.Snapshot.Metadata.Index @@ -203,7 +255,7 @@ func (c *Consensus) handleReady(tickInterval time.Duration) error { return stacktrace.Propagate(err, "failed to get entries to apply") } - err = c.publishEntries(entries) + err = c.publishEntries(entries, snapshotInterval) if err != nil { return stacktrace.Propagate(err, "failed to publish entries") } @@ -213,16 +265,104 @@ func (c *Consensus) handleReady(tickInterval time.Duration) error { } } -// TODO implement -func (c *Consensus) publishEntries(_ []raftpb.Entry) error { +func (c *Consensus) publishEntries(entries []raftpb.Entry, snapshotInterval uint64) error { + if len(entries) == 0 { + return nil + } + + c.logger.Info("publishing entries", zap.Int("numEntries", len(entries)), zap.Uint64("firstIndex", entries[0].Index), zap.Uint64("lastIndex", entries[len(entries)-1].Index)) + + var triggerSnapshot bool + var err error + var wg sync.WaitGroup + for _, entry := range entries { + switch entry.Type { + case raftpb.EntryNormal: + err := c.processNormalEntry(entry.Data, &wg) + if err != nil { + return stacktrace.Propagate(err, "failed to process normal entry") + } + case raftpb.EntryConfChange: + err := c.processConfigChangeEntry(entry.Data) + if err != nil { + return stacktrace.Propagate(err, "failed to process config change entry") + } + case raftpb.EntryConfChangeV2: + triggerSnapshot, err = c.processConfigChangeV2Entry(entry.Data) + if err != nil { + return stacktrace.Propagate(err, "failed to process config change v2 entry") + } + } + } + + // wait for all entries to be applied before updating the applied index and potentially triggering a snapshot + wg.Wait() + c.appliedIndex = entries[len(entries)-1].Index + + if triggerSnapshot || c.appliedIndex-c.snapshotIndex >= snapshotInterval { + err := c.storage.triggerSnapshot(c.appliedIndex, &c.confState) + if err != nil { + return stacktrace.Propagate(err, "failed to trigger snapshot") + } + + c.snapshotIndex = c.appliedIndex + } + return nil } -// TODO implement -func (c *Consensus) dispatchSnapshot(_ []byte) error { +// processNormalEntry passes the proposal to the store and waits for the result to be returned before untracking it. +func (c *Consensus) processNormalEntry(data []byte, wg *sync.WaitGroup) error { + if len(data) <= 0 { + return nil + } + + prop := Proposal{} + err := json.Unmarshal(data, &prop) + if err != nil { + return stacktrace.Propagate(err, "failed to unmarshal committed proposal") + } + + //if readOnly proposal and we did not initiate it, skip it (noop) + if prop.ReadOnly && !c.tracker.isPending(prop.ID) { + return nil + } + + applyDoneC := make(chan ProposalResult, 1) + wg.Go(func() { + c.tracker.untrack(prop.ID, <-applyDoneC) + }) + + c.commitC <- EntryCommit{Prop: prop, Done: applyDoneC} + return nil +} + +// raftpb.ConfChange is still used internally by Raft, we just need to apply the change to the node. +// Changes requested by clients are processed by processConfigChangeV2Entry. +func (c *Consensus) processConfigChangeEntry(data []byte) error { + var cc raftpb.ConfChange + err := cc.Unmarshal(data) + if err != nil { + return stacktrace.Propagate(err, "failed to unmarshal config change data") + } + + c.confState = *c.node.ApplyConfChange(cc) return nil } +func (c *Consensus) processConfigChangeV2Entry(data []byte) (bool, error) { + var cc raftpb.ConfChangeV2 + err := cc.Unmarshal(data) + if err != nil { + return false, stacktrace.Propagate(err, "failed to unmarshal config change data") + } + + c.confState = *c.node.ApplyConfChange(cc) + + // TODO - implement config changes when triggered by a proposal + return false, nil +} + func (c *Consensus) entriesToApply(entries []raftpb.Entry) ([]raftpb.Entry, error) { if len(entries) == 0 { return entries, nil @@ -272,8 +412,3 @@ func (c *Consensus) ReportUnreachable(id uint64) { func (c *Consensus) ReportSnapshot(id uint64, status raft.SnapshotStatus) { c.node.ReportSnapshot(id, status) } - -// RegisterStore allows registering a snapshot provider function for a specific store -func (c *Consensus) RegisterStore(name string, provider snapshotProvider) { - c.storage.registerSnapshotProvider(name, provider) -} diff --git a/pkg/raftstore/consensus/proposal.go b/pkg/raftstore/consensus/proposal.go new file mode 100644 index 000000000..d05f4b7e5 --- /dev/null +++ b/pkg/raftstore/consensus/proposal.go @@ -0,0 +1,93 @@ +package consensus + +import ( + "context" + "encoding/json" + "sync" + "time" + + "github.com/google/uuid" + "github.com/interuss/dss/pkg/timestamp" + "github.com/interuss/stacktrace" +) + +type EntryCommit struct { + Prop Proposal + Done chan ProposalResult + + SnapshotData []byte +} + +type Proposal struct { + ID string `json:"id"` + Timestamp time.Time `json:"timestamp"` + RequestType string `json:"request_type"` + Value []byte `json:"value"` + ReadOnly bool `json:"read_only"` +} + +type ProposalResult struct { + Result any + Error error +} + +func newProposal(ctx context.Context, requestType string, payload any, readOnly bool) (Proposal, error) { + proposalTimestamp := timestamp.NowFromContext(ctx) + if proposalTimestamp.IsZero() { + proposalTimestamp = time.Now().UTC() + } + + value, err := json.Marshal(payload) + if err != nil { + return Proposal{}, stacktrace.Propagate(err, "failed to serialize proposal payload") + } + + return Proposal{ + ID: uuid.NewString(), + Timestamp: proposalTimestamp, + RequestType: requestType, + Value: value, + ReadOnly: readOnly, + }, nil +} + +type proposalsTracker struct { + sync.Mutex + pending map[string]chan ProposalResult +} + +func newProposalsTracker() *proposalsTracker { + return &proposalsTracker{ + pending: make(map[string]chan ProposalResult), + } +} + +func (p *proposalsTracker) isPending(id string) bool { + p.Lock() + defer p.Unlock() + + _, ok := p.pending[id] + return ok +} + +func (p *proposalsTracker) track(id string) chan ProposalResult { + p.Lock() + defer p.Unlock() + + applied := make(chan ProposalResult, 1) + p.pending[id] = applied + return applied +} + +func (p *proposalsTracker) untrack(id string, result ProposalResult) { + p.Lock() + defer p.Unlock() + + applied, ok := p.pending[id] + if !ok { + return + } + + applied <- result + delete(p.pending, id) +} diff --git a/pkg/raftstore/consensus/storage.go b/pkg/raftstore/consensus/storage.go index 97f2ff718..1d9893882 100644 --- a/pkg/raftstore/consensus/storage.go +++ b/pkg/raftstore/consensus/storage.go @@ -2,7 +2,6 @@ package consensus import ( "context" - "encoding/json" "errors" "fmt" "os" @@ -31,8 +30,8 @@ type storage struct { wal *wal.WAL - snapper *snap.Snapshotter - providers map[string]snapshotProvider + snapper *snap.Snapshotter + snapshot snapshotProvider snapshotCatchUpEntries uint64 } @@ -40,7 +39,7 @@ type storage struct { // newStorage initializes the storage by loading the latest snapshot and wal entries from the disk // and applies them to the Raft memory storage. // It returns the initialized storage, a boolean indicating whether the storage was pre-existent or an error. -func newStorage(ctx context.Context, logger *zap.Logger, dataDir string, nodeID uint64, snapshotCatchUpEntries uint64) (*storage, bool, error) { +func newStorage(ctx context.Context, logger *zap.Logger, dataDir string, nodeID uint64, provider snapshotProvider, snapshotCatchUpEntries uint64) (*storage, bool, error) { logger = logging.WithValuesFromContext(ctx, logger) // load the latest snapshot @@ -124,8 +123,9 @@ func newStorage(ctx context.Context, logger *zap.Logger, dataDir string, nodeID wal: w, - snapper: snapper, - providers: make(map[string]snapshotProvider), + snapper: snapper, + snapshot: provider, + snapshotCatchUpEntries: snapshotCatchUpEntries, }, ok, nil } @@ -177,33 +177,9 @@ func (s *storage) save(snapshot raftpb.Snapshot) error { return s.wal.ReleaseLockTo(snapshot.Metadata.Index) } -func (s *storage) registerSnapshotProvider(name string, provider func() ([]byte, error)) { - s.providers[name] = provider -} - -// getSnapshot calls all registered snapshot providers and combines their data into a single snapshot. -func (s *storage) getSnapshot() ([]byte, error) { - parts := make(map[string][]byte) - for name, provider := range s.providers { - data, err := provider() - if err != nil { - return nil, stacktrace.Propagate(err, "failed to get snapshot data from %q", name) - } - - parts[name] = data - } - - return json.Marshal(parts) -} - -// snapshotter returns the snapshotter used by the storage. -func (s *storage) snapshotter() *snap.Snapshotter { - return s.snapper -} - func (s *storage) triggerSnapshot(appliedIndex uint64, confState *raftpb.ConfState) error { s.logger.Info("triggering snapshot", zap.Uint64("appliedIndex", appliedIndex)) - data, err := s.getSnapshot() + data, err := s.snapshot() if err != nil { return stacktrace.Propagate(err, "failed to get snapshot data") } diff --git a/pkg/raftstore/params/params.go b/pkg/raftstore/params/params.go index 3a2b6fd6a..85b0551b2 100644 --- a/pkg/raftstore/params/params.go +++ b/pkg/raftstore/params/params.go @@ -123,7 +123,6 @@ var ( func init() { flag.Uint64Var(&connectParameters.NodeID, "raft_node_id", 0, "Raft node ID for this instance (must be non-zero and unique within the cluster).") flag.Uint64Var(&connectParameters.ClusterID, "raft_cluster_id", 1, "ID of the cluster, used to isolate different Raft clusters running in the same network (must be the same for all nodes in the cluster).") - flag.StringVar(&connectParameters.Peers, "raft_peers", "", `Comma-separated "nodeID=peerURL" pairs for all cluster members, including the current node, e.g. "1=http://node1:9021,2=http://node2:9021,3=http://node3:9021"`) flag.StringVar(&connectParameters.DataDir, "raft_datadir", defaultDataDir, "Directory for raft data (WAL segments and snapshots), required for restarts. These should not be deleted while the node is running or across restarts unless the node is being permanently shut down.") flag.Uint64Var(&connectParameters.SnapshotCatchupEntries, "raft_snapshot_catchup_entries", defaultSnapshotCatchupEntries, @@ -142,6 +141,11 @@ func init() { } // GetConnectParameters returns a ConnectParameters instance that gets populated from well-known CLI flags. -func GetConnectParameters() ConnectParameters { - return connectParameters +func GetConnectParameters(subfolder string) (ConnectParameters, error) { + if connectParameters.NodeID == 0 { + return ConnectParameters{}, stacktrace.NewError("--raft_node_id is required and must be non-zero") + } + p := connectParameters + p.DataDir = connectParameters.DataDir + "/" + subfolder + return p, nil } diff --git a/pkg/raftstore/store.go b/pkg/raftstore/store.go index 3ef62c157..8bbdb5110 100644 --- a/pkg/raftstore/store.go +++ b/pkg/raftstore/store.go @@ -2,68 +2,98 @@ package raftstore import ( "context" - "sync" + "github.com/interuss/dss/pkg/logging" "github.com/interuss/dss/pkg/raftstore/consensus" raftparams "github.com/interuss/dss/pkg/raftstore/params" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/stacktrace" "go.uber.org/zap" ) -var ( - sharedConsensus *consensus.Consensus - sharedConsensusOnce sync.Once - sharedConsensusErr error -) +type RequestType = string + +type RaftRepo[R any] interface { + GetRepo() R + // Apply is called on every committed entry. The proposal must be applied atomically. + Apply(ctx context.Context, proposal consensus.Proposal) (any, error) + GetSnapshot() ([]byte, error) + RestoreFromSnapshot(data []byte) error + IsReadOnly(requestType RequestType) bool +} type Store[R any] struct { - newRepo func() R - consensus *consensus.Consensus + logger *zap.Logger + + name string + raftRepo RaftRepo[R] + cancel context.CancelFunc + + Consensus *consensus.Consensus } -func Init[R any](ctx context.Context, logger *zap.Logger, newRepo func() R) (*Store[R], error) { - // scd, rid and aux will share the same consensus instance, so we initialize it once. - sharedConsensusOnce.Do(func() { - params := raftparams.GetConnectParameters() - peers, err := params.PeerMap() - if err != nil { - sharedConsensusErr = stacktrace.Propagate(err, "failed to parse peer map") - return - } +func Init[R any](ctx context.Context, logger *zap.Logger, params raftparams.ConnectParameters, name string, r RaftRepo[R]) (*Store[R], error) { + ctx, cancel := context.WithCancel(ctx) - sharedConsensus, sharedConsensusErr = consensus.NewConsensus(ctx, logger, peers, params) - if sharedConsensusErr != nil { - sharedConsensusErr = stacktrace.Propagate(sharedConsensusErr, "failed to initialize consensus") - } - }) - if sharedConsensusErr != nil { - return nil, sharedConsensusErr + store := &Store[R]{ + name: name, + raftRepo: r, + logger: logging.WithValuesFromContext(ctx, logger), + cancel: cancel, + } + commitC := make(chan consensus.EntryCommit) + go store.processCommits(ctx, commitC) + + consensusInstance, err := consensus.NewConsensus(ctx, logger, params, func() ([]byte, error) { return nil, nil }, commitC) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to initialize consensus") } - // TODO: implement - sharedConsensus.RegisterStore("provider", func() ([]byte, error) { - return nil, nil - }) + store.Consensus = consensusInstance - return &Store[R]{ - newRepo: newRepo, - consensus: sharedConsensus, - }, nil + return store, nil } -// Transact proposes the entry to Raft and blocks until it is committed and applied. -func (s *Store[R]) Transact(ctx context.Context, f func(context.Context, R) error) error { - // TODO: implement - return nil +// Transact proposes an entry to Raft and blocks until it is committed and applied. +// The processCommits loop will call Apply on the proposal when it is committed. +func (s *Store[R]) Transact(ctx context.Context, requestType RequestType, payload any, _ func(context.Context, R) error) (any, error) { + return s.Consensus.ProposeValue(ctx, requestType, payload, s.raftRepo.IsReadOnly(requestType)) } -// Interact returns a repository that can be used to query the store without proposing a Raft entry. func (s *Store[R]) Interact(_ context.Context) (R, error) { - return s.newRepo(), nil + return s.raftRepo.GetRepo(), nil } -// Close shuts down the consensus instance. +// Close shuts down the consensus instance and processCommits loop. func (s *Store[R]) Close() error { - // TODO: implement + s.Consensus.Stop(context.Background()) + s.cancel() return nil } + +// processCommits reads committed entries from the consensus layer and applies them via Apply. +func (s *Store[R]) processCommits(ctx context.Context, commitCh <-chan consensus.EntryCommit) { + for { + select { + case <-ctx.Done(): + s.logger.Info("stopping commit processing loop") + return + case commit, ok := <-commitCh: + if !ok { + s.logger.Info("commit channel closed, stopping commit processing loop") + return + } + + if commit.SnapshotData != nil { + if err := s.raftRepo.RestoreFromSnapshot(commit.SnapshotData); err != nil { + s.logger.Error("failed to restore from snapshot", zap.Error(err)) + } + continue + } + + ctx = timestamp.WithTimestamp(ctx, commit.Prop.Timestamp) + result, err := s.raftRepo.Apply(ctx, commit.Prop) + commit.Done <- consensus.ProposalResult{Result: result, Error: err} + } + } +} diff --git a/pkg/rid/application/application_test.go b/pkg/rid/application/application_test.go index 0cc3fd885..11b785bf4 100644 --- a/pkg/rid/application/application_test.go +++ b/pkg/rid/application/application_test.go @@ -35,8 +35,8 @@ func (s *mockRepo) Interact(ctx context.Context) (repos.Repository, error) { return s, nil } -func (s *mockRepo) Transact(ctx context.Context, f func(ctx context.Context, repo repos.Repository) error) error { - return f(ctx, s) +func (s *mockRepo) Transact(ctx context.Context, _ string, _ any, f func(ctx context.Context, repo repos.Repository) error) (any, error) { + return nil, f(ctx, s) } func (s *mockRepo) Close() error { diff --git a/pkg/rid/application/isa.go b/pkg/rid/application/isa.go index ea402af60..c6a265419 100644 --- a/pkg/rid/application/isa.go +++ b/pkg/rid/application/isa.go @@ -63,7 +63,7 @@ func (a *app) DeleteISA(ctx context.Context, id dssmodels.ID, owner dssmodels.Ow subs []*ridmodels.Subscription ) // The following will automatically retry TXN retry errors. - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := a.store.Transact(ctx, "", nil, func(ctx context.Context, repo repos.Repository) error { old, err := repo.GetISA(ctx, id, true) switch { case err != nil: @@ -104,7 +104,7 @@ func (a *app) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServic subs []*ridmodels.Subscription ) // The following will automatically retry TXN retry errors. - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := a.store.Transact(ctx, "", nil, func(ctx context.Context, repo repos.Repository) error { // ensure it doesn't exist yet old, err := repo.GetISA(ctx, isa.ID, false) if err != nil { @@ -138,7 +138,7 @@ func (a *app) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServic subs []*ridmodels.Subscription ) // The following will automatically retry TXN retry errors. - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := a.store.Transact(ctx, "", nil, func(ctx context.Context, repo repos.Repository) error { var err error old, err := repo.GetISA(ctx, isa.ID, true) diff --git a/pkg/rid/application/subscription.go b/pkg/rid/application/subscription.go index 33ba2af36..411d36ea2 100644 --- a/pkg/rid/application/subscription.go +++ b/pkg/rid/application/subscription.go @@ -60,7 +60,7 @@ func (a *app) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) return nil, stacktrace.Propagate(err, "Unable to adjust time range") } var sub *ridmodels.Subscription - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := a.store.Transact(ctx, "", nil, func(ctx context.Context, repo repos.Repository) error { // ensure it doesn't exist yet old, err := repo.GetSubscription(ctx, s.ID) @@ -98,7 +98,7 @@ func (a *app) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) func (a *app) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { var sub *ridmodels.Subscription - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := a.store.Transact(ctx, "", nil, func(ctx context.Context, repo repos.Repository) error { old, err := repo.GetSubscription(ctx, s.ID) switch { case err != nil: @@ -145,7 +145,7 @@ func (a *app) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) // DeleteSubscription deletes the Subscription identified by "id" and owned by "owner". func (a *app) DeleteSubscription(ctx context.Context, id dssmodels.ID, owner dssmodels.Owner, version *dssmodels.Version) (*ridmodels.Subscription, error) { var ret *ridmodels.Subscription - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := a.store.Transact(ctx, "", nil, func(ctx context.Context, repo repos.Repository) error { var err error old, err := repo.GetSubscription(ctx, id) switch { diff --git a/pkg/rid/store/memstore/doc.go b/pkg/rid/store/memstore/doc.go new file mode 100644 index 000000000..e42a19ad4 --- /dev/null +++ b/pkg/rid/store/memstore/doc.go @@ -0,0 +1,3 @@ +// Package rid.store.memstore provides a full implementation of store.Store[rid.repos.Repository] +// storing data in memory. It as meant to be used by raftstore. +package memstore diff --git a/pkg/rid/store/memstore/identification_service_area.go b/pkg/rid/store/memstore/identification_service_area.go new file mode 100644 index 000000000..29055426a --- /dev/null +++ b/pkg/rid/store/memstore/identification_service_area.go @@ -0,0 +1,152 @@ +package memstore + +import ( + "context" + "time" + + "github.com/golang/geo/s2" + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/timestamp" + "github.com/interuss/stacktrace" +) + +func isaRecordFromModel(isa *ridmodels.IdentificationServiceArea, updatedAt time.Time) *isaRecord { + return &isaRecord{ + ID: isa.ID, + URL: isa.URL, + Owner: isa.Owner, + Cells: cloneCells(isa.Cells), + StartTime: cloneTime(isa.StartTime), + EndTime: cloneTime(isa.EndTime), + AltitudeHi: cloneFloat32(isa.AltitudeHi), + AltitudeLo: cloneFloat32(isa.AltitudeLo), + Writer: isa.Writer, + UpdatedAt: updatedAt, + } +} + +// toModel rebuilds the ISA model +func (rec *isaRecord) toModel() *ridmodels.IdentificationServiceArea { + return &ridmodels.IdentificationServiceArea{ + ID: rec.ID, + URL: rec.URL, + Owner: rec.Owner, + Cells: cloneCells(rec.Cells), + StartTime: cloneTime(rec.StartTime), + EndTime: cloneTime(rec.EndTime), + Version: dssmodels.VersionFromTime(rec.UpdatedAt), + AltitudeHi: cloneFloat32(rec.AltitudeHi), + AltitudeLo: cloneFloat32(rec.AltitudeLo), + Writer: rec.Writer, + } +} + +func (r *repo) GetISA(_ context.Context, id dssmodels.ID, _ bool) (*ridmodels.IdentificationServiceArea, error) { + rec, ok := r.state.ISAs[id] + if !ok { + return nil, nil + } + return rec.toModel(), nil +} + +func (r *repo) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + if err := validateWriteData(isa.Cells, isa.StartTime, isa.EndTime); err != nil { + return nil, err + } + if _, ok := r.state.ISAs[isa.ID]; ok { + return nil, stacktrace.NewError("ISA with id %s already exists", isa.ID) + } + rec := isaRecordFromModel(isa, timestamp.NowFromContext(ctx)) + r.state.ISAs[isa.ID] = rec + return rec.toModel(), nil +} + +func (r *repo) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + if err := validateWriteData(isa.Cells, isa.StartTime, isa.EndTime); err != nil { + return nil, err + } + prev, ok := r.state.ISAs[isa.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(prev.UpdatedAt).Matches(isa.Version) { + return nil, nil + } + rec := isaRecordFromModel(isa, timestamp.NowFromContext(ctx)) + rec.Owner = prev.Owner + r.state.ISAs[isa.ID] = rec + return rec.toModel(), nil +} + +func (r *repo) DeleteISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + rec, ok := r.state.ISAs[isa.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(rec.UpdatedAt).Matches(isa.Version) { + return nil, nil + } + out := rec.toModel() + delete(r.state.ISAs, isa.ID) + return out, nil +} + +func (r *repo) SearchISAs(_ context.Context, cells s2.CellUnion, earliest *time.Time, latest *time.Time) ([]*ridmodels.IdentificationServiceArea, error) { + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Missing cell IDs for query") + } + if earliest == nil { + return nil, stacktrace.NewError("Earliest start time is missing") + } + + want := cellSet(cells) + var out []*ridmodels.IdentificationServiceArea + for _, rec := range r.state.ISAs { + // ends_at >= earliest + if rec.EndTime == nil || rec.EndTime.Before(*earliest) { + continue + } + // COALESCE(starts_at <= latest, true) + if latest != nil && rec.StartTime != nil && rec.StartTime.After(*latest) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil +} + +func (r *repo) ListExpiredISAs(_ context.Context, writer string, threshold time.Time) ([]*ridmodels.IdentificationServiceArea, error) { + var out []*ridmodels.IdentificationServiceArea + for _, rec := range r.state.ISAs { + // ends_at <= threshold + if rec.EndTime == nil || rec.EndTime.After(threshold) { + continue + } + if writer == "" { + if rec.Writer != "" { + continue + } + } else if rec.Writer != writer { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil +} + +func (r *repo) CountISAs(_ context.Context) (int64, error) { + return int64(len(r.state.ISAs)), nil +} diff --git a/pkg/rid/store/memstore/identification_service_area_test.go b/pkg/rid/store/memstore/identification_service_area_test.go new file mode 100644 index 000000000..70699e415 --- /dev/null +++ b/pkg/rid/store/memstore/identification_service_area_test.go @@ -0,0 +1,321 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/golang/geo/s2" + "github.com/google/uuid" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/rid/repos" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +var ( + // Ensure the struct conforms to the interface + _ repos.ISA = &repo{} + overflow = uint64(17106221850767130624) // face 5 L13 overflows + serviceArea = &ridmodels.IdentificationServiceArea{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner(uuid.New().String()), + URL: "https://no/place/like/home/for/flights", + StartTime: &startTime, + EndTime: &endTime, + Writer: writer, + Cells: s2.CellUnion{ + s2.CellID(uint64(overflow)), + s2.CellID(17106221850767130624), + }, + } +) + +func TestStoreSearchISAs(t *testing.T) { + ctx := context.Background() + cells := s2.CellUnion{ + s2.CellID(17106221850767130624), + s2.CellID(17106221885126868992), + s2.CellID(17106221919486607360), + s2.CellID(uint64(overflow)), + } + repo := setUpStore(t) + + isa := *serviceArea + isa.Cells = cells + saOut, err := repo.InsertISA(ctx, &isa) + require.NoError(t, err) + require.NotNil(t, saOut) + require.Equal(t, isa.ID, saOut.ID) + + for _, r := range []struct { + name string + cells s2.CellUnion + timestampMutator func(time.Time, time.Time) (*time.Time, *time.Time) + expectedLen int + }{ + { + name: "search for empty cell", + cells: s2.CellUnion{s2.CellID(17106221953846345728)}, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 0, + }, + { + name: "search for only one cell", + cells: s2.CellUnion{s2.CellID(17106221850767130624)}, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 1, + }, + { + name: "search for only one cell with high bit set", + cells: s2.CellUnion{s2.CellID(uint64(overflow))}, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 1, + }, + { + name: "search with nil ends_at", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 1, + }, + { + name: "search with exact timestamps", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, &end + }, + expectedLen: 1, + }, + { + name: "search with non-matching time span", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + var ( + offset = time.Duration(100 * time.Second) + earliest = end.Add(offset) + latest = end.Add(offset * 2) + ) + + return &earliest, &latest + }, + expectedLen: 0, + }, + { + name: "search with expanded time span", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + var ( + offset = time.Duration(100 * time.Second) + earliest = start.Add(-offset) + latest = end.Add(offset) + ) + + return &earliest, &latest + }, + expectedLen: 1, + }, + } { + t.Run(r.name, func(t *testing.T) { + earliest, latest := r.timestampMutator(*saOut.StartTime, *saOut.EndTime) + + serviceAreas, err := repo.SearchISAs(ctx, r.cells, earliest, latest) + require.NoError(t, err) + require.Len(t, serviceAreas, r.expectedLen) + }) + } +} + +func TestBadVersion(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + saOut1, err := repo.InsertISA(ctx, serviceArea) + require.NoError(t, err) + require.NotNil(t, saOut1) + + // Rewriting service area should fail + saOut2, err := repo.UpdateISA(ctx, serviceArea) + require.NoError(t, err) + require.Nil(t, saOut2) + + // Rewriting, but with the correct version should work. + newEndTime := saOut1.EndTime.Add(time.Minute) + saOut1.EndTime = &newEndTime + saOut3, err := repo.UpdateISA(ctx, saOut1) + require.NoError(t, err) + require.NotNil(t, saOut3) +} + +func TestStoreExpiredISA(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + saOut, err := repo.InsertISA(ctx, serviceArea) + require.NoError(t, err) + require.NotNil(t, saOut) + + // The ISA's endTime is one hour from now. + fakeClock.Advance(59 * time.Minute) + + // We should still be able to find the ISA by searching and by ID. + now := fakeClock.Now() + serviceAreas, err := repo.SearchISAs(ctx, serviceArea.Cells, &now, nil) + require.NoError(t, err) + require.Len(t, serviceAreas, 1) + + ret, err := repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, ret) + + // But now the ISA has expired. + fakeClock.Advance(2 * time.Minute) + now = fakeClock.Now() + + serviceAreas, err = repo.SearchISAs(ctx, serviceArea.Cells, &now, nil) + require.NoError(t, err) + require.Len(t, serviceAreas, 0) + + // A get should work even if it is expired. + ret, err = repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, ret) +} + +func TestStoreDeleteISAs(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + // Insert the ISA. + copy := *serviceArea + isa, err := repo.InsertISA(ctx, ©) + require.NoError(t, err) + require.NotNil(t, isa) + + // Delete the ISA. + // Ensure a fresh Get, then delete still updates the sub indexes + isa, err = repo.GetISA(ctx, isa.ID, false) + require.NoError(t, err) + + serviceAreaOut, err := repo.DeleteISA(ctx, isa) + require.NoError(t, err) + require.Equal(t, isa, serviceAreaOut) +} + +func TestStoreISAWithNoGeoData(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + endTime := fakeClock.Now().Add(24 * time.Hour) + sub := &ridmodels.IdentificationServiceArea{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner("original owner"), + EndTime: &endTime, + } + _, err := repo.InsertISA(ctx, sub) + require.Error(t, err) +} + +func TestListExpiredISAs(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert ISA with endtime 1 day from now + isa1 := *serviceArea + startTime := fakeClock.Now() + isa1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + isa1.EndTime = &endTime + saOut1, err := repo.InsertISA(ctx, &isa1) + require.NoError(t, err) + require.NotNil(t, saOut1) + + // Insert ISA with endtime to 30 minutes ago + isa2 := *serviceArea + startTime = fakeClock.Now().Add(-1 * time.Hour) + isa2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + isa2.EndTime = &endTime + isa2.ID = dssmodels.ID(uuid.New().String()) + saOut2, err := repo.InsertISA(ctx, &isa2) + require.NoError(t, err) + require.NotNil(t, saOut2) + + serviceAreas, err := repo.ListExpiredISAs(ctx, writer, fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, serviceAreas, 1) +} + +func TestListExpiredISAsWithEmptyWriter(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert ISA with endtime 1 day from now + isa1 := *serviceArea + startTime := fakeClock.Now() + isa1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + isa1.EndTime = &endTime + isa1.Writer = "" + saOut1, err := repo.InsertISA(ctx, &isa1) + require.NoError(t, err) + require.NotNil(t, saOut1) + + // Insert ISA with endtime to 30 minutes ago + isa2 := *serviceArea + startTime = fakeClock.Now().Add(-1 * time.Hour) + isa2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + isa2.EndTime = &endTime + isa2.ID = dssmodels.ID(uuid.New().String()) + isa2.Writer = "" + saOut2, err := repo.InsertISA(ctx, &isa2) + require.NoError(t, err) + require.NotNil(t, saOut2) + + serviceAreas, err := repo.ListExpiredISAs(ctx, "", fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, serviceAreas, 1) +} + +func TestStoreCountISAs(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + // Insert the ISA. + copy := *serviceArea + isa, err := repo.InsertISA(ctx, ©) + require.NoError(t, err) + require.NotNil(t, isa) + + //Cound should be one + count, err := repo.CountISAs(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(1)) + + // Delete the ISA. + // Ensure a fresh Get, then delete still updates the sub indexes + isa, err = repo.GetISA(ctx, isa.ID, false) + require.NoError(t, err) + + serviceAreaOut, err := repo.DeleteISA(ctx, isa) + require.NoError(t, err) + require.Equal(t, isa, serviceAreaOut) + + //Cound should be zero + count, err = repo.CountISAs(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(0)) +} diff --git a/pkg/rid/store/memstore/snapshot.go b/pkg/rid/store/memstore/snapshot.go new file mode 100644 index 000000000..fe2aa2c70 --- /dev/null +++ b/pkg/rid/store/memstore/snapshot.go @@ -0,0 +1,43 @@ +package memstore + +import ( + "bytes" + "encoding/gob" + + dssmodels "github.com/interuss/dss/pkg/models" + "github.com/interuss/stacktrace" +) + +const snapshotVersion = 1 + +type snapshotEnvelope struct { + Version int + State state +} + +func (r *repo) GetSnapshot() ([]byte, error) { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion, State: r.state}); err != nil { + return nil, stacktrace.Propagate(err, "Failed to encode memstore snapshot") + } + return buf.Bytes(), nil +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + var env snapshotEnvelope + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&env); err != nil { + return stacktrace.Propagate(err, "Failed to decode memstore snapshot") + } + if env.Version != snapshotVersion { + return stacktrace.NewError("Unsupported memstore snapshot version %d, expected %d", env.Version, snapshotVersion) + } + r.state = env.State + // gob decodes an empty map as nil; re-initialize to keep the repo writable. + if r.state.ISAs == nil { + r.state.ISAs = map[dssmodels.ID]*isaRecord{} + } + if r.state.Subscriptions == nil { + r.state.Subscriptions = map[dssmodels.ID]*subscriptionRecord{} + } + return nil +} diff --git a/pkg/rid/store/memstore/snapshot_test.go b/pkg/rid/store/memstore/snapshot_test.go new file mode 100644 index 000000000..a9764f812 --- /dev/null +++ b/pkg/rid/store/memstore/snapshot_test.go @@ -0,0 +1,82 @@ +package memstore + +import ( + "bytes" + "context" + "encoding/gob" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/interuss/dss/pkg/models" + "github.com/stretchr/testify/require" +) + +func TestSnapshotRoundTrip(t *testing.T) { + ctx := context.Background() + src := setUpStore(t) + _, err := src.InsertISA(ctx, serviceArea) + require.NoError(t, err) + _, err = src.InsertSubscription(ctx, subscriptionsPool[0].input) + require.NoError(t, err) + + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := setUpStore(t) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + wantISA, err := src.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + gotISA, err := dst.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + + if diff := cmp.Diff(wantISA, gotISA, cmpopts.EquateApproxTime(0), cmp.AllowUnexported(models.Version{})); diff != "" { + t.Errorf("IdentificationServiceArea mismatch (-want +got):\n%s", diff) + } + + wantSub, err := src.GetSubscription(ctx, subscriptionsPool[0].input.ID) + require.NoError(t, err) + gotSub, err := dst.GetSubscription(ctx, subscriptionsPool[0].input.ID) + require.NoError(t, err) + + if diff := cmp.Diff(wantSub, gotSub, cmpopts.EquateApproxTime(0), cmp.AllowUnexported(models.Version{})); diff != "" { + t.Errorf("Subscription mismatch (-want +got):\n%s", diff) + } +} + +func TestRestoreFromSnapshotReplacesState(t *testing.T) { + ctx := context.Background() + src := setUpStore(t) + _, err := src.InsertISA(ctx, serviceArea) + require.NoError(t, err) + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := setUpStore(t) + other := *serviceArea + other.ID = "00000000-0000-4000-8000-000000000002" + _, err = dst.InsertISA(ctx, &other) + require.NoError(t, err) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + count, err := dst.CountISAs(ctx) + require.NoError(t, err) + require.Equal(t, int64(1), count) + got, err := dst.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, got) + gone, err := dst.GetISA(ctx, other.ID, false) + require.NoError(t, err) + require.Nil(t, gone) +} + +func TestRestoreFromSnapshotInvalidData(t *testing.T) { + require.Error(t, setUpStore(t).RestoreFromSnapshot([]byte("random value that is definitely not valid"))) +} + +func TestRestoreFromSnapshotVersionMismatch(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion + 1})) + require.Error(t, setUpStore(t).RestoreFromSnapshot(buf.Bytes())) +} diff --git a/pkg/rid/store/memstore/store.go b/pkg/rid/store/memstore/store.go new file mode 100644 index 000000000..7cd02e161 --- /dev/null +++ b/pkg/rid/store/memstore/store.go @@ -0,0 +1,136 @@ +package memstore + +import ( + "context" + "time" + + "github.com/golang/geo/s2" + "github.com/interuss/dss/pkg/geo" + "github.com/interuss/dss/pkg/memstore" + dssmodels "github.com/interuss/dss/pkg/models" + "github.com/interuss/dss/pkg/rid/repos" + "github.com/interuss/stacktrace" + "go.uber.org/zap" +) + +// repo is a full implementation of rid.repos.Repository for memory-based storage. +type repo struct { + state state +} + +// state is the serializable in-memory state. +type state struct { + // ISAs holds the stored ISAs keyed by ID. + ISAs map[dssmodels.ID]*isaRecord + // Subscriptions holds the stored subscriptions keyed by ID. + Subscriptions map[dssmodels.ID]*subscriptionRecord +} + +// isaRecord is the gob-serializable representation of an ISA. It intentionally +// stores only primitive fields: the model's Version is never persisted, it is +// derived from UpdatedAt on read. +type isaRecord struct { + ID dssmodels.ID + URL string + Owner dssmodels.Owner + Cells s2.CellUnion + StartTime *time.Time + EndTime *time.Time + AltitudeHi *float32 + AltitudeLo *float32 + Writer string + UpdatedAt time.Time +} + +// subscriptionRecord is the gob-serializable representation of a Subscription. +type subscriptionRecord struct { + ID dssmodels.ID + URL string + NotificationIndex int + Owner dssmodels.Owner + Cells s2.CellUnion + StartTime *time.Time + EndTime *time.Time + AltitudeHi *float32 + AltitudeLo *float32 + Writer string + UpdatedAt time.Time +} + +func newRepo() *repo { + r := &repo{} + r.resetState() + return r +} + +func (r *repo) resetState() { + r.state = state{ + ISAs: map[dssmodels.ID]*isaRecord{}, + Subscriptions: map[dssmodels.ID]*subscriptionRecord{}, + } +} + +func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { + return memstore.Init(ctx, logger, "rid", newRepo()) +} + +func (r *repo) GetRepo() repos.Repository { return r } + +// validateWriteData validate constraints on an ISA +func validateWriteData(cells s2.CellUnion, start, end *time.Time) error { + if len(cells) == 0 { + return stacktrace.NewError("At least one cell must be provided") + } + for _, c := range cells { + if err := geo.ValidateCell(c); err != nil { + return stacktrace.Propagate(err, "Error validating cell") + } + } + if start != nil && end != nil && !start.Before(*end) { + return stacktrace.NewError("Start time must be strictly before end time") + } + return nil +} + +// cellSet builds a lookup set from a cell union. +func cellSet(cells s2.CellUnion) map[s2.CellID]struct{} { + set := make(map[s2.CellID]struct{}, len(cells)) + for _, c := range cells { + set[c] = struct{}{} + } + return set +} + +// overlaps reports whether any cell is present in set (equivalent to the SQL +// "cells && $x" array-overlap operator). +func overlaps(cells s2.CellUnion, set map[s2.CellID]struct{}) bool { + for _, c := range cells { + if _, ok := set[c]; ok { + return true + } + } + return false +} + +func cloneCells(cells s2.CellUnion) s2.CellUnion { + if cells == nil { + return nil + } + return append(s2.CellUnion(nil), cells...) +} + +func cloneTime(t *time.Time) *time.Time { + if t == nil { + return nil + } + v := *t + return &v +} + +func cloneFloat32(f *float32) *float32 { + if f == nil { + return nil + } + v := *f + return &v +} diff --git a/pkg/rid/store/memstore/store_test.go b/pkg/rid/store/memstore/store_test.go new file mode 100644 index 000000000..3bf484463 --- /dev/null +++ b/pkg/rid/store/memstore/store_test.go @@ -0,0 +1,47 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +var ( + fakeClock = clockwork.NewFakeClock() + startTime = fakeClock.Now().UTC().Add(-time.Minute) + endTime = fakeClock.Now().UTC().Add(time.Hour) + writer = "writer" +) + +// setUpStore returns a fresh in-memory repo whose clock is the (reset) package +// fakeClock, so tests can advance time deterministically. +func setUpStore(t *testing.T) *repo { + t.Helper() + r := newRepo() + return r +} + +func TestDatabaseEnsuresBeginsBeforeExpires(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + var ( + begins = time.Now().UTC() + expires = begins.Add(-5 * time.Minute) + ) + _, err := repo.InsertSubscription(ctx, &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "me-myself-and-i", + URL: "https://no/place/like/home", + NotificationIndex: 42, + StartTime: &begins, + EndTime: &expires, + }) + require.Error(t, err) +} diff --git a/pkg/rid/store/memstore/subscriptions.go b/pkg/rid/store/memstore/subscriptions.go new file mode 100644 index 000000000..5923c3927 --- /dev/null +++ b/pkg/rid/store/memstore/subscriptions.go @@ -0,0 +1,214 @@ +package memstore + +import ( + "context" + "time" + + "github.com/golang/geo/s2" + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/timestamp" + "github.com/interuss/stacktrace" +) + +func subRecordFromModel(s *ridmodels.Subscription, updatedAt time.Time) *subscriptionRecord { + return &subscriptionRecord{ + ID: s.ID, + URL: s.URL, + NotificationIndex: s.NotificationIndex, + Owner: s.Owner, + Cells: cloneCells(s.Cells), + StartTime: cloneTime(s.StartTime), + EndTime: cloneTime(s.EndTime), + AltitudeHi: cloneFloat32(s.AltitudeHi), + AltitudeLo: cloneFloat32(s.AltitudeLo), + Writer: s.Writer, + UpdatedAt: updatedAt, + } +} + +func (rec *subscriptionRecord) toModel() *ridmodels.Subscription { + return &ridmodels.Subscription{ + ID: rec.ID, + URL: rec.URL, + NotificationIndex: rec.NotificationIndex, + Owner: rec.Owner, + Cells: cloneCells(rec.Cells), + StartTime: cloneTime(rec.StartTime), + EndTime: cloneTime(rec.EndTime), + Version: dssmodels.VersionFromTime(rec.UpdatedAt), + AltitudeHi: cloneFloat32(rec.AltitudeHi), + AltitudeLo: cloneFloat32(rec.AltitudeLo), + Writer: rec.Writer, + } +} + +func (r *repo) GetSubscription(_ context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) { + rec, ok := r.state.Subscriptions[id] + if !ok { + return nil, nil + } + return rec.toModel(), nil +} + +func (r *repo) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { + if err := validateWriteData(s.Cells, s.StartTime, s.EndTime); err != nil { + return nil, err + } + if _, ok := r.state.Subscriptions[s.ID]; ok { + return nil, stacktrace.NewError("Subscription with id %s already exists", s.ID) + } + rec := subRecordFromModel(s, timestamp.NowFromContext(ctx)) + r.state.Subscriptions[s.ID] = rec + return rec.toModel(), nil +} + +func (r *repo) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { + if err := validateWriteData(s.Cells, s.StartTime, s.EndTime); err != nil { + return nil, err + } + prev, ok := r.state.Subscriptions[s.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(prev.UpdatedAt).Matches(s.Version) { + return nil, nil + } + rec := subRecordFromModel(s, timestamp.NowFromContext(ctx)) + rec.Owner = prev.Owner + r.state.Subscriptions[s.ID] = rec + return rec.toModel(), nil +} + +func (r *repo) DeleteSubscription(_ context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { + rec, ok := r.state.Subscriptions[s.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(rec.UpdatedAt).Matches(s.Version) { + return nil, nil + } + out := rec.toModel() + delete(r.state.Subscriptions, s.ID) + return out, nil +} + +func (r *repo) SearchSubscriptions(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") + } + now := timestamp.NowFromContext(ctx) + want := cellSet(cells) + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil +} + +func (r *repo) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") + } + now := timestamp.NowFromContext(ctx) + want := cellSet(cells) + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + if rec.Owner != owner { + continue + } + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil +} + +// UpdateNotificationIdxsInCells increments the notification index for each +// subscription in the given cells. +func (r *repo) UpdateNotificationIdxsInCells(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { + now := timestamp.NowFromContext(ctx) + want := cellSet(cells) + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + rec.NotificationIndex++ + out = append(out, rec.toModel()) + } + return out, nil +} + +func (r *repo) MaxSubscriptionCountInCellsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { + now := timestamp.NowFromContext(ctx) + want := cellSet(cells) + counts := make(map[s2.CellID]int, len(cells)) + for _, rec := range r.state.Subscriptions { + if rec.Owner != owner { + continue + } + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + for _, c := range rec.Cells { + if _, ok := want[c]; ok { + counts[c]++ + } + } + } + best := 0 + for _, n := range counts { + if n > best { + best = n + } + } + return best, nil +} + +func (r *repo) ListExpiredSubscriptions(_ context.Context, writer string, threshold time.Time) ([]*ridmodels.Subscription, error) { + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + // ends_at <= threshold + if rec.EndTime == nil || rec.EndTime.After(threshold) { + continue + } + if writer == "" { + if rec.Writer != "" { + continue + } + } else if rec.Writer != writer { + continue + } + out = append(out, rec.toModel()) + + // TODO: This miminc sqlstore inconsistency of not limiting results there, comparted to ISAs. Should it be normalized? + } + return out, nil +} + +func (r *repo) CountSubscriptions(_ context.Context) (int64, error) { + return int64(len(r.state.Subscriptions)), nil +} diff --git a/pkg/rid/store/memstore/subscriptions_test.go b/pkg/rid/store/memstore/subscriptions_test.go new file mode 100644 index 000000000..3b90cf0e0 --- /dev/null +++ b/pkg/rid/store/memstore/subscriptions_test.go @@ -0,0 +1,363 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/golang/geo/s2" + "github.com/google/uuid" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/rid/repos" + "github.com/interuss/dss/pkg/timestamp" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +var ( + // Ensure the struct conforms to the interface + _ repos.Subscription = &repo{} + subscriptionsPool = []struct { + name string + input *ridmodels.Subscription + }{ + { + name: "a subscription with startTime and endTime", + input: &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "myself", + URL: "https://no/place/like/home", + StartTime: &startTime, + EndTime: &endTime, + NotificationIndex: 42, + Writer: writer, + Cells: s2.CellUnion{ + s2.CellID(uint64(overflow)), + 12494535935418957824, + }, + }, + }, + { + name: "a subscription without startTime and with endTime", + input: &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "myself", + URL: "https://no/place/like/home", + EndTime: &endTime, + NotificationIndex: 42, + Cells: s2.CellUnion{ + 12494535935418957824, + }, + }, + }, + { + name: "a subscription without startTime and with endTime", + input: &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "me", + URL: "https://no/place/like/home", + StartTime: &startTime, + EndTime: &endTime, + NotificationIndex: 42, + Cells: s2.CellUnion{ + 12494535935418957824, + }, + }, + }, + } +) + +func TestStoreGetSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + sub2, err := repo.GetSubscription(ctx, sub1.ID) + require.NoError(t, err) + require.NotNil(t, sub2) + + require.Equal(t, *sub1, *sub2) + }) + } +} + +func TestStoreInsertSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + // Test changes without the version differing. + r2 := *sub1 + r2.URL = "new url" + sub2, err := repo.UpdateSubscription(ctx, &r2) + require.NoError(t, err) + require.NotNil(t, sub2) + require.Equal(t, "new url", sub2.URL) + + // Test it doesn't work when Version is nil. + r3 := *sub2 + r3.URL = "new url 2" + r3.Version = nil + sub3, err := repo.UpdateSubscription(ctx, &r3) + require.NoError(t, err) + require.Nil(t, sub3) + + // Bad version doesn't work. + r4 := *sub2 + r4.URL = "new url 3" + r4.Version = dssmodels.VersionFromTime(time.Now()) + sub4, err := repo.UpdateSubscription(ctx, &r4) + require.NoError(t, err) + require.Nil(t, sub4) + + sub5, err := repo.GetSubscription(ctx, sub1.ID) + require.NoError(t, err) + require.NotNil(t, sub5) + + require.Equal(t, *sub2, *sub5) + }) + } +} + +func TestStoreDeleteSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + // Ensure mismatched versions returns nothing + sub1BadVersion := *sub1 + sub1BadVersion.Version, err = dssmodels.VersionFromString("a3cg3tcuhk00") + require.NoError(t, err) + sub2, err := repo.DeleteSubscription(ctx, &sub1BadVersion) + require.NoError(t, err) + require.Nil(t, sub2) + + sub4, err := repo.DeleteSubscription(ctx, sub1) + require.NoError(t, err) + require.NotNil(t, sub4) + + require.Equal(t, *sub1, *sub4) + }) + } +} + +func TestStoreSearchSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + var ( + // pick an L13 value that overflows. + overflow = uint64(17106221850767130624) + + cells = s2.CellUnion{ + s2.CellID(12494535935418957824), + s2.CellID(12494535866699481088), + s2.CellID(12494535901059219456), + s2.CellID(12494535866699481088), + s2.CellID(overflow), + } + owners = []dssmodels.Owner{ + "me", + "my", + "self", + } + ) + + for i, r := range subscriptionsPool { + subscription := *r.input + subscription.Owner = owners[i] + subscription.Cells = cells[:i+1] + sub1, err := repo.InsertSubscription(ctx, &subscription) + require.NoError(t, err) + require.NotNil(t, sub1) + } + // Test normal search + found, err := repo.SearchSubscriptions(ctx, cells) + require.NoError(t, err) + require.Len(t, found, 3) + for _, owner := range owners { + found, err := repo.SearchSubscriptionsByOwner(ctx, cells, owner) + require.NoError(t, err) + require.NotNil(t, found) + // We insert one subscription per owner. Hence, no matter how many cells are touched by the subscription, + // the result should always be 1. + require.Len(t, found, 1) + } +} + +func TestStoreExpiredSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + endTime := fakeClock.Now().Add(24 * time.Hour) + sub := &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner("original owner"), + Cells: s2.CellUnion{s2.CellID(12494535866699481088)}, + EndTime: &endTime, + } + _, err := repo.InsertSubscription(ctx, sub) + require.NoError(t, err) + + // The subscription's endTime is 24 hours from now. + fakeClock.Advance(23 * time.Hour) + ctx = timestamp.WithTimestamp(ctx, fakeClock.Now()) + + // We should still be able to find the subscription by searching and by ID. + subs, err := repo.SearchSubscriptionsByOwner(ctx, sub.Cells, "original owner") + require.NoError(t, err) + require.Len(t, subs, 1) + + ret, err := repo.GetSubscription(ctx, sub.ID) + require.NoError(t, err) + require.NotNil(t, &ret) + + // But now the subscription has expired. + fakeClock.Advance(2 * time.Hour) + ctx = timestamp.WithTimestamp(ctx, fakeClock.Now()) + + subs, err = repo.SearchSubscriptionsByOwner(ctx, sub.Cells, "original owner") + require.NoError(t, err) + require.Len(t, subs, 0) + + ret, err = repo.GetSubscription(ctx, sub.ID) + require.NotNil(t, ret) + require.NoError(t, err) +} + +func TestStoreSubscriptionWithNoGeoData(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + endTime := fakeClock.Now().Add(24 * time.Hour) + sub := &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner("original owner"), + EndTime: &endTime, + } + _, err := repo.InsertSubscription(ctx, sub) + require.Error(t, err) +} + +func TestMaxSubscriptionCountInCellsByOwner(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, s := range subscriptionsPool { + _, err := repo.InsertSubscription(ctx, s.input) + require.NoError(t, err) + } + + count, err := repo.MaxSubscriptionCountInCellsByOwner(ctx, s2.CellUnion{12494535935418957824}, "myself") + require.NoError(t, err) + require.Equal(t, 2, count) +} + +func TestListExpiredSubscriptions(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert Subscription with endtime 1 day from now + subscripiton1 := *subscriptionsPool[0].input + startTime := fakeClock.Now() + subscripiton1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + subscripiton1.EndTime = &endTime + subOut1, err := repo.InsertSubscription(ctx, &subscripiton1) + require.NoError(t, err) + require.NotNil(t, subOut1) + + // Insert Subscription with endtime to 30 minutes ago + subscripiton2 := *subscriptionsPool[0].input + startTime = fakeClock.Now().Add(-1 * time.Hour) + subscripiton2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + subscripiton2.EndTime = &endTime + subscripiton2.ID = dssmodels.ID(uuid.New().String()) + subOut2, err := repo.InsertSubscription(ctx, &subscripiton2) + require.NoError(t, err) + require.NotNil(t, subOut2) + + subscriptions, err := repo.ListExpiredSubscriptions(ctx, writer, fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, subscriptions, 1) +} + +func TestListExpiredSubscriptionsWithEmptyWriter(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert Subscription with endtime 1 day from now + subscripiton1 := *subscriptionsPool[0].input + startTime := fakeClock.Now() + subscripiton1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + subscripiton1.EndTime = &endTime + subscripiton1.Writer = "" + subOut1, err := repo.InsertSubscription(ctx, &subscripiton1) + require.NoError(t, err) + require.NotNil(t, subOut1) + + // Insert Subscription with endtime to 30 minutes ago + subscripiton2 := *subscriptionsPool[0].input + startTime = fakeClock.Now().Add(-1 * time.Hour) + subscripiton2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + subscripiton2.EndTime = &endTime + subscripiton2.ID = dssmodels.ID(uuid.New().String()) + subscripiton2.Writer = "" + subOut2, err := repo.InsertSubscription(ctx, &subscripiton2) + require.NoError(t, err) + require.NotNil(t, subOut2) + + subscriptions, err := repo.ListExpiredSubscriptions(ctx, "", fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, subscriptions, 1) +} + +func TestStoreCountSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + //Cound should be one + count, err := repo.CountSubscriptions(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(1)) + + sub4, err := repo.DeleteSubscription(ctx, sub1) + require.NoError(t, err) + require.NotNil(t, sub4) + + //Cound should be zero + count, err = repo.CountSubscriptions(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(0)) + }) + } +} diff --git a/pkg/rid/store/raftstore/params/params.go b/pkg/rid/store/raftstore/params/params.go new file mode 100644 index 000000000..272b6e789 --- /dev/null +++ b/pkg/rid/store/raftstore/params/params.go @@ -0,0 +1,29 @@ +package params + +import ( + "flag" + + raftparams "github.com/interuss/dss/pkg/raftstore/params" + "github.com/interuss/stacktrace" +) + +const peersFlag = "rid_raft_peers" + +var peers string + +func init() { + flag.StringVar(&peers, peersFlag, "", `Comma-separated "nodeID=peerURL" pairs for the rid store, e.g. "1=http://node1:9011,2=http://node2:9011,3=http://node3:9011"`) +} + +func GetConnectParameters() (raftparams.ConnectParameters, error) { + if peers == "" { + return raftparams.ConnectParameters{}, stacktrace.NewError("--%s is required", peersFlag) + } + + p, err := raftparams.GetConnectParameters("rid") + if err != nil { + return raftparams.ConnectParameters{}, err + } + p.Peers = peers + return p, nil +} diff --git a/pkg/rid/store/raftstore/store.go b/pkg/rid/store/raftstore/store.go index 3022c1251..815202ce0 100644 --- a/pkg/rid/store/raftstore/store.go +++ b/pkg/rid/store/raftstore/store.go @@ -3,8 +3,12 @@ package raftstore import ( "context" + dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/dss/pkg/raftstore" + "github.com/interuss/dss/pkg/raftstore/consensus" "github.com/interuss/dss/pkg/rid/repos" + ridraftparams "github.com/interuss/dss/pkg/rid/store/raftstore/params" + "github.com/interuss/stacktrace" "go.uber.org/zap" ) @@ -12,5 +16,25 @@ import ( type repo struct{} func Init(ctx context.Context, logger *zap.Logger) (*raftstore.Store[repos.Repository], error) { - return raftstore.Init[repos.Repository](ctx, logger, func() repos.Repository { return &repo{} }) + params, err := ridraftparams.GetConnectParameters() + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get rid raft parameters") + } + return raftstore.Init(ctx, logger, params, "rid", &repo{}) +} + +func (r *repo) GetRepo() repos.Repository { return r } + +func (r *repo) IsReadOnly(_ raftstore.RequestType) bool { return false } + +func (r *repo) GetSnapshot() ([]byte, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") +} + +func (r *repo) RestoreFromSnapshot([]byte) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") +} + +func (r *repo) Apply(_ context.Context, _ consensus.Proposal) (any, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") } diff --git a/pkg/rid/store/sqlstore/store_test.go b/pkg/rid/store/sqlstore/store_test.go index 275a1a332..3a57d20da 100644 --- a/pkg/rid/store/sqlstore/store_test.go +++ b/pkg/rid/store/sqlstore/store_test.go @@ -97,7 +97,7 @@ func TestTxnRetrier(t *testing.T) { require.NotNil(t, store) defer tearDownStore() - err := store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := store.Transact(ctx, "", nil, func(ctx context.Context, repo repos.Repository) error { // can query within this isa, err := repo.InsertISA(ctx, serviceArea) require.NotNil(t, isa) @@ -118,7 +118,7 @@ func TestTxnRetrier(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) defer cancel() count := 0 - err = store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err = store.Transact(ctx, "", nil, func(ctx context.Context, repo repos.Repository) error { // can query within this count++ // Postgre retryable error @@ -141,13 +141,13 @@ func TestTransactor(t *testing.T) { subscription2 := subscriptionsPool[1].input txnCount := 0 - err := store.Transact(ctx, func(ctx context.Context, s1 repos.Repository) error { + _, err := store.Transact(ctx, "", nil, func(ctx context.Context, s1 repos.Repository) error { // We should get to this retry, then return nothing. if txnCount > 0 { return errors.New("already failed") } txnCount++ - err := store.Transact(ctx, func(ctx context.Context, s2 repos.Repository) error { + _, err := store.Transact(ctx, "", nil, func(ctx context.Context, s2 repos.Repository) error { subs, err := s1.SearchSubscriptions(ctx, subscription1.Cells) require.NoError(t, err) require.Len(t, subs, 0) diff --git a/pkg/rid/store/store.go b/pkg/rid/store/store.go index aa561a092..7f2dc6c20 100644 --- a/pkg/rid/store/store.go +++ b/pkg/rid/store/store.go @@ -4,6 +4,7 @@ import ( "context" "github.com/interuss/dss/pkg/rid/repos" + ridmemstore "github.com/interuss/dss/pkg/rid/store/memstore" ridraftstore "github.com/interuss/dss/pkg/rid/store/raftstore" ridsqlstore "github.com/interuss/dss/pkg/rid/store/sqlstore" dssstore "github.com/interuss/dss/pkg/store" @@ -24,6 +25,8 @@ func Init(ctx context.Context, logger *zap.Logger, withCheckCron bool) (Store, e return ridsqlstore.Init(ctx, logger, withCheckCron) case params.RaftStoreType: return ridraftstore.Init(ctx, logger) + case params.MemStoreType: + return ridmemstore.Init(ctx, logger) default: return nil, stacktrace.NewError("Unsupported store type %q for rid", storeType) } diff --git a/pkg/scd/constraints_handler.go b/pkg/scd/constraints_handler.go index 0688176c7..aae6b887e 100644 --- a/pkg/scd/constraints_handler.go +++ b/pkg/scd/constraints_handler.go @@ -88,7 +88,7 @@ func (a *Server) DeleteConstraintReference(ctx context.Context, req *restapi.Del return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { err = stacktrace.Propagate(err, "Could not delete constraint") errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -147,7 +147,7 @@ func (a *Server) GetConstraintReference(ctx context.Context, req *restapi.GetCon return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { err = stacktrace.Propagate(err, "Could not get constraint") if stacktrace.GetCode(err) == dsserr.NotFound { @@ -306,7 +306,7 @@ func (a *Server) PutConstraintReference(ctx context.Context, manager string, ent return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { return nil, err // No need to Propagate this error as this is not a useful stacktrace line } @@ -434,7 +434,7 @@ func (a *Server) QueryConstraintReferences(ctx context.Context, req *restapi.Que return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { return restapi.QueryConstraintReferencesResponseSet{Response500: &api.InternalServerErrorBody{ ErrorMessage: *dsserr.Handle(ctx, stacktrace.Propagate(err, "Got an unexpected error"))}} diff --git a/pkg/scd/operational_intents_handler.go b/pkg/scd/operational_intents_handler.go index f611797d0..35ac9c1ee 100644 --- a/pkg/scd/operational_intents_handler.go +++ b/pkg/scd/operational_intents_handler.go @@ -165,7 +165,7 @@ func (a *Server) DeleteOperationalIntentReference(ctx context.Context, req *rest return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { err = stacktrace.Propagate(err, "Could not delete operational intent") errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -221,7 +221,7 @@ func (a *Server) GetOperationalIntentReference(ctx context.Context, req *restapi return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { err = stacktrace.Propagate(err, "Could not get operational intent") if stacktrace.GetCode(err) == dsserr.NotFound { @@ -288,7 +288,7 @@ func (a *Server) QueryOperationalIntentReferences(ctx context.Context, req *rest return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { err = stacktrace.Propagate(err, "Could not query operational intent") if stacktrace.GetCode(err) == dsserr.BadRequest { @@ -934,7 +934,7 @@ func (a *Server) upsertOperationalIntentReference(ctx context.Context, now time. return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { return nil, responseConflict, err // No need to Propagate this error as this is not a useful stacktrace line } diff --git a/pkg/scd/store/memstore/availability.go b/pkg/scd/store/memstore/availability.go new file mode 100644 index 000000000..3579672e1 --- /dev/null +++ b/pkg/scd/store/memstore/availability.go @@ -0,0 +1,18 @@ +package memstore + +import ( + "context" + + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/stacktrace" +) + +func (r *repo) GetUssAvailability(_ context.Context, id dssmodels.Manager) (*scdmodels.UssAvailabilityStatus, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetUssAvailability not implemented for memstore") +} + +func (r *repo) UpsertUssAvailability(_ context.Context, ussa *scdmodels.UssAvailabilityStatus) (*scdmodels.UssAvailabilityStatus, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpsertUssAvailability not implemented for memstore") +} diff --git a/pkg/scd/store/memstore/constraints.go b/pkg/scd/store/memstore/constraints.go new file mode 100644 index 000000000..be3e46eee --- /dev/null +++ b/pkg/scd/store/memstore/constraints.go @@ -0,0 +1,30 @@ +package memstore + +import ( + "context" + + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/stacktrace" +) + +func (r *repo) SearchConstraints(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Constraint, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchConstraints not implemented for memstore") +} + +func (r *repo) GetConstraint(_ context.Context, id dssmodels.ID) (*scdmodels.Constraint, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetConstraint not implemented for memstore") +} + +func (r *repo) UpsertConstraint(_ context.Context, constraint *scdmodels.Constraint) (*scdmodels.Constraint, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpsertConstraint not implemented for memstore") +} + +func (r *repo) DeleteConstraint(_ context.Context, id dssmodels.ID) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteConstraint not implemented for memstore") +} + +func (r *repo) CountConstraints(_ context.Context) (int64, error) { + return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountConstraint not implemented for memstore") +} diff --git a/pkg/scd/store/memstore/doc.go b/pkg/scd/store/memstore/doc.go new file mode 100644 index 000000000..f86b85987 --- /dev/null +++ b/pkg/scd/store/memstore/doc.go @@ -0,0 +1,3 @@ +// Package scd.store.memstore provides a full implementation of store.Store[scd.repos.Repository] +// storing data in memory. It is meant to be used by raftstore. +package memstore diff --git a/pkg/scd/store/memstore/operational_intents.go b/pkg/scd/store/memstore/operational_intents.go new file mode 100644 index 000000000..b39f81793 --- /dev/null +++ b/pkg/scd/store/memstore/operational_intents.go @@ -0,0 +1,39 @@ +package memstore + +import ( + "context" + "time" + + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/stacktrace" +) + +func (r *repo) GetOperationalIntent(_ context.Context, id dssmodels.ID) (*scdmodels.OperationalIntent, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetOperationalIntent not implemented for memstore") +} + +func (r *repo) DeleteOperationalIntent(_ context.Context, id dssmodels.ID) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteOperationalIntent not implemented for memstore") +} + +func (r *repo) UpsertOperationalIntent(_ context.Context, operation *scdmodels.OperationalIntent) (*scdmodels.OperationalIntent, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpsertOperationalIntent not implemented for memstore") +} + +func (r *repo) SearchOperationalIntents(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.OperationalIntent, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchOperationalIntents not implemented for memstore") +} + +func (r *repo) GetDependentOperationalIntents(_ context.Context, subscriptionID dssmodels.ID) ([]dssmodels.ID, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDependentOperationalIntents not implemented for memstore") +} + +func (r *repo) ListExpiredOperationalIntents(_ context.Context, threshold time.Time) ([]*scdmodels.OperationalIntent, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredOperationalIntents not implemented for memstore") +} + +func (r *repo) CountOperationalIntents(_ context.Context) (int64, error) { + return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountOperationalIntents not implemented for memstore") +} diff --git a/pkg/scd/store/memstore/snapshot.go b/pkg/scd/store/memstore/snapshot.go new file mode 100644 index 000000000..dea64e6a4 --- /dev/null +++ b/pkg/scd/store/memstore/snapshot.go @@ -0,0 +1,13 @@ +package memstore + +import ( + "github.com/interuss/stacktrace" +) + +func (r *repo) GetSnapshot() ([]byte, error) { + return nil, stacktrace.NewError("GetSnapshot not yet implemented for rid") +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + return stacktrace.NewError("RestoreFromSnapshot not yet implemented for rid") +} diff --git a/pkg/scd/store/memstore/store.go b/pkg/scd/store/memstore/store.go new file mode 100644 index 000000000..45365614a --- /dev/null +++ b/pkg/scd/store/memstore/store.go @@ -0,0 +1,18 @@ +package memstore + +import ( + "context" + + "github.com/interuss/dss/pkg/memstore" + "github.com/interuss/dss/pkg/scd/repos" + "go.uber.org/zap" +) + +// repo is a full implementation of scd.repos.Repository for memory-based storage. +type repo struct{} + +func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { + return memstore.Init(ctx, logger, "scd", &repo{}) +} + +func (r *repo) GetRepo() repos.Repository { return r } diff --git a/pkg/scd/store/memstore/subscriptions.go b/pkg/scd/store/memstore/subscriptions.go new file mode 100644 index 000000000..c284e6df8 --- /dev/null +++ b/pkg/scd/store/memstore/subscriptions.go @@ -0,0 +1,48 @@ +package memstore + +import ( + "context" + "time" + + "github.com/golang/geo/s2" + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/stacktrace" +) + +func (r *repo) SearchSubscriptions(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchSubscriptions not implemented for memstore") +} + +func (r *repo) GetSubscription(_ context.Context, id dssmodels.ID) (*scdmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetSubscription not implemented for memstore") +} + +func (r *repo) UpsertSubscription(_ context.Context, sub *scdmodels.Subscription) (*scdmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpsertSubscription not implemented for memstore") +} + +func (r *repo) DeleteSubscription(_ context.Context, id dssmodels.ID) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteSubscription not implemented for memstore") +} + +func (r *repo) IncrementNotificationIndicesForOperationalIntents(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "IncrementNotificationIndicesForOperationalIntents not implemented for memstore") +} + +func (r *repo) IncrementNotificationIndicesForConstraints(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "IncrementNotificationIndicesForConstraints not implemented for memstore") +} + +func (r *repo) LockSubscriptionsOnCells(_ context.Context, cells s2.CellUnion, subscriptionIds []dssmodels.ID, startTime *time.Time, endTime *time.Time) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "LockSubscriptionsOnCells not implemented for memstore") +} + +func (r *repo) ListExpiredSubscriptions(_ context.Context, threshold time.Time) ([]*scdmodels.Subscription, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredSubscriptions not implemented for memstore") +} + +func (r *repo) CountSubscriptions(_ context.Context) (int64, error) { + return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountSubscriptions not implemented for memstore") +} diff --git a/pkg/scd/store/raftstore/params/params.go b/pkg/scd/store/raftstore/params/params.go new file mode 100644 index 000000000..8642a7efc --- /dev/null +++ b/pkg/scd/store/raftstore/params/params.go @@ -0,0 +1,29 @@ +package params + +import ( + "flag" + + raftparams "github.com/interuss/dss/pkg/raftstore/params" + "github.com/interuss/stacktrace" +) + +const peersFlag = "scd_raft_peers" + +var peers string + +func init() { + flag.StringVar(&peers, peersFlag, "", `Comma-separated "nodeID=peerURL" pairs for the scd store, e.g. "1=http://node1:9021,2=http://node2:9021,3=http://node3:9021"`) +} + +func GetConnectParameters() (raftparams.ConnectParameters, error) { + if peers == "" { + return raftparams.ConnectParameters{}, stacktrace.NewError("--%s is required", peersFlag) + } + + p, err := raftparams.GetConnectParameters("scd") + if err != nil { + return raftparams.ConnectParameters{}, err + } + p.Peers = peers + return p, nil +} diff --git a/pkg/scd/store/raftstore/store.go b/pkg/scd/store/raftstore/store.go index 75dd1ff5c..e533284c9 100644 --- a/pkg/scd/store/raftstore/store.go +++ b/pkg/scd/store/raftstore/store.go @@ -3,8 +3,12 @@ package raftstore import ( "context" + dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/dss/pkg/raftstore" + "github.com/interuss/dss/pkg/raftstore/consensus" "github.com/interuss/dss/pkg/scd/repos" + scdraftparams "github.com/interuss/dss/pkg/scd/store/raftstore/params" + "github.com/interuss/stacktrace" "go.uber.org/zap" ) @@ -12,5 +16,25 @@ import ( type repo struct{} func Init(ctx context.Context, logger *zap.Logger) (*raftstore.Store[repos.Repository], error) { - return raftstore.Init[repos.Repository](ctx, logger, func() repos.Repository { return &repo{} }) + params, err := scdraftparams.GetConnectParameters() + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get scd raft parameters") + } + return raftstore.Init(ctx, logger, params, "scd", &repo{}) +} + +func (r *repo) GetRepo() repos.Repository { return r } + +func (r *repo) IsReadOnly(_ raftstore.RequestType) bool { return false } + +func (r *repo) GetSnapshot() ([]byte, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") +} + +func (r *repo) RestoreFromSnapshot([]byte) error { + return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") +} + +func (r *repo) Apply(_ context.Context, _ consensus.Proposal) (any, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "not implemented yet") } diff --git a/pkg/scd/store/store.go b/pkg/scd/store/store.go index a8603fb42..9f313d511 100644 --- a/pkg/scd/store/store.go +++ b/pkg/scd/store/store.go @@ -4,6 +4,7 @@ import ( "context" "github.com/interuss/dss/pkg/scd/repos" + scdmemstore "github.com/interuss/dss/pkg/scd/store/memstore" scdraftstore "github.com/interuss/dss/pkg/scd/store/raftstore" scdsqlstore "github.com/interuss/dss/pkg/scd/store/sqlstore" dssstore "github.com/interuss/dss/pkg/store" @@ -24,6 +25,8 @@ func Init(ctx context.Context, logger *zap.Logger, withCheckCron bool, globalLoc return scdsqlstore.Init(ctx, logger, withCheckCron, globalLock) case params.RaftStoreType: return scdraftstore.Init(ctx, logger) + case params.MemStoreType: + return scdmemstore.Init(ctx, logger) default: return nil, stacktrace.NewError("Unsupported store type %q for scd", storeType) } diff --git a/pkg/scd/subscriptions_handler.go b/pkg/scd/subscriptions_handler.go index 98ab9fed2..434f2bfea 100644 --- a/pkg/scd/subscriptions_handler.go +++ b/pkg/scd/subscriptions_handler.go @@ -276,7 +276,7 @@ func (a *Server) PutSubscription(ctx context.Context, manager string, subscripti return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { return nil, err // No need to Propagate this error as this is not a useful stacktrace line } @@ -340,7 +340,7 @@ func (a *Server) GetSubscription(ctx context.Context, req *restapi.GetSubscripti return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { err = stacktrace.Propagate(err, "Could not get subscription") errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -424,7 +424,7 @@ func (a *Server) QuerySubscriptions(ctx context.Context, req *restapi.QuerySubsc return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -517,7 +517,7 @@ func (a *Server) DeleteSubscription(ctx context.Context, req *restapi.DeleteSubs return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { err = stacktrace.Propagate(err, "Could not delete subscription") errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} diff --git a/pkg/scd/uss_availability_handler.go b/pkg/scd/uss_availability_handler.go index 517a9859a..0c9b05bb3 100644 --- a/pkg/scd/uss_availability_handler.go +++ b/pkg/scd/uss_availability_handler.go @@ -51,7 +51,7 @@ func (a *Server) GetUssAvailability(ctx context.Context, req *restapi.GetUssAvai return nil } - err := a.Store.Transact(ctx, action) + _, err := a.Store.Transact(ctx, "", nil, action) if err != nil { // In case of older DB versions where availability table doesn't exist if strings.Contains(err.Error(), "does not exist") { @@ -121,7 +121,7 @@ func (a *Server) SetUssAvailability(ctx context.Context, req *restapi.SetUssAvai } return nil } - err = a.Store.Transact(ctx, action) + _, err = a.Store.Transact(ctx, "", nil, action) if err != nil { // In case of older DB versions where availability table doesn't exist if strings.Contains(err.Error(), "does not exist") { diff --git a/pkg/sqlstore/store.go b/pkg/sqlstore/store.go index 0255620c6..28c3a1cbb 100644 --- a/pkg/sqlstore/store.go +++ b/pkg/sqlstore/store.go @@ -173,9 +173,9 @@ func (s *Store[R]) Interact(_ context.Context) (R, error) { return s.newRepo(s.Pool), nil } -func (s *Store[R]) Transact(ctx context.Context, f func(context.Context, R) error) error { +func (s *Store[R]) Transact(ctx context.Context, _ string, _ any, f func(context.Context, R) error) (any, error) { ctx = crdb.WithMaxRetries(ctx, s.maxRetries) - return crdbpgx.ExecuteTx(ctx, s.Pool, pgx.TxOptions{IsoLevel: pgx.Serializable}, func(tx pgx.Tx) error { + return nil, crdbpgx.ExecuteTx(ctx, s.Pool, pgx.TxOptions{IsoLevel: pgx.Serializable}, func(tx pgx.Tx) error { return f(ctx, s.newRepo(tx)) }) } diff --git a/pkg/store/params/params.go b/pkg/store/params/params.go index 7a25e29f6..3bf49bfa1 100644 --- a/pkg/store/params/params.go +++ b/pkg/store/params/params.go @@ -8,6 +8,7 @@ import ( const ( RaftStoreType = "raft" SQLStoreType = "sql" + MemStoreType = "mem" ) type ( @@ -22,6 +23,7 @@ var ( ) func init() { + // NB: Memstore not available there on purpose. flag.StringVar(&storeParameters.StoreType, "store_type", SQLStoreType, fmt.Sprintf("Store type. Use '%s' for CockroachDB/YugabyteDB and '%s' for Raft-based store.", SQLStoreType, RaftStoreType)) } diff --git a/pkg/store/store.go b/pkg/store/store.go index e5e95c4b0..e59318a9e 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -16,7 +16,9 @@ type Store[R any] interface { Interact(context.Context) (R, error) // Attempt to apply the operations in f to the R Repo it is supplied. All operations performed // on the R Repo by f will be applied or rejected atomically. - Transact(ctx context.Context, f func(context.Context, R) error) error + // requestType and payload are used by the Raftstore to build the proposal. + // The returned any is the proposal result (also Raftstore only). + Transact(ctx context.Context, requestType string, payload any, f func(context.Context, R) error) (any, error) } const ( diff --git a/pkg/timestamp/timestamp.go b/pkg/timestamp/timestamp.go new file mode 100644 index 000000000..6ac952a16 --- /dev/null +++ b/pkg/timestamp/timestamp.go @@ -0,0 +1,32 @@ +package timestamp + +import ( + "context" + "net/http" + "time" +) + +type timestampKey struct{} + +// NowFromContext returns the timestamp from the context, or zero if not present. +func NowFromContext(ctx context.Context) time.Time { + t, ok := ctx.Value(timestampKey{}).(time.Time) + if !ok { + return time.Time{} + } + + return t +} + +// WithTimestamp returns a new context with the given timestamp. +func WithTimestamp(ctx context.Context, t time.Time) context.Context { + return context.WithValue(ctx, timestampKey{}, t) +} + +// Middleware is an HTTP middleware that adds a timestamp to the request context. +func Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := WithTimestamp(r.Context(), time.Now().UTC()) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +}