diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6cb9f1dac..6b1f506ea 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -127,6 +127,41 @@ jobs: - name: Bring down local DSS instance run: make down-locally + dss-tests-with-raft: + name: DSS tests with Raft + runs-on: ubuntu-latest + env: + COMPOSE_PROFILES: with-raft + steps: + - name: Job information + run: | + echo "Job information" + echo "Trigger: ${{ github.event_name }}" + echo "Host: ${{ runner.os }}" + echo "Repository: ${{ github.repository }}" + echo "Branch: ${{ github.ref }}" + docker images + go env + - name: Checkout + uses: actions/checkout@v6 + with: + submodules: true + - name: Build dss image + run: make build-dss + - name: Tear down any pre-existing local DSS instance + run: make down-locally + - name: Start local DSS instance + run: make start-locally + - name: Probe local DSS instance + run: make probe-locally + - name: Run Qualifier against local DSS instance + run: make qualify-locally + # Todo: Re add evict tests here once evict is implemented. + - name: Run security tests against local DSS instance + run: make security-locally + - name: Bring down local DSS instance + run: make down-locally + certificates-management-tests: name: Certificate management tests runs-on: ubuntu-latest diff --git a/build/dev/docker-compose_dss.yaml b/build/dev/docker-compose_dss.yaml index deb7b0c19..084433225 100644 --- a/build/dev/docker-compose_dss.yaml +++ b/build/dev/docker-compose_dss.yaml @@ -111,6 +111,7 @@ services: - $PWD/../test-certs:/var/test-certs:ro - $PWD/startup/core_service.sh:/startup/core_service.sh:ro - $PWD/startup/coverdata:/startup/coverdata:rw # we will save coverage info here + - raftdata:/raftdata environment: COMPOSE_PROFILES: ${COMPOSE_PROFILES} # Note: requires the Dockerfile to have been built with "-cover" in the EXTRA_GO_INSTALL_FLAGS var @@ -142,7 +143,7 @@ services: interval: 3m start_period: 30s start_interval: 5s - profiles: ["", "with-yugabyte"] + profiles: ["", "with-yugabyte", "with-raft"] local-dss-dummy-oauth: build: @@ -166,3 +167,4 @@ networks: volumes: local-dss-data: + raftdata: diff --git a/build/dev/startup/core_service.sh b/build/dev/startup/core_service.sh index f09bda597..36b72b3bc 100755 --- a/build/dev/startup/core_service.sh +++ b/build/dev/startup/core_service.sh @@ -10,8 +10,13 @@ if [ "${COMPOSE_PROFILES#*"with-yugabyte"}" != "${COMPOSE_PROFILES}" ]; then echo "Using Yugabyte" DATASTORE_CONNECTION="-datastore_host local-dss-ybdb -datastore_user yugabyte --datastore_port 5433" else - echo "Using CockroachDB" - DATASTORE_CONNECTION="-datastore_host local-dss-crdb" + if [ "${COMPOSE_PROFILES#*"with-raft"}" != "${COMPOSE_PROFILES}" ]; then + echo "Using raft" + DATASTORE_CONNECTION="-store_type raft -raft_node_id=1 -rid_raft_peers=1=http://127.0.0.1:9011 -scd_raft_peers=1=http://127.0.0.1:9021 -aux_raft_peers=1=http://127.0.0.1:9031 -raft_datadir /raftdata" + else + echo "Using CockroachDB" + DATASTORE_CONNECTION="-datastore_host local-dss-crdb" + fi fi if [ "$DEBUG_ON" = "1" ]; then diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index 7931bf045..cfeccae75 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -32,6 +32,7 @@ import ( scds "github.com/interuss/dss/pkg/scd/store" "github.com/interuss/dss/pkg/store" "github.com/interuss/dss/pkg/store/params" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/dss/pkg/version" "github.com/interuss/dss/pkg/versioning" "github.com/interuss/stacktrace" @@ -340,6 +341,7 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st handler = authorizer.TokenMiddleware(handler) handler = http.TimeoutHandler(handler, *timeout, "request timeout") handler = logging.HTTPMiddleware(logger, *dumpRequests, handler) + handler = timestamp.Middleware(handler) if *enableMetrics || *enableTracing { // We use the default settings; the APIRouter handler will override the span value accordingly, as it has more information. diff --git a/cmds/db-manager/cleanup/evict.go b/cmds/db-manager/cleanup/evict.go index 9584ef620..e605cb4c8 100644 --- a/cmds/db-manager/cleanup/evict.go +++ b/cmds/db-manager/cleanup/evict.go @@ -100,7 +100,7 @@ func evict(cmd *cobra.Command, _ []string) error { } return nil } - if err = scdStore.Transact(ctx, scdAction); err != nil { + if _, err = scdStore.Transact(ctx, "", nil, scdAction); err != nil { return fmt.Errorf("failed to execute SCD transaction: %w", err) } @@ -145,7 +145,7 @@ func evict(cmd *cobra.Command, _ []string) error { return nil } - if err = ridStore.Transact(ctx, ridAction); err != nil { + if _, err = ridStore.Transact(ctx, "", nil, ridAction); err != nil { return fmt.Errorf("failed to execute RID transaction: %w", err) } diff --git a/go.mod b/go.mod index a8fa1da50..d0722521b 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/go-jose/go-jose/v4 v4.1.4 github.com/golang-jwt/jwt/v4 v4.5.2 github.com/golang/geo v0.0.0-20230421003525-6adc56603217 + github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/interuss/stacktrace v1.0.0 github.com/jackc/pgx/v5 v5.9.2 @@ -27,11 +28,13 @@ require ( go.opentelemetry.io/otel v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 go.opentelemetry.io/otel/exporters/prometheus v0.65.0 + go.opentelemetry.io/otel/metric v1.43.0 go.opentelemetry.io/otel/sdk v1.43.0 go.opentelemetry.io/otel/sdk/metric v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 go.uber.org/multierr v1.11.0 go.uber.org/zap v1.27.0 + golang.org/x/sync v0.20.0 ) require ( @@ -71,13 +74,11 @@ require ( go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect - go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect go.yaml.in/yaml/v2 v2.4.4 // indirect golang.org/x/crypto v0.49.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect - golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.36.0 // indirect golang.org/x/time v0.14.0 // indirect diff --git a/pkg/aux_/store/memstore/doc.go b/pkg/aux_/store/memstore/doc.go new file mode 100644 index 000000000..0fe2e87de --- /dev/null +++ b/pkg/aux_/store/memstore/doc.go @@ -0,0 +1,3 @@ +// Package aux_.store.memstore provides a full implementation of store.Store[aux_.repos.Repository] +// storing data in memory. It is meant to be used by raftstore. +package memstore diff --git a/pkg/aux_/store/memstore/dss.go b/pkg/aux_/store/memstore/dss.go new file mode 100644 index 000000000..4f19699a5 --- /dev/null +++ b/pkg/aux_/store/memstore/dss.go @@ -0,0 +1,86 @@ +package memstore + +import ( + "context" + "database/sql" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/dss/pkg/timestamp" + "github.com/interuss/stacktrace" +) + +func (r *repo) SaveOwnMetadata(ctx context.Context, locality string, publicEndpoint string) error { + if locality == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") + } + if publicEndpoint == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Public endpoint not set") + } + + r.state.Participants[locality] = &participant{ + PublicEndpoint: publicEndpoint, + UpdatedAt: timestamp.NowFromContext(ctx), + } + return nil +} + +func (r *repo) GetDSSMetadata(_ context.Context) ([]*auxmodels.DSSMetadata, error) { + metadata := make([]*auxmodels.DSSMetadata, 0, len(r.state.Participants)) + for locality, p := range r.state.Participants { + updatedAt := p.UpdatedAt + m := &auxmodels.DSSMetadata{ + Locality: locality, + PublicEndpoint: p.PublicEndpoint, + UpdatedAt: &updatedAt, + } + + // Find the latest heartbeat across all sources for this locality. + var latest auxmodels.Heartbeat + found := false + for key, hb := range r.state.Heartbeats { + if key.Locality != locality { + continue + } + if !found || hb.Timestamp.After(*latest.Timestamp) { + latest = hb + found = true + } + } + + if found { + m.LatestTimestamp.Source = sql.NullString{String: latest.Source, Valid: true} + m.LatestTimestamp.Timestamp = latest.Timestamp + m.LatestTimestamp.NextHeartbeatExpectedBefore = latest.NextHeartbeatExpectedBefore + m.LatestTimestamp.Reporter = sql.NullString{String: latest.Reporter, Valid: true} + } + + metadata = append(metadata, m) + } + return metadata, nil +} + +func (r *repo) RecordHeartbeat(ctx context.Context, heartbeat auxmodels.Heartbeat) error { + if heartbeat.Locality == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") + } + if heartbeat.Source == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Source not set") + } + + if heartbeat.Timestamp == nil { + now := timestamp.NowFromContext(ctx) + heartbeat.Timestamp = &now + } + + if heartbeat.NextHeartbeatExpectedBefore != nil && heartbeat.NextHeartbeatExpectedBefore.Before(*heartbeat.Timestamp) { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Cannot expect the timestamp of the next heartbeat before the timestamp of the new heartbeat") + } + + r.state.Heartbeats[heartbeatKey{Locality: heartbeat.Locality, Source: heartbeat.Source}] = heartbeat + return nil +} + +func (r *repo) GetDSSAirspaceRepresentationID(_ context.Context) (string, error) { + return "", stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSAirspaceRepresentationID not implementable for memstore") +} diff --git a/pkg/aux_/store/memstore/dss_test.go b/pkg/aux_/store/memstore/dss_test.go new file mode 100644 index 000000000..43563d27b --- /dev/null +++ b/pkg/aux_/store/memstore/dss_test.go @@ -0,0 +1,125 @@ +package memstore + +import ( + "context" + "testing" + "time" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/stacktrace" + "github.com/stretchr/testify/require" +) + +func TestSaveOwnMetadataValidation(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.SaveOwnMetadata(ctx, "", "https://example.com"))) + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.SaveOwnMetadata(ctx, "dss-1", ""))) +} + +func TestSaveOwnMetadataRoundTrip(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.Equal(t, "dss-1", md[0].Locality) + require.Equal(t, "https://example.com", md[0].PublicEndpoint) + require.NotNil(t, md[0].UpdatedAt) + + // No heartbeat recorded yet. + require.False(t, md[0].LatestTimestamp.Source.Valid) + require.Nil(t, md[0].LatestTimestamp.Timestamp) +} + +func TestSaveOwnMetadataUpsert(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://old.example.com")) + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://new.example.com")) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.Equal(t, "https://new.example.com", md[0].PublicEndpoint) +} + +func TestRecordHeartbeatValidation(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Source: "source1"}))) + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1"}))) + + ts := time.Now() + before := ts.Add(-time.Minute) + err := r.RecordHeartbeat(ctx, auxmodels.Heartbeat{ + Locality: "dss-1", + Source: "source1", + Timestamp: &ts, + NextHeartbeatExpectedBefore: &before, + }) + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(err)) +} + +func TestRecordHeartbeatDefaultsTimestamp(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1"})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Source.Valid) + require.NotNil(t, md[0].LatestTimestamp.Timestamp) +} + +func TestGetDSSMetadataPicksLatestHeartbeat(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + older := time.Now().Add(-time.Hour) + newer := time.Now() + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &older, Reporter: "uss1"})) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source2", Timestamp: &newer, Reporter: "uss2"})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Timestamp.Equal(newer)) + require.Equal(t, "source2", md[0].LatestTimestamp.Source.String) + require.Equal(t, "uss2", md[0].LatestTimestamp.Reporter.String) +} + +func TestGetDSSMetadataUpdatesHeartbeatPerSource(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + first := time.Now().Add(-time.Hour) + second := time.Now() + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &first})) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &second})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Timestamp.Equal(second)) +} diff --git a/pkg/aux_/store/memstore/snapshot.go b/pkg/aux_/store/memstore/snapshot.go new file mode 100644 index 000000000..15deb274b --- /dev/null +++ b/pkg/aux_/store/memstore/snapshot.go @@ -0,0 +1,35 @@ +package memstore + +import ( + "bytes" + "encoding/gob" + + "github.com/interuss/stacktrace" +) + +const snapshotVersion = 1 + +type snapshotEnvelope struct { + Version int + State state +} + +func (r *repo) GetSnapshot() ([]byte, error) { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion, State: r.state}); err != nil { + return nil, stacktrace.Propagate(err, "Failed to encode memstore snapshot") + } + return buf.Bytes(), nil +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + var env snapshotEnvelope + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&env); err != nil { + return stacktrace.Propagate(err, "Failed to decode memstore snapshot") + } + if env.Version != snapshotVersion { + return stacktrace.NewError("Unsupported memstore snapshot version %d, expected %d", env.Version, snapshotVersion) + } + r.state = env.State + return nil +} diff --git a/pkg/aux_/store/memstore/snapshot_test.go b/pkg/aux_/store/memstore/snapshot_test.go new file mode 100644 index 000000000..2085e2488 --- /dev/null +++ b/pkg/aux_/store/memstore/snapshot_test.go @@ -0,0 +1,59 @@ +package memstore + +import ( + "bytes" + "context" + "encoding/gob" + "testing" + "time" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + "github.com/stretchr/testify/require" +) + +func TestSnapshotRoundTrip(t *testing.T) { + ctx := context.Background() + src := newRepo() + require.NoError(t, src.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + ts := time.Now().UTC() + require.NoError(t, src.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source-1", Timestamp: &ts, Reporter: "uss-1"})) + + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := newRepo() + require.NoError(t, dst.RestoreFromSnapshot(data)) + + want, err := src.GetDSSMetadata(ctx) + require.NoError(t, err) + got, err := dst.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Equal(t, want, got) +} + +func TestRestoreFromSnapshotReplacesState(t *testing.T) { + ctx := context.Background() + src := newRepo() + require.NoError(t, src.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := newRepo() + require.NoError(t, dst.SaveOwnMetadata(ctx, "dss-2", "https://other.example.com")) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + md, err := dst.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Len(t, md, 1) + require.Equal(t, "dss-1", md[0].Locality) +} + +func TestRestoreFromSnapshotInvalidData(t *testing.T) { + require.Error(t, newRepo().RestoreFromSnapshot([]byte("random value that is definitely not valid"))) +} + +func TestRestoreFromSnapshotVersionMismatch(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion + 1})) + require.Error(t, newRepo().RestoreFromSnapshot(buf.Bytes())) +} diff --git a/pkg/aux_/store/memstore/store.go b/pkg/aux_/store/memstore/store.go new file mode 100644 index 000000000..bdb8d35d3 --- /dev/null +++ b/pkg/aux_/store/memstore/store.go @@ -0,0 +1,83 @@ +package memstore + +import ( + "context" + "time" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + "github.com/interuss/dss/pkg/aux_/repos" + "github.com/interuss/dss/pkg/memstore" + "github.com/interuss/stacktrace" + "go.uber.org/zap" +) + +// repo is a full implementation of aux_.repos.Repository for memory-based storage. +type repo struct { + state state +} + +// state is the serializable in-memory state. +type state struct { + // Participants holds pool participants metadata, keyed by locality. + Participants map[string]*participant + // Heartbeats holds the latest heartbeat per (locality, source). + Heartbeats map[heartbeatKey]auxmodels.Heartbeat + // participants holds pool participants metadata, keyed by locality. + participants map[string]*participant + // heartbeats holds the latest heartbeat per (locality, source). + heartbeats map[heartbeatKey]auxmodels.Heartbeat +} + +type participant struct { + PublicEndpoint string + UpdatedAt time.Time +} + +type heartbeatKey struct { + Locality string + Source string +} + +func newRepo() *repo { + return &repo{ + state: state{ + Participants: map[string]*participant{}, + Heartbeats: map[heartbeatKey]auxmodels.Heartbeat{}, + }} +} + +func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { + return memstore.Init(ctx, logger, "aux_", newRepo()) +} + +func (r *repo) GetRepo() repos.Repository { return r } + +// clone returns a copy of s with independent maps and participant records. +func (s state) clone() state { + ps := make(map[string]*participant, len(s.Participants)) + for k, v := range s.Participants { + cp := *v + ps[k] = &cp + } + hb := make(map[heartbeatKey]auxmodels.Heartbeat, len(s.Heartbeats)) + for k, v := range s.Heartbeats { + hb[k] = v + } + return state{Participants: ps, Heartbeats: hb} +} + +// Checkpoint returns a fast, restorable in-memory copy of the current state. +func (r *repo) Checkpoint() any { + return r.state.clone() +} + +// Restore replaces the current state with a checkpoint previously returned by +// Checkpoint. The checkpoint is copied, so it stays reusable. +func (r *repo) Restore(cp any) error { + s, ok := cp.(state) + if !ok { + return stacktrace.NewError("Invalid checkpoint type %T", cp) + } + r.state = s.clone() + return nil +} diff --git a/pkg/aux_/store/memstore/store_test.go b/pkg/aux_/store/memstore/store_test.go new file mode 100644 index 000000000..737c40d25 --- /dev/null +++ b/pkg/aux_/store/memstore/store_test.go @@ -0,0 +1,51 @@ +package memstore + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCheckpointRestore(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + cp := r.Checkpoint() + + // Mutate after the checkpoint. + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-2", "https://other.example.com")) + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Len(t, md, 2) + + // Restore drops dss-2. + require.NoError(t, r.Restore(cp)) + md, err = r.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Len(t, md, 1) + require.Equal(t, "dss-1", md[0].Locality) +} + +func TestCheckpointIsolatesUpsert(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://old.example.com")) + + cp := r.Checkpoint() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://new.example.com")) + + require.NoError(t, r.Restore(cp)) + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Len(t, md, 1) + require.Equal(t, "https://old.example.com", md[0].PublicEndpoint) +} + +func TestRestoreInvalidType(t *testing.T) { + require.Error(t, newRepo().Restore("not a checkpoint")) +} diff --git a/pkg/aux_/store/raftstore/dss.go b/pkg/aux_/store/raftstore/dss.go index 871d06394..32feff5e4 100644 --- a/pkg/aux_/store/raftstore/dss.go +++ b/pkg/aux_/store/raftstore/dss.go @@ -2,25 +2,67 @@ package raftstore import ( "context" + "strconv" auxmodels "github.com/interuss/dss/pkg/aux_/models" dsserr "github.com/interuss/dss/pkg/errors" + raftparams "github.com/interuss/dss/pkg/raftstore/params" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/stacktrace" ) -// SaveOwnMetadata returns nil instead of dsserr.NotImplemented because it is needed to allow the server to startup. -func (r *repo) SaveOwnMetadata(_ context.Context, locality string, publicEndpoint string) error { - return nil +type saveOwnMetadataPayload struct { + Locality string + PublicEndpoint string } -func (r *repo) GetDSSMetadata(_ context.Context) ([]*auxmodels.DSSMetadata, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSMetadata not implemented for raftstore") +func (r *repo) SaveOwnMetadata(ctx context.Context, locality string, publicEndpoint string) error { + if locality == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") + } + if publicEndpoint == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Public endpoint not set") + } + + _, err := r.consensus.ProposeValue(ctx, saveOwnMetadata, saveOwnMetadataPayload{ + Locality: locality, + PublicEndpoint: publicEndpoint, + }, false) + return err +} + +func (r *repo) GetDSSMetadata(ctx context.Context) ([]*auxmodels.DSSMetadata, error) { + result, err := r.consensus.ProposeValue(ctx, getDSSMetadata, nil, true) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose %s", getDSSMetadata) + } + if result == nil { + return nil, nil + } + return result.([]*auxmodels.DSSMetadata), nil } -func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) error { - return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "RecordHeartbeat not implemented for raftstore") +func (r *repo) RecordHeartbeat(ctx context.Context, heartbeat auxmodels.Heartbeat) error { + if heartbeat.Locality == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") + } + if heartbeat.Source == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Source not set") + } + + if heartbeat.Timestamp == nil { + now := timestamp.NowFromContext(ctx) + heartbeat.Timestamp = &now + } + + if heartbeat.NextHeartbeatExpectedBefore != nil && heartbeat.NextHeartbeatExpectedBefore.Before(*heartbeat.Timestamp) { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Cannot expect the timestamp of the next heartbeat before the timestamp of the new heartbeat") + } + + _, err := r.consensus.ProposeValue(ctx, recordHeartbeat, heartbeat, false) + return err } func (r *repo) GetDSSAirspaceRepresentationID(_ context.Context) (string, error) { - return "", stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSAirspaceRepresentationID not implemented for raftstore") + return strconv.Itoa(int(raftparams.GetClusterID())), nil } diff --git a/pkg/aux_/store/raftstore/params/params.go b/pkg/aux_/store/raftstore/params/params.go new file mode 100644 index 000000000..3cd731a4d --- /dev/null +++ b/pkg/aux_/store/raftstore/params/params.go @@ -0,0 +1,29 @@ +package params + +import ( + "flag" + + raftparams "github.com/interuss/dss/pkg/raftstore/params" + "github.com/interuss/stacktrace" +) + +const peersFlag = "aux_raft_peers" + +var peers string + +func init() { + flag.StringVar(&peers, peersFlag, "", `Comma-separated "nodeID=peerURL" pairs for the aux store, e.g. "1=http://node1:9031,2=http://node2:9031,3=http://node3:9031"`) +} + +func GetConnectParameters() (raftparams.ConnectParameters, error) { + if peers == "" { + return raftparams.ConnectParameters{}, stacktrace.NewError("--%s is required", peersFlag) + } + + p, err := raftparams.GetConnectParameters("aux") + if err != nil { + return raftparams.ConnectParameters{}, err + } + p.Peers = peers + return p, nil +} diff --git a/pkg/aux_/store/raftstore/store.go b/pkg/aux_/store/raftstore/store.go index ae74ad990..5f8939669 100644 --- a/pkg/aux_/store/raftstore/store.go +++ b/pkg/aux_/store/raftstore/store.go @@ -2,15 +2,96 @@ package raftstore import ( "context" + "encoding/json" + auxmodels "github.com/interuss/dss/pkg/aux_/models" "github.com/interuss/dss/pkg/aux_/repos" + auxmemstore "github.com/interuss/dss/pkg/aux_/store/memstore" + auxraftparams "github.com/interuss/dss/pkg/aux_/store/raftstore/params" + "github.com/interuss/dss/pkg/memstore" "github.com/interuss/dss/pkg/raftstore" + "github.com/interuss/dss/pkg/raftstore/consensus" + "github.com/interuss/stacktrace" "go.uber.org/zap" ) +const storeID = "aux_" + +const ( + saveOwnMetadata raftstore.RequestType = "saveOwnMetadata" + getDSSMetadata raftstore.RequestType = "getDSSMetadata" + recordHeartbeat raftstore.RequestType = "recordHeartbeat" +) + // repo is a full implementation of aux_.repos.Repository for Raft-based storage. -type repo struct{} +type repo struct { + consensus *consensus.Consensus + memStore *memstore.Store[repos.Repository] +} func Init(ctx context.Context, logger *zap.Logger) (*raftstore.Store[repos.Repository], error) { - return raftstore.Init[repos.Repository](ctx, logger, func() repos.Repository { return &repo{} }) + params, err := auxraftparams.GetConnectParameters() + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get aux raft parameters") + } + + memStore, err := auxmemstore.Init(ctx, logger) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to initialize aux memstore") + } + + r := &repo{memStore: memStore} + store, err := raftstore.Init(ctx, logger, params, r) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to initialize aux raftstore") + } + + r.consensus = store.Consensus + + return store, nil +} + +func (r *repo) GetRepo() repos.Repository { return r } + +func (r *repo) IsReadOnly(requestType raftstore.RequestType) bool { + return requestType == getDSSMetadata +} + +func (r *repo) GetSnapshot() ([]byte, error) { + return r.memStore.GetSnapshot() +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + return r.memStore.RestoreFromSnapshot(data) +} + +func (r *repo) Apply(ctx context.Context, proposal consensus.Proposal) (any, error) { + mem, err := r.memStore.Interact(ctx) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to obtain aux memstore repository") + } + + switch raftstore.RequestType(proposal.RequestType) { + case saveOwnMetadata: + var p saveOwnMetadataPayload + if err := json.Unmarshal(proposal.Value, &p); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", saveOwnMetadata) + } + + return nil, mem.SaveOwnMetadata(ctx, p.Locality, p.PublicEndpoint) + + case getDSSMetadata: + return mem.GetDSSMetadata(ctx) + + case recordHeartbeat: + var hb auxmodels.Heartbeat + if err := json.Unmarshal(proposal.Value, &hb); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", recordHeartbeat) + } + + return nil, mem.RecordHeartbeat(ctx, hb) + + default: + return nil, stacktrace.NewError("unknown request type: %q", proposal.RequestType) + } } diff --git a/pkg/aux_/store/store.go b/pkg/aux_/store/store.go index 4a9049327..2532842a7 100644 --- a/pkg/aux_/store/store.go +++ b/pkg/aux_/store/store.go @@ -4,6 +4,7 @@ import ( "context" "github.com/interuss/dss/pkg/aux_/repos" + auxmemstore "github.com/interuss/dss/pkg/aux_/store/memstore" auxraftstore "github.com/interuss/dss/pkg/aux_/store/raftstore" auxsqlstore "github.com/interuss/dss/pkg/aux_/store/sqlstore" dssstore "github.com/interuss/dss/pkg/store" @@ -24,6 +25,8 @@ func Init(ctx context.Context, logger *zap.Logger, withCheckCron bool) (Store, e return auxsqlstore.Init(ctx, logger, withCheckCron) case params.RaftStoreType: return auxraftstore.Init(ctx, logger) + case params.MemStoreType: + return auxmemstore.Init(ctx, logger) default: return nil, stacktrace.NewError("Unsupported store type %q for aux", storeType) } diff --git a/pkg/memstore/store.go b/pkg/memstore/store.go new file mode 100644 index 000000000..f11ecf409 --- /dev/null +++ b/pkg/memstore/store.go @@ -0,0 +1,87 @@ +package memstore + +// Memstore are a special kind of store: +// Store instances store data in memory. There is no persistent storage. +// Store instances are a singleton. +// Repository usage is not thread-safe. +// +// As of now, they are made to be used by raftstorage. +// Adaptations could be done to use them directly in the future. + +import ( + "context" + "sync" + + dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/dss/pkg/logging" + "github.com/interuss/stacktrace" + "go.uber.org/zap" +) + +type MemRepo[R any] interface { + GetRepo() R + GetSnapshot() ([]byte, error) + RestoreFromSnapshot([]byte) error + Checkpoint() any + Restore(any) error +} + +type Store[R any] struct { + logger *zap.Logger + + name string + memRepo MemRepo[R] +} + +var ( + stores = map[string]any{} + storesMu sync.Mutex +) + +func Init[R any](ctx context.Context, logger *zap.Logger, name string, r MemRepo[R]) (*Store[R], error) { + + storesMu.Lock() + defer storesMu.Unlock() + if s, ok := stores[name]; ok { + return s.(*Store[R]), nil + } + + store := &Store[R]{ + name: name, + logger: logging.WithValuesFromContext(ctx, logger), + memRepo: r, + } + + stores[name] = store + return store, nil +} + +func (s *Store[R]) Transact(ctx context.Context, requestType string, payload any, _ func(context.Context, R) error) (any, error) { + return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "Transact not implemented for memstore") +} + +func (s *Store[R]) Interact(_ context.Context) (R, error) { + return s.memRepo.GetRepo(), nil +} + +// Checkpoint returns a fast, restorable in-memory copy of the current state. +func (s *Store[R]) Checkpoint() any { + return s.memRepo.Checkpoint() +} + +// Restore replaces the current state with a checkpoint returned by Checkpoint. +func (s *Store[R]) Restore(cp any) error { + return s.memRepo.Restore(cp) +} + +func (s *Store[R]) GetSnapshot() ([]byte, error) { + return s.memRepo.GetSnapshot() +} + +func (s *Store[R]) RestoreFromSnapshot(data []byte) error { + return s.memRepo.RestoreFromSnapshot(data) +} + +func (s *Store[R]) Close() error { + return nil +} diff --git a/pkg/models/geo.go b/pkg/models/geo.go index d132c450e..29ad1cd0a 100644 --- a/pkg/models/geo.go +++ b/pkg/models/geo.go @@ -1,6 +1,8 @@ package models import ( + "encoding/json" + "fmt" "time" "github.com/golang/geo/s2" @@ -72,6 +74,83 @@ type Volume3D struct { Footprint Geometry } +type Volume3DJSON struct { + AltitudeHi *float32 `json:"AltitudeHi,omitempty"` + AltitudeLo *float32 `json:"AltitudeLo,omitempty"` + Footprint *geometryJSON `json:"Footprint,omitempty"` +} + +type geometryType string + +const ( + circle geometryType = "circle" + polygon geometryType = "polygon" + cells geometryType = "cells" +) + +// geometryJSON is a helper struct for marshaling and unmarshaling Geometry types to/from JSON. +type geometryJSON struct { + Type geometryType `json:"type"` + Polygon *GeoPolygon `json:"polygon,omitempty"` + Circle *GeoCircle `json:"circle,omitempty"` + Cells []s2.CellID `json:"cells,omitempty"` +} + +func (v Volume3D) MarshalJSON() ([]byte, error) { + w := Volume3DJSON{AltitudeHi: v.AltitudeHi, AltitudeLo: v.AltitudeLo} + if v.Footprint != nil { + switch f := v.Footprint.(type) { + case *GeoPolygon: + w.Footprint = &geometryJSON{Type: polygon, Polygon: f} + + case *GeoCircle: + w.Footprint = &geometryJSON{Type: circle, Circle: f} + + case precomputedCellGeometry: + cellsResult := make([]s2.CellID, 0, len(f)) + for id := range f { + cellsResult = append(cellsResult, id) + } + w.Footprint = &geometryJSON{Type: cells, Cells: cellsResult} + + default: + return nil, fmt.Errorf("Volume3D: unsupported Footprint type %T for JSON marshaling", v.Footprint) + } + } + + return json.Marshal(w) +} + +func (v *Volume3D) UnmarshalJSON(data []byte) error { + var w Volume3DJSON + if err := json.Unmarshal(data, &w); err != nil { + return err + } + v.AltitudeHi = w.AltitudeHi + v.AltitudeLo = w.AltitudeLo + if w.Footprint != nil { + switch w.Footprint.Type { + case polygon: + v.Footprint = w.Footprint.Polygon + + case circle: + v.Footprint = w.Footprint.Circle + + case cells: + pcg := make(precomputedCellGeometry, len(w.Footprint.Cells)) + for _, id := range w.Footprint.Cells { + pcg[id] = struct{}{} + } + + v.Footprint = pcg + default: + return fmt.Errorf("Volume3D: unknown geometry type %q", w.Footprint.Type) + } + } + + return nil +} + // Geometry models a geometry. type Geometry interface { // CalculateCovering returns an s2 cell covering for a geometry. diff --git a/pkg/models/models.go b/pkg/models/models.go index 7ec4c91e5..18d3a2b79 100644 --- a/pkg/models/models.go +++ b/pkg/models/models.go @@ -1,6 +1,7 @@ package models import ( + "encoding/json" "strconv" "time" @@ -175,3 +176,26 @@ func (v *Version) ToTimestamp() *time.Time { } return &v.t } + +func (v *Version) MarshalJSON() ([]byte, error) { + return json.Marshal(v.String()) +} + +func (v *Version) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + if s == "" { + return nil + } + + parsed, err := VersionFromString(s) + if err != nil { + return stacktrace.Propagate(err, "failed to unmarshal version") + } + + *v = *parsed + return nil +} diff --git a/pkg/raftstore/consensus/consensus.go b/pkg/raftstore/consensus/consensus.go index 8377e22d3..4d6c2367e 100644 --- a/pkg/raftstore/consensus/consensus.go +++ b/pkg/raftstore/consensus/consensus.go @@ -2,10 +2,12 @@ package consensus import ( "context" + "encoding/json" "errors" "fmt" "net/http" "net/url" + "sync" "time" "github.com/interuss/dss/pkg/logging" @@ -29,18 +31,30 @@ type Consensus struct { server *http.Server storage *storage + commitC chan<- EntryCommit + + tracker *proposalsTracker + stopOnce sync.Once + + once sync.Once + shutdownTimeout time.Duration confState raftpb.ConfState snapshotIndex uint64 appliedIndex uint64 } -func NewConsensus(ctx context.Context, logger *zap.Logger, peers map[uint64]*url.URL, connectParams params.ConnectParameters) (*Consensus, error) { - storage, old, err := newStorage(ctx, logger.With(zap.String("component", "storage")), connectParams.DataDir, connectParams.NodeID, connectParams.SnapshotCatchupEntries) +func NewConsensus(ctx context.Context, logger *zap.Logger, connectParams params.ConnectParameters, provider snapshotProvider, commitC chan<- EntryCommit) (*Consensus, error) { + storage, old, err := newStorage(ctx, logger.With(zap.String("component", "storage")), connectParams.DataDir, connectParams.NodeID, provider, connectParams.SnapshotCatchupEntries) if err != nil { return nil, stacktrace.Propagate(err, "failed to initialize storage") } + peers, err := connectParams.PeerMap() + if err != nil { + return nil, stacktrace.Propagate(err, "failed to parse peer map") + } + nodeUrl, ok := peers[connectParams.NodeID] if !ok { return nil, stacktrace.NewError("node ID %d not found in peers map", connectParams.NodeID) @@ -63,6 +77,10 @@ func NewConsensus(ctx context.Context, logger *zap.Logger, peers map[uint64]*url node: node, storage: storage, + commitC: commitC, + tracker: newProposalsTracker(), + + shutdownTimeout: 2 * connectParams.ElectionInterval(), } err = consensus.initTransport(ctx, connectParams.NodeID, connectParams.ClusterID, peers) @@ -80,26 +98,62 @@ func NewConsensus(ctx context.Context, logger *zap.Logger, peers map[uint64]*url consensus.appliedIndex = snap.Metadata.Index go func() { - err := consensus.handleReady(connectParams.TickInterval) + err := consensus.handleReady(connectParams.TickInterval, connectParams.SnapshotIntervalEntries) if err != nil { consensus.logger.Error("handleReady exited with error, shutting down consensus", zap.Error(err)) } - shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*connectParams.ElectionInterval()) + consensus.Stop(context.Background()) + }() + + return consensus, nil +} + +func (c *Consensus) Stop(ctx context.Context) { + c.once.Do(func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), c.shutdownTimeout) defer cancel() - if shutdownErr := consensus.server.Shutdown(shutdownCtx); shutdownErr != nil { - consensus.logger.Error("failed to shutdown http server", zap.Error(shutdownErr)) + if shutdownErr := c.server.Shutdown(shutdownCtx); shutdownErr != nil { + c.logger.Error("failed to shutdown http server", zap.Error(shutdownErr)) } else { - consensus.logger.Info("http server shutdown complete") + c.logger.Info("http server shutdown complete") } - consensus.transport.Stop() - consensus.logger.Info("transport stopped") - consensus.node.Stop() - consensus.logger.Info("raft node stopped") - }() + c.transport.Stop() + c.logger.Info("transport stopped") + c.node.Stop() + c.logger.Info("raft node stopped") + }) +} - return consensus, nil +// ProposeValue blocks until the proposal is committed and applied / dropped or until ctx is cancelled. +func (c *Consensus) ProposeValue(ctx context.Context, requestType string, payload any, readOnly bool) (any, error) { + proposal, err := newProposal(ctx, requestType, payload, readOnly) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to create proposal") + } + + buf, err := json.Marshal(proposal) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to marshal proposal") + } + + applied := c.tracker.track(proposal.ID) + + err = c.node.Propose(ctx, buf) + if err != nil { + c.tracker.untrack(proposal.ID, ProposalResult{Error: err}) + return nil, stacktrace.Propagate(err, "failed to propose value to Raft") + } + + select { + case res := <-applied: + return res.Result, res.Error + + case <-ctx.Done(): + c.tracker.untrack(proposal.ID, ProposalResult{Error: ctx.Err()}) + return nil, ctx.Err() + } } func peersList(peers map[uint64]*url.URL) []raft.Peer { @@ -120,7 +174,7 @@ func (c *Consensus) initTransport(ctx context.Context, nodeID uint64, clusterID Raft: c, ServerStats: v2stats.NewServerStats(nodeIDStr, nodeIDStr), LeaderStats: v2stats.NewLeaderStats(c.logger, nodeIDStr), - ErrorC: make(chan error), + ErrorC: make(chan error, 1), } err := transport.Start() @@ -142,6 +196,8 @@ func (c *Consensus) initTransport(ctx context.Context, nodeID uint64, clusterID return stacktrace.NewError("node ID %d not found in peers map", nodeID) } + c.transport = transport + c.server = &http.Server{ Addr: listeningAddr, Handler: transport.Handler(), @@ -155,12 +211,11 @@ func (c *Consensus) initTransport(ctx context.Context, nodeID uint64, clusterID } }() - c.transport = transport return nil } // handleReady processes the Ready channel of the Raft node and applies committed entries to the state machine -func (c *Consensus) handleReady(tickInterval time.Duration) error { +func (c *Consensus) handleReady(tickInterval time.Duration, snapshotInterval uint64) error { ticker := time.NewTicker(tickInterval) defer ticker.Stop() @@ -185,10 +240,7 @@ func (c *Consensus) handleReady(tickInterval time.Duration) error { return stacktrace.NewError("snapshot index %d shall be greater than current applied index %d", rd.Snapshot.Metadata.Index, c.appliedIndex) } - err = c.dispatchSnapshot(rd.Snapshot.Data) - if err != nil { - return stacktrace.Propagate(err, "failed to dispatch snapshot") - } + c.commitC <- EntryCommit{SnapshotData: rd.Snapshot.Data} c.confState = rd.Snapshot.Metadata.ConfState c.snapshotIndex = rd.Snapshot.Metadata.Index @@ -203,7 +255,7 @@ func (c *Consensus) handleReady(tickInterval time.Duration) error { return stacktrace.Propagate(err, "failed to get entries to apply") } - err = c.publishEntries(entries) + err = c.publishEntries(entries, snapshotInterval) if err != nil { return stacktrace.Propagate(err, "failed to publish entries") } @@ -213,16 +265,104 @@ func (c *Consensus) handleReady(tickInterval time.Duration) error { } } -// TODO implement -func (c *Consensus) publishEntries(_ []raftpb.Entry) error { +func (c *Consensus) publishEntries(entries []raftpb.Entry, snapshotInterval uint64) error { + if len(entries) == 0 { + return nil + } + + c.logger.Info("publishing entries", zap.Int("numEntries", len(entries)), zap.Uint64("firstIndex", entries[0].Index), zap.Uint64("lastIndex", entries[len(entries)-1].Index)) + + var triggerSnapshot bool + var err error + var wg sync.WaitGroup + for _, entry := range entries { + switch entry.Type { + case raftpb.EntryNormal: + err := c.processNormalEntry(entry.Data, &wg) + if err != nil { + return stacktrace.Propagate(err, "failed to process normal entry") + } + case raftpb.EntryConfChange: + err := c.processConfigChangeEntry(entry.Data) + if err != nil { + return stacktrace.Propagate(err, "failed to process config change entry") + } + case raftpb.EntryConfChangeV2: + triggerSnapshot, err = c.processConfigChangeV2Entry(entry.Data) + if err != nil { + return stacktrace.Propagate(err, "failed to process config change v2 entry") + } + } + } + + // wait for all entries to be applied before updating the applied index and potentially triggering a snapshot + wg.Wait() + c.appliedIndex = entries[len(entries)-1].Index + + if triggerSnapshot || c.appliedIndex-c.snapshotIndex >= snapshotInterval { + err := c.storage.triggerSnapshot(c.appliedIndex, &c.confState) + if err != nil { + return stacktrace.Propagate(err, "failed to trigger snapshot") + } + + c.snapshotIndex = c.appliedIndex + } + return nil } -// TODO implement -func (c *Consensus) dispatchSnapshot(_ []byte) error { +// processNormalEntry passes the proposal to the store and waits for the result to be returned before untracking it. +func (c *Consensus) processNormalEntry(data []byte, wg *sync.WaitGroup) error { + if len(data) <= 0 { + return nil + } + + prop := Proposal{} + err := json.Unmarshal(data, &prop) + if err != nil { + return stacktrace.Propagate(err, "failed to unmarshal committed proposal") + } + + //if readOnly proposal and we did not initiate it, skip it (noop) + if prop.ReadOnly && !c.tracker.isPending(prop.ID) { + return nil + } + + applyDoneC := make(chan ProposalResult, 1) + wg.Go(func() { + c.tracker.untrack(prop.ID, <-applyDoneC) + }) + + c.commitC <- EntryCommit{Prop: prop, Done: applyDoneC} + return nil +} + +// raftpb.ConfChange is still used internally by Raft, we just need to apply the change to the node. +// Changes requested by clients are processed by processConfigChangeV2Entry. +func (c *Consensus) processConfigChangeEntry(data []byte) error { + var cc raftpb.ConfChange + err := cc.Unmarshal(data) + if err != nil { + return stacktrace.Propagate(err, "failed to unmarshal config change data") + } + + c.confState = *c.node.ApplyConfChange(cc) return nil } +func (c *Consensus) processConfigChangeV2Entry(data []byte) (bool, error) { + var cc raftpb.ConfChangeV2 + err := cc.Unmarshal(data) + if err != nil { + return false, stacktrace.Propagate(err, "failed to unmarshal config change data") + } + + c.confState = *c.node.ApplyConfChange(cc) + + // TODO - implement config changes when triggered by a proposal + return false, nil +} + func (c *Consensus) entriesToApply(entries []raftpb.Entry) ([]raftpb.Entry, error) { if len(entries) == 0 { return entries, nil @@ -272,8 +412,3 @@ func (c *Consensus) ReportUnreachable(id uint64) { func (c *Consensus) ReportSnapshot(id uint64, status raft.SnapshotStatus) { c.node.ReportSnapshot(id, status) } - -// RegisterStore allows registering a snapshot provider function for a specific store -func (c *Consensus) RegisterStore(name string, provider snapshotProvider) { - c.storage.registerSnapshotProvider(name, provider) -} diff --git a/pkg/raftstore/consensus/proposal.go b/pkg/raftstore/consensus/proposal.go new file mode 100644 index 000000000..d05f4b7e5 --- /dev/null +++ b/pkg/raftstore/consensus/proposal.go @@ -0,0 +1,93 @@ +package consensus + +import ( + "context" + "encoding/json" + "sync" + "time" + + "github.com/google/uuid" + "github.com/interuss/dss/pkg/timestamp" + "github.com/interuss/stacktrace" +) + +type EntryCommit struct { + Prop Proposal + Done chan ProposalResult + + SnapshotData []byte +} + +type Proposal struct { + ID string `json:"id"` + Timestamp time.Time `json:"timestamp"` + RequestType string `json:"request_type"` + Value []byte `json:"value"` + ReadOnly bool `json:"read_only"` +} + +type ProposalResult struct { + Result any + Error error +} + +func newProposal(ctx context.Context, requestType string, payload any, readOnly bool) (Proposal, error) { + proposalTimestamp := timestamp.NowFromContext(ctx) + if proposalTimestamp.IsZero() { + proposalTimestamp = time.Now().UTC() + } + + value, err := json.Marshal(payload) + if err != nil { + return Proposal{}, stacktrace.Propagate(err, "failed to serialize proposal payload") + } + + return Proposal{ + ID: uuid.NewString(), + Timestamp: proposalTimestamp, + RequestType: requestType, + Value: value, + ReadOnly: readOnly, + }, nil +} + +type proposalsTracker struct { + sync.Mutex + pending map[string]chan ProposalResult +} + +func newProposalsTracker() *proposalsTracker { + return &proposalsTracker{ + pending: make(map[string]chan ProposalResult), + } +} + +func (p *proposalsTracker) isPending(id string) bool { + p.Lock() + defer p.Unlock() + + _, ok := p.pending[id] + return ok +} + +func (p *proposalsTracker) track(id string) chan ProposalResult { + p.Lock() + defer p.Unlock() + + applied := make(chan ProposalResult, 1) + p.pending[id] = applied + return applied +} + +func (p *proposalsTracker) untrack(id string, result ProposalResult) { + p.Lock() + defer p.Unlock() + + applied, ok := p.pending[id] + if !ok { + return + } + + applied <- result + delete(p.pending, id) +} diff --git a/pkg/raftstore/consensus/storage.go b/pkg/raftstore/consensus/storage.go index 97f2ff718..1d9893882 100644 --- a/pkg/raftstore/consensus/storage.go +++ b/pkg/raftstore/consensus/storage.go @@ -2,7 +2,6 @@ package consensus import ( "context" - "encoding/json" "errors" "fmt" "os" @@ -31,8 +30,8 @@ type storage struct { wal *wal.WAL - snapper *snap.Snapshotter - providers map[string]snapshotProvider + snapper *snap.Snapshotter + snapshot snapshotProvider snapshotCatchUpEntries uint64 } @@ -40,7 +39,7 @@ type storage struct { // newStorage initializes the storage by loading the latest snapshot and wal entries from the disk // and applies them to the Raft memory storage. // It returns the initialized storage, a boolean indicating whether the storage was pre-existent or an error. -func newStorage(ctx context.Context, logger *zap.Logger, dataDir string, nodeID uint64, snapshotCatchUpEntries uint64) (*storage, bool, error) { +func newStorage(ctx context.Context, logger *zap.Logger, dataDir string, nodeID uint64, provider snapshotProvider, snapshotCatchUpEntries uint64) (*storage, bool, error) { logger = logging.WithValuesFromContext(ctx, logger) // load the latest snapshot @@ -124,8 +123,9 @@ func newStorage(ctx context.Context, logger *zap.Logger, dataDir string, nodeID wal: w, - snapper: snapper, - providers: make(map[string]snapshotProvider), + snapper: snapper, + snapshot: provider, + snapshotCatchUpEntries: snapshotCatchUpEntries, }, ok, nil } @@ -177,33 +177,9 @@ func (s *storage) save(snapshot raftpb.Snapshot) error { return s.wal.ReleaseLockTo(snapshot.Metadata.Index) } -func (s *storage) registerSnapshotProvider(name string, provider func() ([]byte, error)) { - s.providers[name] = provider -} - -// getSnapshot calls all registered snapshot providers and combines their data into a single snapshot. -func (s *storage) getSnapshot() ([]byte, error) { - parts := make(map[string][]byte) - for name, provider := range s.providers { - data, err := provider() - if err != nil { - return nil, stacktrace.Propagate(err, "failed to get snapshot data from %q", name) - } - - parts[name] = data - } - - return json.Marshal(parts) -} - -// snapshotter returns the snapshotter used by the storage. -func (s *storage) snapshotter() *snap.Snapshotter { - return s.snapper -} - func (s *storage) triggerSnapshot(appliedIndex uint64, confState *raftpb.ConfState) error { s.logger.Info("triggering snapshot", zap.Uint64("appliedIndex", appliedIndex)) - data, err := s.getSnapshot() + data, err := s.snapshot() if err != nil { return stacktrace.Propagate(err, "failed to get snapshot data") } diff --git a/pkg/raftstore/params/params.go b/pkg/raftstore/params/params.go index 3a2b6fd6a..6adf994f6 100644 --- a/pkg/raftstore/params/params.go +++ b/pkg/raftstore/params/params.go @@ -123,7 +123,6 @@ var ( func init() { flag.Uint64Var(&connectParameters.NodeID, "raft_node_id", 0, "Raft node ID for this instance (must be non-zero and unique within the cluster).") flag.Uint64Var(&connectParameters.ClusterID, "raft_cluster_id", 1, "ID of the cluster, used to isolate different Raft clusters running in the same network (must be the same for all nodes in the cluster).") - flag.StringVar(&connectParameters.Peers, "raft_peers", "", `Comma-separated "nodeID=peerURL" pairs for all cluster members, including the current node, e.g. "1=http://node1:9021,2=http://node2:9021,3=http://node3:9021"`) flag.StringVar(&connectParameters.DataDir, "raft_datadir", defaultDataDir, "Directory for raft data (WAL segments and snapshots), required for restarts. These should not be deleted while the node is running or across restarts unless the node is being permanently shut down.") flag.Uint64Var(&connectParameters.SnapshotCatchupEntries, "raft_snapshot_catchup_entries", defaultSnapshotCatchupEntries, @@ -142,6 +141,15 @@ func init() { } // GetConnectParameters returns a ConnectParameters instance that gets populated from well-known CLI flags. -func GetConnectParameters() ConnectParameters { - return connectParameters +func GetConnectParameters(subfolder string) (ConnectParameters, error) { + if connectParameters.NodeID == 0 { + return ConnectParameters{}, stacktrace.NewError("--raft_node_id is required and must be non-zero") + } + p := connectParameters + p.DataDir = connectParameters.DataDir + "/" + subfolder + return p, nil +} + +func GetClusterID() uint64 { + return connectParameters.ClusterID } diff --git a/pkg/raftstore/store.go b/pkg/raftstore/store.go index 3ef62c157..5f9c74112 100644 --- a/pkg/raftstore/store.go +++ b/pkg/raftstore/store.go @@ -2,68 +2,96 @@ package raftstore import ( "context" - "sync" + "github.com/interuss/dss/pkg/logging" "github.com/interuss/dss/pkg/raftstore/consensus" raftparams "github.com/interuss/dss/pkg/raftstore/params" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/stacktrace" "go.uber.org/zap" ) -var ( - sharedConsensus *consensus.Consensus - sharedConsensusOnce sync.Once - sharedConsensusErr error -) +type RequestType = string + +type RaftRepo[R any] interface { + GetRepo() R + // Apply is called on every committed entry. The proposal must be applied atomically. + Apply(ctx context.Context, proposal consensus.Proposal) (any, error) + GetSnapshot() ([]byte, error) + RestoreFromSnapshot(data []byte) error + IsReadOnly(requestType RequestType) bool +} type Store[R any] struct { - newRepo func() R - consensus *consensus.Consensus + logger *zap.Logger + + raftRepo RaftRepo[R] + cancel context.CancelFunc + + Consensus *consensus.Consensus } -func Init[R any](ctx context.Context, logger *zap.Logger, newRepo func() R) (*Store[R], error) { - // scd, rid and aux will share the same consensus instance, so we initialize it once. - sharedConsensusOnce.Do(func() { - params := raftparams.GetConnectParameters() - peers, err := params.PeerMap() - if err != nil { - sharedConsensusErr = stacktrace.Propagate(err, "failed to parse peer map") - return - } +func Init[R any](ctx context.Context, logger *zap.Logger, params raftparams.ConnectParameters, r RaftRepo[R]) (*Store[R], error) { + ctx, cancel := context.WithCancel(ctx) - sharedConsensus, sharedConsensusErr = consensus.NewConsensus(ctx, logger, peers, params) - if sharedConsensusErr != nil { - sharedConsensusErr = stacktrace.Propagate(sharedConsensusErr, "failed to initialize consensus") - } - }) - if sharedConsensusErr != nil { - return nil, sharedConsensusErr + store := &Store[R]{ + raftRepo: r, + logger: logging.WithValuesFromContext(ctx, logger), + cancel: cancel, + } + commitC := make(chan consensus.EntryCommit) + go store.processCommits(ctx, commitC) + + consensusInstance, err := consensus.NewConsensus(ctx, logger, params, func() ([]byte, error) { return nil, nil }, commitC) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to initialize consensus") } - // TODO: implement - sharedConsensus.RegisterStore("provider", func() ([]byte, error) { - return nil, nil - }) + store.Consensus = consensusInstance - return &Store[R]{ - newRepo: newRepo, - consensus: sharedConsensus, - }, nil + return store, nil } -// Transact proposes the entry to Raft and blocks until it is committed and applied. -func (s *Store[R]) Transact(ctx context.Context, f func(context.Context, R) error) error { - // TODO: implement - return nil +// Transact proposes an entry to Raft and blocks until it is committed and applied. +// The processCommits loop will call Apply on the proposal when it is committed. +func (s *Store[R]) Transact(ctx context.Context, requestType RequestType, payload any, _ func(context.Context, R) error) (any, error) { + return s.Consensus.ProposeValue(ctx, requestType, payload, s.raftRepo.IsReadOnly(requestType)) } -// Interact returns a repository that can be used to query the store without proposing a Raft entry. func (s *Store[R]) Interact(_ context.Context) (R, error) { - return s.newRepo(), nil + return s.raftRepo.GetRepo(), nil } -// Close shuts down the consensus instance. +// Close shuts down the consensus instance and processCommits loop. func (s *Store[R]) Close() error { - // TODO: implement + s.Consensus.Stop(context.Background()) + s.cancel() return nil } + +// processCommits reads committed entries from the consensus layer and applies them via Apply. +func (s *Store[R]) processCommits(ctx context.Context, commitCh <-chan consensus.EntryCommit) { + for { + select { + case <-ctx.Done(): + s.logger.Info("stopping commit processing loop") + return + case commit, ok := <-commitCh: + if !ok { + s.logger.Info("commit channel closed, stopping commit processing loop") + return + } + + if commit.SnapshotData != nil { + if err := s.raftRepo.RestoreFromSnapshot(commit.SnapshotData); err != nil { + s.logger.Error("failed to restore from snapshot", zap.Error(err)) + } + continue + } + + ctx = timestamp.WithTimestamp(ctx, commit.Prop.Timestamp) + result, err := s.raftRepo.Apply(ctx, commit.Prop) + commit.Done <- consensus.ProposalResult{Result: result, Error: err} + } + } +} diff --git a/pkg/rid/application/application_test.go b/pkg/rid/application/application_test.go index 0cc3fd885..11b785bf4 100644 --- a/pkg/rid/application/application_test.go +++ b/pkg/rid/application/application_test.go @@ -35,8 +35,8 @@ func (s *mockRepo) Interact(ctx context.Context) (repos.Repository, error) { return s, nil } -func (s *mockRepo) Transact(ctx context.Context, f func(ctx context.Context, repo repos.Repository) error) error { - return f(ctx, s) +func (s *mockRepo) Transact(ctx context.Context, _ string, _ any, f func(ctx context.Context, repo repos.Repository) error) (any, error) { + return nil, f(ctx, s) } func (s *mockRepo) Close() error { diff --git a/pkg/rid/application/isa.go b/pkg/rid/application/isa.go index ea402af60..ada032dea 100644 --- a/pkg/rid/application/isa.go +++ b/pkg/rid/application/isa.go @@ -10,6 +10,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" "github.com/interuss/dss/pkg/rid/repos" + ridraftstore "github.com/interuss/dss/pkg/rid/store/raftstore" "github.com/interuss/stacktrace" ) @@ -62,8 +63,13 @@ func (a *app) DeleteISA(ctx context.Context, id dssmodels.ID, owner dssmodels.Ow ret *ridmodels.IdentificationServiceArea subs []*ridmodels.Subscription ) + // The following will automatically retry TXN retry errors. - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + raftResult, err := a.store.Transact(ctx, ridraftstore.DeleteISATransaction, ridraftstore.DeleteISATransactionPayload{ + ID: id, + Owner: owner, + Version: version, + }, func(ctx context.Context, repo repos.Repository) error { old, err := repo.GetISA(ctx, id, true) switch { case err != nil: @@ -89,6 +95,20 @@ func (a *app) DeleteISA(ctx context.Context, id dssmodels.ID, owner dssmodels.Ow } return nil }) + + if err == nil && raftResult != nil { + if result, ok := raftResult.(*ridraftstore.ISATransactionResult); ok { + if result.Ret != nil { + ret = result.Ret + } + if result.Subs != nil { + subs = result.Subs + } + } else { + return nil, nil, stacktrace.NewError("invalid result type: %T", raftResult) + } + } + return ret, subs, err // No need to Propagate this error as this stack layer does not add useful information } @@ -104,7 +124,7 @@ func (a *app) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServic subs []*ridmodels.Subscription ) // The following will automatically retry TXN retry errors. - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + raftResult, err := a.store.Transact(ctx, ridraftstore.InsertISATransaction, isa, func(ctx context.Context, repo repos.Repository) error { // ensure it doesn't exist yet old, err := repo.GetISA(ctx, isa.ID, false) if err != nil { @@ -127,6 +147,14 @@ func (a *app) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServic } return nil }) + if err == nil && raftResult != nil { + if result, ok := raftResult.(*ridraftstore.ISATransactionResult); ok { + ret = result.Ret + subs = result.Subs + } else { + return nil, nil, stacktrace.NewError("invalid result type: %T", raftResult) + } + } return ret, subs, err // No need to Propagate this error as this stack layer does not add useful information } @@ -138,7 +166,7 @@ func (a *app) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServic subs []*ridmodels.Subscription ) // The following will automatically retry TXN retry errors. - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + raftResult, err := a.store.Transact(ctx, ridraftstore.UpdateISATransaction, isa, func(ctx context.Context, repo repos.Repository) error { var err error old, err := repo.GetISA(ctx, isa.ID, true) @@ -178,5 +206,14 @@ func (a *app) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServic return nil }) + if err == nil && raftResult != nil { + if result, ok := raftResult.(*ridraftstore.ISATransactionResult); ok { + ret = result.Ret + subs = result.Subs + } else { + return nil, nil, stacktrace.NewError("invalid result type: %T", raftResult) + } + } + return ret, subs, err // No need to Propagate this error as this stack layer does not add useful information } diff --git a/pkg/rid/application/subscription.go b/pkg/rid/application/subscription.go index 33ba2af36..4100584e7 100644 --- a/pkg/rid/application/subscription.go +++ b/pkg/rid/application/subscription.go @@ -8,15 +8,11 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" "github.com/interuss/dss/pkg/rid/repos" + ridraftstore "github.com/interuss/dss/pkg/rid/store/raftstore" "github.com/interuss/stacktrace" "go.uber.org/zap" ) -const ( - // Defined in requirement DSS0030. - maxSubscriptionsPerArea = 10 -) - // SubscriptionApp provides the interface to the application logic for Subscription entities // AppInterface provides the interface to the application logic for ISA entities // Note that there is no need for the applciation layer to have the same API as @@ -60,7 +56,7 @@ func (a *app) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) return nil, stacktrace.Propagate(err, "Unable to adjust time range") } var sub *ridmodels.Subscription - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + raftResult, err := a.store.Transact(ctx, ridraftstore.InsertSubscriptionTransaction, s, func(ctx context.Context, repo repos.Repository) error { // ensure it doesn't exist yet old, err := repo.GetSubscription(ctx, s.ID) @@ -78,7 +74,7 @@ func (a *app) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) return stacktrace.Propagate(err, "Failed to fetch subscription count, rejecting request") } - if count >= maxSubscriptionsPerArea { + if count >= ridmodels.MaxSubscriptionsPerArea { return stacktrace.Propagate( stacktrace.NewErrorWithCode(dsserr.Exhausted, "Too many existing subscriptions in this area already"), "%s had %d subscriptions in the area", s.Owner, count) @@ -91,6 +87,15 @@ func (a *app) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) return nil }) + + if raftResult != nil { + var ok bool + sub, ok = raftResult.(*ridmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", raftResult) + } + } + return sub, err } @@ -98,7 +103,7 @@ func (a *app) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) func (a *app) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { var sub *ridmodels.Subscription - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + raftResult, err := a.store.Transact(ctx, ridraftstore.UpdateSubscriptionTransaction, s, func(ctx context.Context, repo repos.Repository) error { old, err := repo.GetSubscription(ctx, s.ID) switch { case err != nil: @@ -128,7 +133,7 @@ func (a *app) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) return stacktrace.Propagate(err, "Failed to fetch subscription count, rejecting request") } - if count >= maxSubscriptionsPerArea { + if count >= ridmodels.MaxSubscriptionsPerArea { return stacktrace.Propagate( stacktrace.NewErrorWithCode(dsserr.Exhausted, "Too many existing subscriptions in this area already"), "%s had %d subscriptions in the area", s.Owner, count) @@ -139,13 +144,26 @@ func (a *app) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) } return nil }) + + if raftResult != nil { + var ok bool + sub, ok = raftResult.(*ridmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", raftResult) + } + } + return sub, err } // DeleteSubscription deletes the Subscription identified by "id" and owned by "owner". func (a *app) DeleteSubscription(ctx context.Context, id dssmodels.ID, owner dssmodels.Owner, version *dssmodels.Version) (*ridmodels.Subscription, error) { var ret *ridmodels.Subscription - err := a.store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + raftResult, err := a.store.Transact(ctx, ridraftstore.DeleteSubscriptionTransaction, ridraftstore.DeleteSubscriptionPayload{ + ID: id, + Owner: owner, + Version: version, + }, func(ctx context.Context, repo repos.Repository) error { var err error old, err := repo.GetSubscription(ctx, id) switch { @@ -169,5 +187,14 @@ func (a *app) DeleteSubscription(ctx context.Context, id dssmodels.ID, owner dss } return nil }) + + if raftResult != nil { + var ok bool + ret, ok = raftResult.(*ridmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", raftResult) + } + } + return ret, err } diff --git a/pkg/rid/models/models.go b/pkg/rid/models/models.go index 2221c7e7d..2603f3758 100644 --- a/pkg/rid/models/models.go +++ b/pkg/rid/models/models.go @@ -1,8 +1,14 @@ package models import ( - "github.com/interuss/stacktrace" "net/url" + + "github.com/interuss/stacktrace" +) + +const ( + // Defined in requirement DSS0030. + MaxSubscriptionsPerArea = 10 ) // ValidateURL ensures https diff --git a/pkg/rid/store/memstore/doc.go b/pkg/rid/store/memstore/doc.go new file mode 100644 index 000000000..e42a19ad4 --- /dev/null +++ b/pkg/rid/store/memstore/doc.go @@ -0,0 +1,3 @@ +// Package rid.store.memstore provides a full implementation of store.Store[rid.repos.Repository] +// storing data in memory. It as meant to be used by raftstore. +package memstore diff --git a/pkg/rid/store/memstore/identification_service_area.go b/pkg/rid/store/memstore/identification_service_area.go new file mode 100644 index 000000000..29055426a --- /dev/null +++ b/pkg/rid/store/memstore/identification_service_area.go @@ -0,0 +1,152 @@ +package memstore + +import ( + "context" + "time" + + "github.com/golang/geo/s2" + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/timestamp" + "github.com/interuss/stacktrace" +) + +func isaRecordFromModel(isa *ridmodels.IdentificationServiceArea, updatedAt time.Time) *isaRecord { + return &isaRecord{ + ID: isa.ID, + URL: isa.URL, + Owner: isa.Owner, + Cells: cloneCells(isa.Cells), + StartTime: cloneTime(isa.StartTime), + EndTime: cloneTime(isa.EndTime), + AltitudeHi: cloneFloat32(isa.AltitudeHi), + AltitudeLo: cloneFloat32(isa.AltitudeLo), + Writer: isa.Writer, + UpdatedAt: updatedAt, + } +} + +// toModel rebuilds the ISA model +func (rec *isaRecord) toModel() *ridmodels.IdentificationServiceArea { + return &ridmodels.IdentificationServiceArea{ + ID: rec.ID, + URL: rec.URL, + Owner: rec.Owner, + Cells: cloneCells(rec.Cells), + StartTime: cloneTime(rec.StartTime), + EndTime: cloneTime(rec.EndTime), + Version: dssmodels.VersionFromTime(rec.UpdatedAt), + AltitudeHi: cloneFloat32(rec.AltitudeHi), + AltitudeLo: cloneFloat32(rec.AltitudeLo), + Writer: rec.Writer, + } +} + +func (r *repo) GetISA(_ context.Context, id dssmodels.ID, _ bool) (*ridmodels.IdentificationServiceArea, error) { + rec, ok := r.state.ISAs[id] + if !ok { + return nil, nil + } + return rec.toModel(), nil +} + +func (r *repo) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + if err := validateWriteData(isa.Cells, isa.StartTime, isa.EndTime); err != nil { + return nil, err + } + if _, ok := r.state.ISAs[isa.ID]; ok { + return nil, stacktrace.NewError("ISA with id %s already exists", isa.ID) + } + rec := isaRecordFromModel(isa, timestamp.NowFromContext(ctx)) + r.state.ISAs[isa.ID] = rec + return rec.toModel(), nil +} + +func (r *repo) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + if err := validateWriteData(isa.Cells, isa.StartTime, isa.EndTime); err != nil { + return nil, err + } + prev, ok := r.state.ISAs[isa.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(prev.UpdatedAt).Matches(isa.Version) { + return nil, nil + } + rec := isaRecordFromModel(isa, timestamp.NowFromContext(ctx)) + rec.Owner = prev.Owner + r.state.ISAs[isa.ID] = rec + return rec.toModel(), nil +} + +func (r *repo) DeleteISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + rec, ok := r.state.ISAs[isa.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(rec.UpdatedAt).Matches(isa.Version) { + return nil, nil + } + out := rec.toModel() + delete(r.state.ISAs, isa.ID) + return out, nil +} + +func (r *repo) SearchISAs(_ context.Context, cells s2.CellUnion, earliest *time.Time, latest *time.Time) ([]*ridmodels.IdentificationServiceArea, error) { + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Missing cell IDs for query") + } + if earliest == nil { + return nil, stacktrace.NewError("Earliest start time is missing") + } + + want := cellSet(cells) + var out []*ridmodels.IdentificationServiceArea + for _, rec := range r.state.ISAs { + // ends_at >= earliest + if rec.EndTime == nil || rec.EndTime.Before(*earliest) { + continue + } + // COALESCE(starts_at <= latest, true) + if latest != nil && rec.StartTime != nil && rec.StartTime.After(*latest) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil +} + +func (r *repo) ListExpiredISAs(_ context.Context, writer string, threshold time.Time) ([]*ridmodels.IdentificationServiceArea, error) { + var out []*ridmodels.IdentificationServiceArea + for _, rec := range r.state.ISAs { + // ends_at <= threshold + if rec.EndTime == nil || rec.EndTime.After(threshold) { + continue + } + if writer == "" { + if rec.Writer != "" { + continue + } + } else if rec.Writer != writer { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil +} + +func (r *repo) CountISAs(_ context.Context) (int64, error) { + return int64(len(r.state.ISAs)), nil +} diff --git a/pkg/rid/store/memstore/identification_service_area_test.go b/pkg/rid/store/memstore/identification_service_area_test.go new file mode 100644 index 000000000..70699e415 --- /dev/null +++ b/pkg/rid/store/memstore/identification_service_area_test.go @@ -0,0 +1,321 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/golang/geo/s2" + "github.com/google/uuid" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/rid/repos" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +var ( + // Ensure the struct conforms to the interface + _ repos.ISA = &repo{} + overflow = uint64(17106221850767130624) // face 5 L13 overflows + serviceArea = &ridmodels.IdentificationServiceArea{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner(uuid.New().String()), + URL: "https://no/place/like/home/for/flights", + StartTime: &startTime, + EndTime: &endTime, + Writer: writer, + Cells: s2.CellUnion{ + s2.CellID(uint64(overflow)), + s2.CellID(17106221850767130624), + }, + } +) + +func TestStoreSearchISAs(t *testing.T) { + ctx := context.Background() + cells := s2.CellUnion{ + s2.CellID(17106221850767130624), + s2.CellID(17106221885126868992), + s2.CellID(17106221919486607360), + s2.CellID(uint64(overflow)), + } + repo := setUpStore(t) + + isa := *serviceArea + isa.Cells = cells + saOut, err := repo.InsertISA(ctx, &isa) + require.NoError(t, err) + require.NotNil(t, saOut) + require.Equal(t, isa.ID, saOut.ID) + + for _, r := range []struct { + name string + cells s2.CellUnion + timestampMutator func(time.Time, time.Time) (*time.Time, *time.Time) + expectedLen int + }{ + { + name: "search for empty cell", + cells: s2.CellUnion{s2.CellID(17106221953846345728)}, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 0, + }, + { + name: "search for only one cell", + cells: s2.CellUnion{s2.CellID(17106221850767130624)}, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 1, + }, + { + name: "search for only one cell with high bit set", + cells: s2.CellUnion{s2.CellID(uint64(overflow))}, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 1, + }, + { + name: "search with nil ends_at", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 1, + }, + { + name: "search with exact timestamps", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, &end + }, + expectedLen: 1, + }, + { + name: "search with non-matching time span", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + var ( + offset = time.Duration(100 * time.Second) + earliest = end.Add(offset) + latest = end.Add(offset * 2) + ) + + return &earliest, &latest + }, + expectedLen: 0, + }, + { + name: "search with expanded time span", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + var ( + offset = time.Duration(100 * time.Second) + earliest = start.Add(-offset) + latest = end.Add(offset) + ) + + return &earliest, &latest + }, + expectedLen: 1, + }, + } { + t.Run(r.name, func(t *testing.T) { + earliest, latest := r.timestampMutator(*saOut.StartTime, *saOut.EndTime) + + serviceAreas, err := repo.SearchISAs(ctx, r.cells, earliest, latest) + require.NoError(t, err) + require.Len(t, serviceAreas, r.expectedLen) + }) + } +} + +func TestBadVersion(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + saOut1, err := repo.InsertISA(ctx, serviceArea) + require.NoError(t, err) + require.NotNil(t, saOut1) + + // Rewriting service area should fail + saOut2, err := repo.UpdateISA(ctx, serviceArea) + require.NoError(t, err) + require.Nil(t, saOut2) + + // Rewriting, but with the correct version should work. + newEndTime := saOut1.EndTime.Add(time.Minute) + saOut1.EndTime = &newEndTime + saOut3, err := repo.UpdateISA(ctx, saOut1) + require.NoError(t, err) + require.NotNil(t, saOut3) +} + +func TestStoreExpiredISA(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + saOut, err := repo.InsertISA(ctx, serviceArea) + require.NoError(t, err) + require.NotNil(t, saOut) + + // The ISA's endTime is one hour from now. + fakeClock.Advance(59 * time.Minute) + + // We should still be able to find the ISA by searching and by ID. + now := fakeClock.Now() + serviceAreas, err := repo.SearchISAs(ctx, serviceArea.Cells, &now, nil) + require.NoError(t, err) + require.Len(t, serviceAreas, 1) + + ret, err := repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, ret) + + // But now the ISA has expired. + fakeClock.Advance(2 * time.Minute) + now = fakeClock.Now() + + serviceAreas, err = repo.SearchISAs(ctx, serviceArea.Cells, &now, nil) + require.NoError(t, err) + require.Len(t, serviceAreas, 0) + + // A get should work even if it is expired. + ret, err = repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, ret) +} + +func TestStoreDeleteISAs(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + // Insert the ISA. + copy := *serviceArea + isa, err := repo.InsertISA(ctx, ©) + require.NoError(t, err) + require.NotNil(t, isa) + + // Delete the ISA. + // Ensure a fresh Get, then delete still updates the sub indexes + isa, err = repo.GetISA(ctx, isa.ID, false) + require.NoError(t, err) + + serviceAreaOut, err := repo.DeleteISA(ctx, isa) + require.NoError(t, err) + require.Equal(t, isa, serviceAreaOut) +} + +func TestStoreISAWithNoGeoData(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + endTime := fakeClock.Now().Add(24 * time.Hour) + sub := &ridmodels.IdentificationServiceArea{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner("original owner"), + EndTime: &endTime, + } + _, err := repo.InsertISA(ctx, sub) + require.Error(t, err) +} + +func TestListExpiredISAs(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert ISA with endtime 1 day from now + isa1 := *serviceArea + startTime := fakeClock.Now() + isa1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + isa1.EndTime = &endTime + saOut1, err := repo.InsertISA(ctx, &isa1) + require.NoError(t, err) + require.NotNil(t, saOut1) + + // Insert ISA with endtime to 30 minutes ago + isa2 := *serviceArea + startTime = fakeClock.Now().Add(-1 * time.Hour) + isa2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + isa2.EndTime = &endTime + isa2.ID = dssmodels.ID(uuid.New().String()) + saOut2, err := repo.InsertISA(ctx, &isa2) + require.NoError(t, err) + require.NotNil(t, saOut2) + + serviceAreas, err := repo.ListExpiredISAs(ctx, writer, fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, serviceAreas, 1) +} + +func TestListExpiredISAsWithEmptyWriter(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert ISA with endtime 1 day from now + isa1 := *serviceArea + startTime := fakeClock.Now() + isa1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + isa1.EndTime = &endTime + isa1.Writer = "" + saOut1, err := repo.InsertISA(ctx, &isa1) + require.NoError(t, err) + require.NotNil(t, saOut1) + + // Insert ISA with endtime to 30 minutes ago + isa2 := *serviceArea + startTime = fakeClock.Now().Add(-1 * time.Hour) + isa2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + isa2.EndTime = &endTime + isa2.ID = dssmodels.ID(uuid.New().String()) + isa2.Writer = "" + saOut2, err := repo.InsertISA(ctx, &isa2) + require.NoError(t, err) + require.NotNil(t, saOut2) + + serviceAreas, err := repo.ListExpiredISAs(ctx, "", fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, serviceAreas, 1) +} + +func TestStoreCountISAs(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + // Insert the ISA. + copy := *serviceArea + isa, err := repo.InsertISA(ctx, ©) + require.NoError(t, err) + require.NotNil(t, isa) + + //Cound should be one + count, err := repo.CountISAs(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(1)) + + // Delete the ISA. + // Ensure a fresh Get, then delete still updates the sub indexes + isa, err = repo.GetISA(ctx, isa.ID, false) + require.NoError(t, err) + + serviceAreaOut, err := repo.DeleteISA(ctx, isa) + require.NoError(t, err) + require.Equal(t, isa, serviceAreaOut) + + //Cound should be zero + count, err = repo.CountISAs(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(0)) +} diff --git a/pkg/rid/store/memstore/snapshot.go b/pkg/rid/store/memstore/snapshot.go new file mode 100644 index 000000000..fe2aa2c70 --- /dev/null +++ b/pkg/rid/store/memstore/snapshot.go @@ -0,0 +1,43 @@ +package memstore + +import ( + "bytes" + "encoding/gob" + + dssmodels "github.com/interuss/dss/pkg/models" + "github.com/interuss/stacktrace" +) + +const snapshotVersion = 1 + +type snapshotEnvelope struct { + Version int + State state +} + +func (r *repo) GetSnapshot() ([]byte, error) { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion, State: r.state}); err != nil { + return nil, stacktrace.Propagate(err, "Failed to encode memstore snapshot") + } + return buf.Bytes(), nil +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + var env snapshotEnvelope + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&env); err != nil { + return stacktrace.Propagate(err, "Failed to decode memstore snapshot") + } + if env.Version != snapshotVersion { + return stacktrace.NewError("Unsupported memstore snapshot version %d, expected %d", env.Version, snapshotVersion) + } + r.state = env.State + // gob decodes an empty map as nil; re-initialize to keep the repo writable. + if r.state.ISAs == nil { + r.state.ISAs = map[dssmodels.ID]*isaRecord{} + } + if r.state.Subscriptions == nil { + r.state.Subscriptions = map[dssmodels.ID]*subscriptionRecord{} + } + return nil +} diff --git a/pkg/rid/store/memstore/snapshot_test.go b/pkg/rid/store/memstore/snapshot_test.go new file mode 100644 index 000000000..a9764f812 --- /dev/null +++ b/pkg/rid/store/memstore/snapshot_test.go @@ -0,0 +1,82 @@ +package memstore + +import ( + "bytes" + "context" + "encoding/gob" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/interuss/dss/pkg/models" + "github.com/stretchr/testify/require" +) + +func TestSnapshotRoundTrip(t *testing.T) { + ctx := context.Background() + src := setUpStore(t) + _, err := src.InsertISA(ctx, serviceArea) + require.NoError(t, err) + _, err = src.InsertSubscription(ctx, subscriptionsPool[0].input) + require.NoError(t, err) + + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := setUpStore(t) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + wantISA, err := src.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + gotISA, err := dst.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + + if diff := cmp.Diff(wantISA, gotISA, cmpopts.EquateApproxTime(0), cmp.AllowUnexported(models.Version{})); diff != "" { + t.Errorf("IdentificationServiceArea mismatch (-want +got):\n%s", diff) + } + + wantSub, err := src.GetSubscription(ctx, subscriptionsPool[0].input.ID) + require.NoError(t, err) + gotSub, err := dst.GetSubscription(ctx, subscriptionsPool[0].input.ID) + require.NoError(t, err) + + if diff := cmp.Diff(wantSub, gotSub, cmpopts.EquateApproxTime(0), cmp.AllowUnexported(models.Version{})); diff != "" { + t.Errorf("Subscription mismatch (-want +got):\n%s", diff) + } +} + +func TestRestoreFromSnapshotReplacesState(t *testing.T) { + ctx := context.Background() + src := setUpStore(t) + _, err := src.InsertISA(ctx, serviceArea) + require.NoError(t, err) + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := setUpStore(t) + other := *serviceArea + other.ID = "00000000-0000-4000-8000-000000000002" + _, err = dst.InsertISA(ctx, &other) + require.NoError(t, err) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + count, err := dst.CountISAs(ctx) + require.NoError(t, err) + require.Equal(t, int64(1), count) + got, err := dst.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, got) + gone, err := dst.GetISA(ctx, other.ID, false) + require.NoError(t, err) + require.Nil(t, gone) +} + +func TestRestoreFromSnapshotInvalidData(t *testing.T) { + require.Error(t, setUpStore(t).RestoreFromSnapshot([]byte("random value that is definitely not valid"))) +} + +func TestRestoreFromSnapshotVersionMismatch(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion + 1})) + require.Error(t, setUpStore(t).RestoreFromSnapshot(buf.Bytes())) +} diff --git a/pkg/rid/store/memstore/store.go b/pkg/rid/store/memstore/store.go new file mode 100644 index 000000000..cefd0bb4c --- /dev/null +++ b/pkg/rid/store/memstore/store.go @@ -0,0 +1,170 @@ +package memstore + +import ( + "context" + "time" + + "github.com/golang/geo/s2" + "github.com/interuss/dss/pkg/geo" + "github.com/interuss/dss/pkg/memstore" + dssmodels "github.com/interuss/dss/pkg/models" + "github.com/interuss/dss/pkg/rid/repos" + "github.com/interuss/stacktrace" + "go.uber.org/zap" +) + +// repo is a full implementation of rid.repos.Repository for memory-based storage. +type repo struct { + state state +} + +// state is the serializable in-memory state. +type state struct { + // ISAs holds the stored ISAs keyed by ID. + ISAs map[dssmodels.ID]*isaRecord + // Subscriptions holds the stored subscriptions keyed by ID. + Subscriptions map[dssmodels.ID]*subscriptionRecord +} + +// isaRecord is the gob-serializable representation of an ISA. It intentionally +// stores only primitive fields: the model's Version is never persisted, it is +// derived from UpdatedAt on read. +type isaRecord struct { + ID dssmodels.ID + URL string + Owner dssmodels.Owner + Cells s2.CellUnion + StartTime *time.Time + EndTime *time.Time + AltitudeHi *float32 + AltitudeLo *float32 + Writer string + UpdatedAt time.Time +} + +// subscriptionRecord is the gob-serializable representation of a Subscription. +type subscriptionRecord struct { + ID dssmodels.ID + URL string + NotificationIndex int + Owner dssmodels.Owner + Cells s2.CellUnion + StartTime *time.Time + EndTime *time.Time + AltitudeHi *float32 + AltitudeLo *float32 + Writer string + UpdatedAt time.Time +} + +func newRepo() *repo { + r := &repo{} + r.resetState() + return r +} + +func (r *repo) resetState() { + r.state = state{ + ISAs: map[dssmodels.ID]*isaRecord{}, + Subscriptions: map[dssmodels.ID]*subscriptionRecord{}, + } +} + +func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { + return memstore.Init(ctx, logger, "rid", newRepo()) +} + +func (r *repo) GetRepo() repos.Repository { return r } + +// validateWriteData validate constraints on an ISA +func validateWriteData(cells s2.CellUnion, start, end *time.Time) error { + if len(cells) == 0 { + return stacktrace.NewError("At least one cell must be provided") + } + for _, c := range cells { + if err := geo.ValidateCell(c); err != nil { + return stacktrace.Propagate(err, "Error validating cell") + } + } + if start != nil && end != nil && !start.Before(*end) { + return stacktrace.NewError("Start time must be strictly before end time") + } + return nil +} + +// cellSet builds a lookup set from a cell union. +func cellSet(cells s2.CellUnion) map[s2.CellID]struct{} { + set := make(map[s2.CellID]struct{}, len(cells)) + for _, c := range cells { + set[c] = struct{}{} + } + return set +} + +// overlaps reports whether any cell is present in set (equivalent to the SQL +// "cells && $x" array-overlap operator). +func overlaps(cells s2.CellUnion, set map[s2.CellID]struct{}) bool { + for _, c := range cells { + if _, ok := set[c]; ok { + return true + } + } + return false +} + +func cloneCells(cells s2.CellUnion) s2.CellUnion { + if cells == nil { + return nil + } + return append(s2.CellUnion(nil), cells...) +} + +func cloneTime(t *time.Time) *time.Time { + if t == nil { + return nil + } + v := *t + return &v +} + +func cloneFloat32(f *float32) *float32 { + if f == nil { + return nil + } + v := *f + return &v +} + +// clone returns a copy of s with independent maps and records. Cell slices and +// time pointers are shared, as they are never mutated in place. +func (s state) clone() state { + isas := make(map[dssmodels.ID]*isaRecord, len(s.ISAs)) + for id, rec := range s.ISAs { + cp := *rec + isas[id] = &cp + } + subs := make(map[dssmodels.ID]*subscriptionRecord, len(s.Subscriptions)) + for id, rec := range s.Subscriptions { + cp := *rec + subs[id] = &cp + } + return state{ISAs: isas, Subscriptions: subs} +} + +// Checkpoint returns a fast, restorable in-memory copy of the current state. +// Unlike GetSnapshot it does not serialize, so it is cheap but only valid +// in-process. +func (r *repo) Checkpoint() any { + return r.state.clone() +} + +// Restore replaces the current state with a checkpoint previously returned by +// Checkpoint. The checkpoint is copied, so it stays reusable. +func (r *repo) Restore(cp any) error { + s, ok := cp.(state) + if !ok { + return stacktrace.NewError("Invalid checkpoint type %T", cp) + } + r.state = s.clone() + return nil +} diff --git a/pkg/rid/store/memstore/store_test.go b/pkg/rid/store/memstore/store_test.go new file mode 100644 index 000000000..ef6638513 --- /dev/null +++ b/pkg/rid/store/memstore/store_test.go @@ -0,0 +1,97 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +var ( + fakeClock = clockwork.NewFakeClock() + startTime = fakeClock.Now().UTC().Add(-time.Minute) + endTime = fakeClock.Now().UTC().Add(time.Hour) + writer = "writer" +) + +// setUpStore returns a fresh in-memory repo whose clock is the (reset) package +// fakeClock, so tests can advance time deterministically. +func setUpStore(t *testing.T) *repo { + t.Helper() + r := newRepo() + return r +} + +func TestDatabaseEnsuresBeginsBeforeExpires(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + var ( + begins = time.Now().UTC() + expires = begins.Add(-5 * time.Minute) + ) + _, err := repo.InsertSubscription(ctx, &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "me-myself-and-i", + URL: "https://no/place/like/home", + NotificationIndex: 42, + StartTime: &begins, + EndTime: &expires, + }) + require.Error(t, err) +} + +func TestCheckpointRestoreISA(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + _, err := repo.InsertISA(ctx, serviceArea) + require.NoError(t, err) + + cp := repo.Checkpoint() + + // Mutate after the checkpoint. + isa, err := repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + _, err = repo.DeleteISA(ctx, isa) + require.NoError(t, err) + gone, err := repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.Nil(t, gone) + + // Restore brings it back. + require.NoError(t, repo.Restore(cp)) + back, err := repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, back) +} + +func TestCheckpointIsolatesNotificationIndex(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + sub, err := repo.InsertSubscription(ctx, subscriptionsPool[0].input) + require.NoError(t, err) + + cp := repo.Checkpoint() + + // In-place notification-index bump must not leak into the checkpoint. + updated, err := repo.UpdateNotificationIdxsInCells(ctx, sub.Cells) + require.NoError(t, err) + require.Len(t, updated, 1) + require.Equal(t, sub.NotificationIndex+1, updated[0].NotificationIndex) + + require.NoError(t, repo.Restore(cp)) + restored, err := repo.GetSubscription(ctx, sub.ID) + require.NoError(t, err) + require.Equal(t, sub.NotificationIndex, restored.NotificationIndex) +} + +func TestRestoreInvalidType(t *testing.T) { + require.Error(t, setUpStore(t).Restore("not a checkpoint")) +} diff --git a/pkg/rid/store/memstore/subscriptions.go b/pkg/rid/store/memstore/subscriptions.go new file mode 100644 index 000000000..5923c3927 --- /dev/null +++ b/pkg/rid/store/memstore/subscriptions.go @@ -0,0 +1,214 @@ +package memstore + +import ( + "context" + "time" + + "github.com/golang/geo/s2" + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/timestamp" + "github.com/interuss/stacktrace" +) + +func subRecordFromModel(s *ridmodels.Subscription, updatedAt time.Time) *subscriptionRecord { + return &subscriptionRecord{ + ID: s.ID, + URL: s.URL, + NotificationIndex: s.NotificationIndex, + Owner: s.Owner, + Cells: cloneCells(s.Cells), + StartTime: cloneTime(s.StartTime), + EndTime: cloneTime(s.EndTime), + AltitudeHi: cloneFloat32(s.AltitudeHi), + AltitudeLo: cloneFloat32(s.AltitudeLo), + Writer: s.Writer, + UpdatedAt: updatedAt, + } +} + +func (rec *subscriptionRecord) toModel() *ridmodels.Subscription { + return &ridmodels.Subscription{ + ID: rec.ID, + URL: rec.URL, + NotificationIndex: rec.NotificationIndex, + Owner: rec.Owner, + Cells: cloneCells(rec.Cells), + StartTime: cloneTime(rec.StartTime), + EndTime: cloneTime(rec.EndTime), + Version: dssmodels.VersionFromTime(rec.UpdatedAt), + AltitudeHi: cloneFloat32(rec.AltitudeHi), + AltitudeLo: cloneFloat32(rec.AltitudeLo), + Writer: rec.Writer, + } +} + +func (r *repo) GetSubscription(_ context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) { + rec, ok := r.state.Subscriptions[id] + if !ok { + return nil, nil + } + return rec.toModel(), nil +} + +func (r *repo) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { + if err := validateWriteData(s.Cells, s.StartTime, s.EndTime); err != nil { + return nil, err + } + if _, ok := r.state.Subscriptions[s.ID]; ok { + return nil, stacktrace.NewError("Subscription with id %s already exists", s.ID) + } + rec := subRecordFromModel(s, timestamp.NowFromContext(ctx)) + r.state.Subscriptions[s.ID] = rec + return rec.toModel(), nil +} + +func (r *repo) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { + if err := validateWriteData(s.Cells, s.StartTime, s.EndTime); err != nil { + return nil, err + } + prev, ok := r.state.Subscriptions[s.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(prev.UpdatedAt).Matches(s.Version) { + return nil, nil + } + rec := subRecordFromModel(s, timestamp.NowFromContext(ctx)) + rec.Owner = prev.Owner + r.state.Subscriptions[s.ID] = rec + return rec.toModel(), nil +} + +func (r *repo) DeleteSubscription(_ context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { + rec, ok := r.state.Subscriptions[s.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(rec.UpdatedAt).Matches(s.Version) { + return nil, nil + } + out := rec.toModel() + delete(r.state.Subscriptions, s.ID) + return out, nil +} + +func (r *repo) SearchSubscriptions(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") + } + now := timestamp.NowFromContext(ctx) + want := cellSet(cells) + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil +} + +func (r *repo) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") + } + now := timestamp.NowFromContext(ctx) + want := cellSet(cells) + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + if rec.Owner != owner { + continue + } + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil +} + +// UpdateNotificationIdxsInCells increments the notification index for each +// subscription in the given cells. +func (r *repo) UpdateNotificationIdxsInCells(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { + now := timestamp.NowFromContext(ctx) + want := cellSet(cells) + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + rec.NotificationIndex++ + out = append(out, rec.toModel()) + } + return out, nil +} + +func (r *repo) MaxSubscriptionCountInCellsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { + now := timestamp.NowFromContext(ctx) + want := cellSet(cells) + counts := make(map[s2.CellID]int, len(cells)) + for _, rec := range r.state.Subscriptions { + if rec.Owner != owner { + continue + } + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + for _, c := range rec.Cells { + if _, ok := want[c]; ok { + counts[c]++ + } + } + } + best := 0 + for _, n := range counts { + if n > best { + best = n + } + } + return best, nil +} + +func (r *repo) ListExpiredSubscriptions(_ context.Context, writer string, threshold time.Time) ([]*ridmodels.Subscription, error) { + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + // ends_at <= threshold + if rec.EndTime == nil || rec.EndTime.After(threshold) { + continue + } + if writer == "" { + if rec.Writer != "" { + continue + } + } else if rec.Writer != writer { + continue + } + out = append(out, rec.toModel()) + + // TODO: This miminc sqlstore inconsistency of not limiting results there, comparted to ISAs. Should it be normalized? + } + return out, nil +} + +func (r *repo) CountSubscriptions(_ context.Context) (int64, error) { + return int64(len(r.state.Subscriptions)), nil +} diff --git a/pkg/rid/store/memstore/subscriptions_test.go b/pkg/rid/store/memstore/subscriptions_test.go new file mode 100644 index 000000000..3b90cf0e0 --- /dev/null +++ b/pkg/rid/store/memstore/subscriptions_test.go @@ -0,0 +1,363 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/golang/geo/s2" + "github.com/google/uuid" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/rid/repos" + "github.com/interuss/dss/pkg/timestamp" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +var ( + // Ensure the struct conforms to the interface + _ repos.Subscription = &repo{} + subscriptionsPool = []struct { + name string + input *ridmodels.Subscription + }{ + { + name: "a subscription with startTime and endTime", + input: &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "myself", + URL: "https://no/place/like/home", + StartTime: &startTime, + EndTime: &endTime, + NotificationIndex: 42, + Writer: writer, + Cells: s2.CellUnion{ + s2.CellID(uint64(overflow)), + 12494535935418957824, + }, + }, + }, + { + name: "a subscription without startTime and with endTime", + input: &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "myself", + URL: "https://no/place/like/home", + EndTime: &endTime, + NotificationIndex: 42, + Cells: s2.CellUnion{ + 12494535935418957824, + }, + }, + }, + { + name: "a subscription without startTime and with endTime", + input: &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "me", + URL: "https://no/place/like/home", + StartTime: &startTime, + EndTime: &endTime, + NotificationIndex: 42, + Cells: s2.CellUnion{ + 12494535935418957824, + }, + }, + }, + } +) + +func TestStoreGetSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + sub2, err := repo.GetSubscription(ctx, sub1.ID) + require.NoError(t, err) + require.NotNil(t, sub2) + + require.Equal(t, *sub1, *sub2) + }) + } +} + +func TestStoreInsertSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + // Test changes without the version differing. + r2 := *sub1 + r2.URL = "new url" + sub2, err := repo.UpdateSubscription(ctx, &r2) + require.NoError(t, err) + require.NotNil(t, sub2) + require.Equal(t, "new url", sub2.URL) + + // Test it doesn't work when Version is nil. + r3 := *sub2 + r3.URL = "new url 2" + r3.Version = nil + sub3, err := repo.UpdateSubscription(ctx, &r3) + require.NoError(t, err) + require.Nil(t, sub3) + + // Bad version doesn't work. + r4 := *sub2 + r4.URL = "new url 3" + r4.Version = dssmodels.VersionFromTime(time.Now()) + sub4, err := repo.UpdateSubscription(ctx, &r4) + require.NoError(t, err) + require.Nil(t, sub4) + + sub5, err := repo.GetSubscription(ctx, sub1.ID) + require.NoError(t, err) + require.NotNil(t, sub5) + + require.Equal(t, *sub2, *sub5) + }) + } +} + +func TestStoreDeleteSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + // Ensure mismatched versions returns nothing + sub1BadVersion := *sub1 + sub1BadVersion.Version, err = dssmodels.VersionFromString("a3cg3tcuhk00") + require.NoError(t, err) + sub2, err := repo.DeleteSubscription(ctx, &sub1BadVersion) + require.NoError(t, err) + require.Nil(t, sub2) + + sub4, err := repo.DeleteSubscription(ctx, sub1) + require.NoError(t, err) + require.NotNil(t, sub4) + + require.Equal(t, *sub1, *sub4) + }) + } +} + +func TestStoreSearchSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + var ( + // pick an L13 value that overflows. + overflow = uint64(17106221850767130624) + + cells = s2.CellUnion{ + s2.CellID(12494535935418957824), + s2.CellID(12494535866699481088), + s2.CellID(12494535901059219456), + s2.CellID(12494535866699481088), + s2.CellID(overflow), + } + owners = []dssmodels.Owner{ + "me", + "my", + "self", + } + ) + + for i, r := range subscriptionsPool { + subscription := *r.input + subscription.Owner = owners[i] + subscription.Cells = cells[:i+1] + sub1, err := repo.InsertSubscription(ctx, &subscription) + require.NoError(t, err) + require.NotNil(t, sub1) + } + // Test normal search + found, err := repo.SearchSubscriptions(ctx, cells) + require.NoError(t, err) + require.Len(t, found, 3) + for _, owner := range owners { + found, err := repo.SearchSubscriptionsByOwner(ctx, cells, owner) + require.NoError(t, err) + require.NotNil(t, found) + // We insert one subscription per owner. Hence, no matter how many cells are touched by the subscription, + // the result should always be 1. + require.Len(t, found, 1) + } +} + +func TestStoreExpiredSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + endTime := fakeClock.Now().Add(24 * time.Hour) + sub := &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner("original owner"), + Cells: s2.CellUnion{s2.CellID(12494535866699481088)}, + EndTime: &endTime, + } + _, err := repo.InsertSubscription(ctx, sub) + require.NoError(t, err) + + // The subscription's endTime is 24 hours from now. + fakeClock.Advance(23 * time.Hour) + ctx = timestamp.WithTimestamp(ctx, fakeClock.Now()) + + // We should still be able to find the subscription by searching and by ID. + subs, err := repo.SearchSubscriptionsByOwner(ctx, sub.Cells, "original owner") + require.NoError(t, err) + require.Len(t, subs, 1) + + ret, err := repo.GetSubscription(ctx, sub.ID) + require.NoError(t, err) + require.NotNil(t, &ret) + + // But now the subscription has expired. + fakeClock.Advance(2 * time.Hour) + ctx = timestamp.WithTimestamp(ctx, fakeClock.Now()) + + subs, err = repo.SearchSubscriptionsByOwner(ctx, sub.Cells, "original owner") + require.NoError(t, err) + require.Len(t, subs, 0) + + ret, err = repo.GetSubscription(ctx, sub.ID) + require.NotNil(t, ret) + require.NoError(t, err) +} + +func TestStoreSubscriptionWithNoGeoData(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + endTime := fakeClock.Now().Add(24 * time.Hour) + sub := &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner("original owner"), + EndTime: &endTime, + } + _, err := repo.InsertSubscription(ctx, sub) + require.Error(t, err) +} + +func TestMaxSubscriptionCountInCellsByOwner(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, s := range subscriptionsPool { + _, err := repo.InsertSubscription(ctx, s.input) + require.NoError(t, err) + } + + count, err := repo.MaxSubscriptionCountInCellsByOwner(ctx, s2.CellUnion{12494535935418957824}, "myself") + require.NoError(t, err) + require.Equal(t, 2, count) +} + +func TestListExpiredSubscriptions(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert Subscription with endtime 1 day from now + subscripiton1 := *subscriptionsPool[0].input + startTime := fakeClock.Now() + subscripiton1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + subscripiton1.EndTime = &endTime + subOut1, err := repo.InsertSubscription(ctx, &subscripiton1) + require.NoError(t, err) + require.NotNil(t, subOut1) + + // Insert Subscription with endtime to 30 minutes ago + subscripiton2 := *subscriptionsPool[0].input + startTime = fakeClock.Now().Add(-1 * time.Hour) + subscripiton2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + subscripiton2.EndTime = &endTime + subscripiton2.ID = dssmodels.ID(uuid.New().String()) + subOut2, err := repo.InsertSubscription(ctx, &subscripiton2) + require.NoError(t, err) + require.NotNil(t, subOut2) + + subscriptions, err := repo.ListExpiredSubscriptions(ctx, writer, fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, subscriptions, 1) +} + +func TestListExpiredSubscriptionsWithEmptyWriter(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert Subscription with endtime 1 day from now + subscripiton1 := *subscriptionsPool[0].input + startTime := fakeClock.Now() + subscripiton1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + subscripiton1.EndTime = &endTime + subscripiton1.Writer = "" + subOut1, err := repo.InsertSubscription(ctx, &subscripiton1) + require.NoError(t, err) + require.NotNil(t, subOut1) + + // Insert Subscription with endtime to 30 minutes ago + subscripiton2 := *subscriptionsPool[0].input + startTime = fakeClock.Now().Add(-1 * time.Hour) + subscripiton2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + subscripiton2.EndTime = &endTime + subscripiton2.ID = dssmodels.ID(uuid.New().String()) + subscripiton2.Writer = "" + subOut2, err := repo.InsertSubscription(ctx, &subscripiton2) + require.NoError(t, err) + require.NotNil(t, subOut2) + + subscriptions, err := repo.ListExpiredSubscriptions(ctx, "", fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, subscriptions, 1) +} + +func TestStoreCountSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + //Cound should be one + count, err := repo.CountSubscriptions(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(1)) + + sub4, err := repo.DeleteSubscription(ctx, sub1) + require.NoError(t, err) + require.NotNil(t, sub4) + + //Cound should be zero + count, err = repo.CountSubscriptions(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(0)) + }) + } +} diff --git a/pkg/rid/store/raftstore/identification_service_area.go b/pkg/rid/store/raftstore/identification_service_area.go index b9f7222a5..f68c668b5 100644 --- a/pkg/rid/store/raftstore/identification_service_area.go +++ b/pkg/rid/store/raftstore/identification_service_area.go @@ -5,36 +5,149 @@ import ( "time" "github.com/golang/geo/s2" - dsserr "github.com/interuss/dss/pkg/errors" dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" "github.com/interuss/stacktrace" ) -func (r *repo) GetISA(_ context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetISA not implemented for raftstore") +type getISAPayload struct { + ID dssmodels.ID + ForUpdate bool } -func (r *repo) DeleteISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteISA not implemented for raftstore") +func (r *repo) GetISA(ctx context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) { + result, err := r.consensus.ProposeValue(ctx, getISA, &getISAPayload{ID: id, ForUpdate: forUpdate}, true) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + isa, ok := result.(*ridmodels.IdentificationServiceArea) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return isa, nil } -func (r *repo) InsertISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "InsertISA not implemented for raftstore") +func (r *repo) DeleteISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + result, err := r.consensus.ProposeValue(ctx, deleteISA, isa, false) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + isa, ok := result.(*ridmodels.IdentificationServiceArea) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return isa, nil } -func (r *repo) UpdateISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpdateISA not implemented for raftstore") +func (r *repo) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + result, err := r.consensus.ProposeValue(ctx, insertISA, isa, false) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + isa, ok := result.(*ridmodels.IdentificationServiceArea) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return isa, nil } -func (r *repo) SearchISAs(_ context.Context, cells s2.CellUnion, earliest *time.Time, latest *time.Time) ([]*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchISAs not implemented for raftstore") +func (r *repo) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + result, err := r.consensus.ProposeValue(ctx, updateISA, isa, false) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + isa, ok := result.(*ridmodels.IdentificationServiceArea) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return isa, nil } -func (r *repo) ListExpiredISAs(_ context.Context, writer string, threshold time.Time) ([]*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredISAs not implemented for raftstore") +type searchISAsPayload struct { + Cells s2.CellUnion + Earliest *time.Time + Latest *time.Time } -func (r *repo) CountISAs(_ context.Context) (int64, error) { - return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountISAs not implemented for raftstore") +func (r *repo) SearchISAs(ctx context.Context, cells s2.CellUnion, earliest *time.Time, latest *time.Time) ([]*ridmodels.IdentificationServiceArea, error) { + result, err := r.consensus.ProposeValue(ctx, searchISAs, &searchISAsPayload{Cells: cells, Earliest: earliest, Latest: latest}, true) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + isa, ok := result.([]*ridmodels.IdentificationServiceArea) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return isa, nil +} + +type listExpiredISAsPayload struct { + Writer string + Threshold time.Time +} + +func (r *repo) ListExpiredISAs(ctx context.Context, writer string, threshold time.Time) ([]*ridmodels.IdentificationServiceArea, error) { + result, err := r.consensus.ProposeValue(ctx, listExpiredISAs, &listExpiredISAsPayload{Writer: writer, Threshold: threshold}, true) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + isa, ok := result.([]*ridmodels.IdentificationServiceArea) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return isa, nil +} + +func (r *repo) CountISAs(ctx context.Context) (int64, error) { + result, err := r.consensus.ProposeValue(ctx, countISAs, nil, true) + if err != nil { + return 0, err + } + + if result == nil { + return 0, nil + } + + count, ok := result.(int64) + if !ok { + return 0, stacktrace.NewError("invalid result type: %T", result) + } + + return count, nil } diff --git a/pkg/rid/store/raftstore/identitfication_service_area_appliers.go b/pkg/rid/store/raftstore/identitfication_service_area_appliers.go new file mode 100644 index 000000000..5049406d6 --- /dev/null +++ b/pkg/rid/store/raftstore/identitfication_service_area_appliers.go @@ -0,0 +1,162 @@ +package raftstore + +import ( + "context" + "encoding/json" + + "github.com/golang/geo/s2" + + dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/dss/pkg/geo" + dssmodels "github.com/interuss/dss/pkg/models" + "github.com/interuss/dss/pkg/raftstore/consensus" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/rid/repos" + "github.com/interuss/stacktrace" +) + +type ISATransactionResult struct { + Ret *ridmodels.IdentificationServiceArea + Subs []*ridmodels.Subscription +} + +type DeleteISATransactionPayload struct { + ID dssmodels.ID + Owner dssmodels.Owner + Version *dssmodels.Version +} + +func (r *repo) deleteISATransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*ISATransactionResult, error) { + var payload DeleteISATransactionPayload + err := json.Unmarshal(proposal.Value, &payload) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal payload") + } + + old, err := mem.GetISA(ctx, payload.ID, true) + switch { + case err != nil: + return nil, stacktrace.Propagate(err, "Error getting ISA") + case old == nil: + return nil, stacktrace.NewErrorWithCode(dsserr.NotFound, "ISA %s not found", payload.ID.String()) + case !payload.Version.Matches(old.Version): + return nil, stacktrace.NewErrorWithCode(dsserr.VersionMismatch, + "ISA currently at version %s but client specified %s", old.Version, payload.Version) + case old.Owner != payload.Owner: + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, + "ISA owned by %s, but %s attempted to delete", old.Owner, payload.Owner) + } + + checkpoint := r.memStore.Checkpoint() + ret, err := mem.DeleteISA(ctx, old) + if ret == nil || err != nil { + restoreErr := r.memStore.Restore(checkpoint) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Error restoring store") + } + + return nil, stacktrace.Propagate(err, "Error deleting ISA") + } + + subs, err := mem.UpdateNotificationIdxsInCells(ctx, old.Cells) + if err != nil { + restoreErr := r.memStore.Restore(checkpoint) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Error restoring store") + } + + return nil, stacktrace.Propagate(err, "Error updating notification indices") + } + + return &ISATransactionResult{Ret: ret, Subs: subs}, nil +} + +func (r *repo) insertISATransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*ISATransactionResult, error) { + var isa *ridmodels.IdentificationServiceArea + err := json.Unmarshal(proposal.Value, &isa) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal payload") + } + + old, err := mem.GetISA(ctx, isa.ID, false) + if err != nil { + return nil, stacktrace.Propagate(err, "Error getting ISA") + } + if old != nil { + return nil, stacktrace.NewErrorWithCode(dsserr.AlreadyExists, "ISA %s already exists", isa.ID) + } + + checkpoint := r.memStore.Checkpoint() + subs, err := mem.UpdateNotificationIdxsInCells(ctx, isa.Cells) + if err != nil { + restoreErr := r.memStore.Restore(checkpoint) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Error restoring store") + } + + return nil, stacktrace.Propagate(err, "Error updating notification indices") + } + + ret, err := mem.InsertISA(ctx, isa) + if err != nil { + restoreErr := r.memStore.Restore(checkpoint) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Error restoring store") + } + + return nil, stacktrace.Propagate(err, "Error inserting ISA") + } + + return &ISATransactionResult{Ret: ret, Subs: subs}, nil +} + +func (r *repo) updateISATransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*ISATransactionResult, error) { + var isa *ridmodels.IdentificationServiceArea + err := json.Unmarshal(proposal.Value, &isa) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal payload") + } + + old, err := mem.GetISA(ctx, isa.ID, true) + switch { + case err != nil: + return nil, stacktrace.Propagate(err, "Error getting ISA") + case old == nil: + return nil, stacktrace.NewErrorWithCode(dsserr.NotFound, "ISA %s not found", isa.ID) + case old.Owner != isa.Owner: + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, + "ISA owned by %s, but %s attempted to modify", old.Owner, isa.Owner) + case !old.Version.Matches(isa.Version): + return nil, stacktrace.NewErrorWithCode(dsserr.VersionMismatch, + "ISA currently at version %s but client specified %s", old.Version, isa.Version) + } + + if err := isa.AdjustTimeRange(proposal.Timestamp, old); err != nil { + return nil, stacktrace.Propagate(err, "Error adjusting time range") + } + + checkpoint := r.memStore.Checkpoint() + ret, err := mem.UpdateISA(ctx, isa) + if err != nil { + restoreErr := r.memStore.Restore(checkpoint) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Error restoring store") + } + + return nil, stacktrace.Propagate(err, "Error updating ISA") + } + + cells := s2.CellUnionFromUnion(old.Cells, isa.Cells) + geo.Levelify(&cells) + subs, err := mem.UpdateNotificationIdxsInCells(ctx, cells) + if err != nil { + restoreErr := r.memStore.Restore(checkpoint) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Error restoring store") + } + + return nil, stacktrace.Propagate(err, "Error updating notification indices") + } + + return &ISATransactionResult{Ret: ret, Subs: subs}, nil +} diff --git a/pkg/rid/store/raftstore/params/params.go b/pkg/rid/store/raftstore/params/params.go new file mode 100644 index 000000000..272b6e789 --- /dev/null +++ b/pkg/rid/store/raftstore/params/params.go @@ -0,0 +1,29 @@ +package params + +import ( + "flag" + + raftparams "github.com/interuss/dss/pkg/raftstore/params" + "github.com/interuss/stacktrace" +) + +const peersFlag = "rid_raft_peers" + +var peers string + +func init() { + flag.StringVar(&peers, peersFlag, "", `Comma-separated "nodeID=peerURL" pairs for the rid store, e.g. "1=http://node1:9011,2=http://node2:9011,3=http://node3:9011"`) +} + +func GetConnectParameters() (raftparams.ConnectParameters, error) { + if peers == "" { + return raftparams.ConnectParameters{}, stacktrace.NewError("--%s is required", peersFlag) + } + + p, err := raftparams.GetConnectParameters("rid") + if err != nil { + return raftparams.ConnectParameters{}, err + } + p.Peers = peers + return p, nil +} diff --git a/pkg/rid/store/raftstore/store.go b/pkg/rid/store/raftstore/store.go index 3022c1251..a0c4a8447 100644 --- a/pkg/rid/store/raftstore/store.go +++ b/pkg/rid/store/raftstore/store.go @@ -2,15 +2,250 @@ package raftstore import ( "context" + "encoding/json" + "slices" + "github.com/golang/geo/s2" + "github.com/interuss/dss/pkg/memstore" + dssmodels "github.com/interuss/dss/pkg/models" "github.com/interuss/dss/pkg/raftstore" + "github.com/interuss/dss/pkg/raftstore/consensus" + ridmodels "github.com/interuss/dss/pkg/rid/models" "github.com/interuss/dss/pkg/rid/repos" + ridmemstore "github.com/interuss/dss/pkg/rid/store/memstore" + ridraftparams "github.com/interuss/dss/pkg/rid/store/raftstore/params" + "github.com/interuss/stacktrace" "go.uber.org/zap" ) +const storeID = "rid" + +const ( + getISA raftstore.RequestType = "getISA" + deleteISA raftstore.RequestType = "deleteISA" + insertISA raftstore.RequestType = "insertISA" + updateISA raftstore.RequestType = "updateISA" + searchISAs raftstore.RequestType = "searchISAs" + listExpiredISAs raftstore.RequestType = "listExpiredISAs" + countISAs raftstore.RequestType = "countISAs" + + DeleteISATransaction raftstore.RequestType = "deleteISATransaction" + InsertISATransaction raftstore.RequestType = "insertISATransaction" + UpdateISATransaction raftstore.RequestType = "updateISATransaction" + + getSubscription raftstore.RequestType = "getSubscription" + deleteSubscription raftstore.RequestType = "deleteSubscription" + insertSubscription raftstore.RequestType = "insertSubscription" + updateSubscription raftstore.RequestType = "updateSubscription" + searchSubscriptions raftstore.RequestType = "searchSubscriptions" + searchSubscriptionsByOwner raftstore.RequestType = "searchSubscriptionsByOwner" + updateNotificationIdxsInCells raftstore.RequestType = "updateNotificationIdxsInCells" + maxSubscriptionCountInCellsByOwner raftstore.RequestType = "maxSubscriptionCountInCellsByOwner" + listExpiredSubscriptions raftstore.RequestType = "listExpiredSubscriptions" + countSubscriptions raftstore.RequestType = "countSubscriptions" + + DeleteSubscriptionTransaction raftstore.RequestType = "deleteSubscriptionTransaction" + InsertSubscriptionTransaction raftstore.RequestType = "insertSubscriptionTransaction" + UpdateSubscriptionTransaction raftstore.RequestType = "updateSubscriptionTransaction" +) + +var readOnlyRequests = []raftstore.RequestType{ + getISA, + searchISAs, + listExpiredISAs, + countISAs, + + getSubscription, + searchSubscriptions, + searchSubscriptionsByOwner, + maxSubscriptionCountInCellsByOwner, + listExpiredSubscriptions, + countSubscriptions, +} + // repo is a full implementation of rid.repos.Repository for Raft-based storage. -type repo struct{} +type repo struct { + consensus *consensus.Consensus + memStore *memstore.Store[repos.Repository] +} func Init(ctx context.Context, logger *zap.Logger) (*raftstore.Store[repos.Repository], error) { - return raftstore.Init[repos.Repository](ctx, logger, func() repos.Repository { return &repo{} }) + params, err := ridraftparams.GetConnectParameters() + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get rid raft parameters") + } + + memStore, err := ridmemstore.Init(ctx, logger) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to initialize RID memstore") + } + + r := &repo{memStore: memStore} + store, err := raftstore.Init(ctx, logger, params, r) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to initialize RID raftstore") + } + + r.consensus = store.Consensus + + return store, nil +} + +func (r *repo) GetRepo() repos.Repository { return r } + +func (r *repo) IsReadOnly(requestType raftstore.RequestType) bool { + return slices.Contains(readOnlyRequests, requestType) +} + +func (r *repo) GetSnapshot() ([]byte, error) { + return r.memStore.GetSnapshot() +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + return r.memStore.RestoreFromSnapshot(data) +} + +func (r *repo) Apply(ctx context.Context, proposal consensus.Proposal) (any, error) { + mem, err := r.memStore.Interact(ctx) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to obtain rid memstore repository") + } + + switch raftstore.RequestType(proposal.RequestType) { + // ISAs + + case getISA: + var p getISAPayload + if err := json.Unmarshal(proposal.Value, &p); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", getISA) + } + return mem.GetISA(ctx, p.ID, p.ForUpdate) + + case deleteISA: + var isa ridmodels.IdentificationServiceArea + if err := json.Unmarshal(proposal.Value, &isa); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", deleteISA) + } + return mem.DeleteISA(ctx, &isa) + + case insertISA: + var isa ridmodels.IdentificationServiceArea + if err := json.Unmarshal(proposal.Value, &isa); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", insertISA) + } + return mem.InsertISA(ctx, &isa) + + case updateISA: + var isa ridmodels.IdentificationServiceArea + if err := json.Unmarshal(proposal.Value, &isa); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", updateISA) + } + return mem.UpdateISA(ctx, &isa) + + case searchISAs: + var p searchISAsPayload + if err := json.Unmarshal(proposal.Value, &p); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", searchISAs) + } + return mem.SearchISAs(ctx, p.Cells, p.Earliest, p.Latest) + + case listExpiredISAs: + var p listExpiredISAsPayload + if err := json.Unmarshal(proposal.Value, &p); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", listExpiredISAs) + } + return mem.ListExpiredISAs(ctx, p.Writer, p.Threshold) + + case countISAs: + return mem.CountISAs(ctx) + + case DeleteISATransaction: + return r.deleteISATransactionApplier(ctx, proposal, mem) + + case InsertISATransaction: + return r.insertISATransactionApplier(ctx, proposal, mem) + + case UpdateISATransaction: + return r.updateISATransactionApplier(ctx, proposal, mem) + + // Subscriptions + + case getSubscription: + var id dssmodels.ID + if err := json.Unmarshal(proposal.Value, &id); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", getSubscription) + } + return mem.GetSubscription(ctx, id) + + case deleteSubscription: + var sub ridmodels.Subscription + if err := json.Unmarshal(proposal.Value, &sub); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", deleteSubscription) + } + return mem.DeleteSubscription(ctx, &sub) + + case insertSubscription: + var sub ridmodels.Subscription + if err := json.Unmarshal(proposal.Value, &sub); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", insertSubscription) + } + return mem.InsertSubscription(ctx, &sub) + + case updateSubscription: + var sub ridmodels.Subscription + if err := json.Unmarshal(proposal.Value, &sub); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", updateSubscription) + } + return mem.UpdateSubscription(ctx, &sub) + + case searchSubscriptions: + var cells s2.CellUnion + if err := json.Unmarshal(proposal.Value, &cells); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", searchSubscriptions) + } + return mem.SearchSubscriptions(ctx, cells) + + case searchSubscriptionsByOwner: + var p searchSubscriptionsByOwnerPayload + if err := json.Unmarshal(proposal.Value, &p); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", searchSubscriptionsByOwner) + } + return mem.SearchSubscriptionsByOwner(ctx, p.Cells, p.Owner) + + case updateNotificationIdxsInCells: + var cells s2.CellUnion + if err := json.Unmarshal(proposal.Value, &cells); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", updateNotificationIdxsInCells) + } + return mem.UpdateNotificationIdxsInCells(ctx, cells) + + case maxSubscriptionCountInCellsByOwner: + var p maxSubscriptionCountInCellsByOwnerPayload + if err := json.Unmarshal(proposal.Value, &p); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", maxSubscriptionCountInCellsByOwner) + } + return mem.MaxSubscriptionCountInCellsByOwner(ctx, p.Cells, p.Owner) + + case listExpiredSubscriptions: + var p listExpiredSubscriptionsPayload + if err := json.Unmarshal(proposal.Value, &p); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", listExpiredSubscriptions) + } + return mem.ListExpiredSubscriptions(ctx, p.Writer, p.Threshold) + + case countSubscriptions: + return mem.CountSubscriptions(ctx) + + case DeleteSubscriptionTransaction: + return r.deleteSubscriptionTransactionApplier(ctx, proposal, mem) + + case InsertSubscriptionTransaction: + return r.insertSubscriptionTransactionApplier(ctx, proposal, mem) + + case UpdateSubscriptionTransaction: + return r.updateSubscriptionTransactionApplier(ctx, proposal, mem) + + default: + return nil, stacktrace.NewError("unknown request type: %q", proposal.RequestType) + } } diff --git a/pkg/rid/store/raftstore/subscriptions.go b/pkg/rid/store/raftstore/subscriptions.go index 0c9c261dd..1ab57362f 100644 --- a/pkg/rid/store/raftstore/subscriptions.go +++ b/pkg/rid/store/raftstore/subscriptions.go @@ -5,48 +5,202 @@ import ( "time" "github.com/golang/geo/s2" - dsserr "github.com/interuss/dss/pkg/errors" dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" "github.com/interuss/stacktrace" ) -func (r *repo) GetSubscription(_ context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetSubscription not implemented for raftstore") +func (r *repo) GetSubscription(ctx context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, getSubscription, id, true) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + sub, ok := result.(*ridmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return sub, nil +} + +func (r *repo) DeleteSubscription(ctx context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, deleteSubscription, sub, false) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + out, ok := result.(*ridmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return out, nil +} + +func (r *repo) InsertSubscription(ctx context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, insertSubscription, sub, false) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + out, ok := result.(*ridmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return out, nil +} + +func (r *repo) UpdateSubscription(ctx context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, updateSubscription, sub, false) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + out, ok := result.(*ridmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return out, nil } -func (r *repo) DeleteSubscription(_ context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteSubscription not implemented for raftstore") +func (r *repo) SearchSubscriptions(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, searchSubscriptions, cells, true) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + out, ok := result.([]*ridmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return out, nil } -func (r *repo) InsertSubscription(_ context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "InsertSubscription not implemented for raftstore") +type searchSubscriptionsByOwnerPayload struct { + Cells s2.CellUnion `json:"cells"` + Owner dssmodels.Owner `json:"owner"` } -func (r *repo) UpdateSubscription(_ context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpdateSubscription not implemented for raftstore") +func (r *repo) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, searchSubscriptionsByOwner, &searchSubscriptionsByOwnerPayload{Cells: cells, Owner: owner}, true) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + out, ok := result.([]*ridmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return out, nil } -func (r *repo) SearchSubscriptions(_ context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchSubscriptions not implemented for raftstore") +func (r *repo) UpdateNotificationIdxsInCells(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, updateNotificationIdxsInCells, cells, false) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + out, ok := result.([]*ridmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return out, nil } -func (r *repo) SearchSubscriptionsByOwner(_ context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchSubscriptionsByOwner not implemented for raftstore") +type maxSubscriptionCountInCellsByOwnerPayload struct { + Cells s2.CellUnion `json:"cells"` + Owner dssmodels.Owner `json:"owner"` } -func (r *repo) UpdateNotificationIdxsInCells(_ context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpdateNotificationIdxsInCells not implemented for raftstore") +func (r *repo) MaxSubscriptionCountInCellsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { + result, err := r.consensus.ProposeValue(ctx, maxSubscriptionCountInCellsByOwner, &maxSubscriptionCountInCellsByOwnerPayload{Cells: cells, Owner: owner}, true) + if err != nil { + return 0, err + } + + if result == nil { + return 0, nil + } + + count, ok := result.(int) + if !ok { + return 0, stacktrace.NewError("invalid result type: %T", result) + } + + return count, nil } -func (r *repo) MaxSubscriptionCountInCellsByOwner(_ context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { - return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "MaxSubscriptionCountInCellsByOwner not implemented for raftstore") +type listExpiredSubscriptionsPayload struct { + Writer string `json:"writer"` + Threshold time.Time `json:"threshold"` } -func (r *repo) ListExpiredSubscriptions(_ context.Context, writer string, threshold time.Time) ([]*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredSubscriptions not implemented for raftstore") +func (r *repo) ListExpiredSubscriptions(ctx context.Context, writer string, threshold time.Time) ([]*ridmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, listExpiredSubscriptions, &listExpiredSubscriptionsPayload{Writer: writer, Threshold: threshold}, true) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + out, ok := result.([]*ridmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return out, nil } -func (r *repo) CountSubscriptions(_ context.Context) (int64, error) { - return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountSubscriptions not implemented for raftstore") +func (r *repo) CountSubscriptions(ctx context.Context) (int64, error) { + result, err := r.consensus.ProposeValue(ctx, countSubscriptions, nil, true) + if err != nil { + return 0, err + } + + if result == nil { + return 0, nil + } + + count, ok := result.(int64) + if !ok { + return 0, stacktrace.NewError("invalid result type: %T", result) + } + + return count, nil } diff --git a/pkg/rid/store/raftstore/subscriptions_appliers.go b/pkg/rid/store/raftstore/subscriptions_appliers.go new file mode 100644 index 000000000..021aac9f8 --- /dev/null +++ b/pkg/rid/store/raftstore/subscriptions_appliers.go @@ -0,0 +1,141 @@ +package raftstore + +import ( + "context" + "encoding/json" + + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + "github.com/interuss/dss/pkg/raftstore/consensus" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/rid/repos" + "github.com/interuss/stacktrace" +) + +func (r *repo) insertSubscriptionTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*ridmodels.Subscription, error) { + var payload *ridmodels.Subscription + if err := json.Unmarshal(proposal.Value, &payload); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", insertSubscription) + } + + old, err := mem.GetSubscription(ctx, payload.ID) + if err != nil { + return nil, stacktrace.Propagate(err, "error getting Subscription from repo") + } + if old != nil { + return nil, stacktrace.NewErrorWithCode(dsserr.AlreadyExists, "Subscription %s already exists", payload.ID) + } + + count, err := mem.MaxSubscriptionCountInCellsByOwner(ctx, payload.Cells, payload.Owner) + if err != nil { + return nil, stacktrace.Propagate(err, + "Failed to fetch subscription count, rejecting request") + } + if count >= ridmodels.MaxSubscriptionsPerArea { + return nil, stacktrace.Propagate( + stacktrace.NewErrorWithCode(dsserr.Exhausted, "too many existing subscriptions in this area already"), + "%s had %d subscriptions in the area", payload.Owner, count) + } + + checkpoint := r.memStore.Checkpoint() + ret, err := mem.InsertSubscription(ctx, payload) + if err != nil { + restoreErr := r.memStore.Restore(checkpoint) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Error restoring store") + } + + return nil, stacktrace.Propagate(err, "Error inserting Subscription") + } + return ret, nil +} + +func (r *repo) updateSubscriptionTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*ridmodels.Subscription, error) { + var payload *ridmodels.Subscription + if err := json.Unmarshal(proposal.Value, &payload); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", updateSubscription) + } + + old, err := mem.GetSubscription(ctx, payload.ID) + switch { + case err != nil: + return nil, stacktrace.Propagate(err, "Error getting Subscription from repo") + case old == nil: + return nil, stacktrace.NewErrorWithCode(dsserr.NotFound, "Subscription %s not found", payload.ID.String()) + case !payload.Version.Matches(old.Version): + return nil, stacktrace.Propagate( + stacktrace.NewErrorWithCode(dsserr.VersionMismatch, "Subscription version %s is not current", payload.Version), + "Subscription currently at version %s but client specified %s", old.Version, payload.Version) + case old.Owner != payload.Owner: + return nil, stacktrace.Propagate( + stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Subscription is owned by different client"), + "Subscription owned by %s, but %s attempted to update", old.Owner, payload.Owner) + } + if err := payload.AdjustTimeRange(proposal.Timestamp, old); err != nil { + return nil, stacktrace.Propagate(err, "Error adjusting time range") + } + + count, err := mem.MaxSubscriptionCountInCellsByOwner(ctx, payload.Cells, payload.Owner) + if err != nil { + return nil, stacktrace.Propagate(err, + "Failed to fetch subscription count, rejecting request") + } + if count >= ridmodels.MaxSubscriptionsPerArea { + return nil, stacktrace.Propagate( + stacktrace.NewErrorWithCode(dsserr.Exhausted, "Too many existing subscriptions in this area already"), + "%s had %d subscriptions in the area", payload.Owner, count) + } + + checkpoint := r.memStore.Checkpoint() + ret, err := mem.UpdateSubscription(ctx, payload) + if err != nil { + restoreErr := r.memStore.Restore(checkpoint) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Error restoring store") + } + + return nil, stacktrace.Propagate(err, "Error updating Subscription") + } + return ret, nil +} + +type DeleteSubscriptionPayload struct { + ID dssmodels.ID `json:"id"` + Owner dssmodels.Owner `json:"owner"` + Version *dssmodels.Version `json:"version"` +} + +func (r *repo) deleteSubscriptionTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*ridmodels.Subscription, error) { + var payload *DeleteSubscriptionPayload + if err := json.Unmarshal(proposal.Value, &payload); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s payload", deleteSubscription) + } + + old, err := mem.GetSubscription(ctx, payload.ID) + switch { + case err != nil: + return nil, stacktrace.Propagate(err, "Error getting Subscription from repo") + case old == nil: + return nil, stacktrace.NewErrorWithCode(dsserr.NotFound, "Subscription %s not found", payload.ID.String()) + case !payload.Version.Matches(old.Version): + return nil, stacktrace.Propagate( + stacktrace.NewErrorWithCode(dsserr.VersionMismatch, "Subscription version %s is not current", payload.Version), + "Subscription currently at version %s but client specified %s", old.Version, payload.Version) + case old.Owner != payload.Owner: + return nil, stacktrace.Propagate( + stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Subscription is owned by different client"), + "Subscription owned by %s, but %s attempted to delete", old.Owner, payload.Owner) + } + + checkpoint := r.memStore.Checkpoint() + ret, err := mem.DeleteSubscription(ctx, old) + if err != nil { + restoreErr := r.memStore.Restore(checkpoint) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Error restoring store") + } + + return nil, stacktrace.Propagate(err, "Error deleting Subscription") + } + return ret, nil +} diff --git a/pkg/rid/store/sqlstore/store_test.go b/pkg/rid/store/sqlstore/store_test.go index 275a1a332..3a57d20da 100644 --- a/pkg/rid/store/sqlstore/store_test.go +++ b/pkg/rid/store/sqlstore/store_test.go @@ -97,7 +97,7 @@ func TestTxnRetrier(t *testing.T) { require.NotNil(t, store) defer tearDownStore() - err := store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err := store.Transact(ctx, "", nil, func(ctx context.Context, repo repos.Repository) error { // can query within this isa, err := repo.InsertISA(ctx, serviceArea) require.NotNil(t, isa) @@ -118,7 +118,7 @@ func TestTxnRetrier(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) defer cancel() count := 0 - err = store.Transact(ctx, func(ctx context.Context, repo repos.Repository) error { + _, err = store.Transact(ctx, "", nil, func(ctx context.Context, repo repos.Repository) error { // can query within this count++ // Postgre retryable error @@ -141,13 +141,13 @@ func TestTransactor(t *testing.T) { subscription2 := subscriptionsPool[1].input txnCount := 0 - err := store.Transact(ctx, func(ctx context.Context, s1 repos.Repository) error { + _, err := store.Transact(ctx, "", nil, func(ctx context.Context, s1 repos.Repository) error { // We should get to this retry, then return nothing. if txnCount > 0 { return errors.New("already failed") } txnCount++ - err := store.Transact(ctx, func(ctx context.Context, s2 repos.Repository) error { + _, err := store.Transact(ctx, "", nil, func(ctx context.Context, s2 repos.Repository) error { subs, err := s1.SearchSubscriptions(ctx, subscription1.Cells) require.NoError(t, err) require.Len(t, subs, 0) diff --git a/pkg/rid/store/store.go b/pkg/rid/store/store.go index aa561a092..7f2dc6c20 100644 --- a/pkg/rid/store/store.go +++ b/pkg/rid/store/store.go @@ -4,6 +4,7 @@ import ( "context" "github.com/interuss/dss/pkg/rid/repos" + ridmemstore "github.com/interuss/dss/pkg/rid/store/memstore" ridraftstore "github.com/interuss/dss/pkg/rid/store/raftstore" ridsqlstore "github.com/interuss/dss/pkg/rid/store/sqlstore" dssstore "github.com/interuss/dss/pkg/store" @@ -24,6 +25,8 @@ func Init(ctx context.Context, logger *zap.Logger, withCheckCron bool) (Store, e return ridsqlstore.Init(ctx, logger, withCheckCron) case params.RaftStoreType: return ridraftstore.Init(ctx, logger) + case params.MemStoreType: + return ridmemstore.Init(ctx, logger) default: return nil, stacktrace.NewError("Unsupported store type %q for rid", storeType) } diff --git a/pkg/scd/constraints_handler.go b/pkg/scd/constraints_handler.go index 0688176c7..965cf60c9 100644 --- a/pkg/scd/constraints_handler.go +++ b/pkg/scd/constraints_handler.go @@ -11,6 +11,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/dss/pkg/scd/repos" + scdraftstore "github.com/interuss/dss/pkg/scd/store/raftstore" "github.com/interuss/stacktrace" "github.com/jackc/pgx/v5" ) @@ -82,13 +83,13 @@ func (a *Server) DeleteConstraintReference(ctx context.Context, req *restapi.Del // Return response to client response = &restapi.ChangeConstraintReferenceResponse{ ConstraintReference: *old.ToRest(), - Subscribers: makeSubscribersToNotify(subs), + Subscribers: repos.MakeSubscribersToNotify(subs), } return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.DeleteConstraintTransaction, req, action) if err != nil { err = stacktrace.Propagate(err, "Could not delete constraint") errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -106,6 +107,14 @@ func (a *Server) DeleteConstraintReference(ctx context.Context, req *restapi.Del ErrorMessage: *dsserr.Handle(ctx, stacktrace.Propagate(err, "Got an unexpected error"))}} } } + if raftResult != nil { + deleteConstraintResponse, ok := raftResult.(*restapi.ChangeConstraintReferenceResponse) + if !ok { + return restapi.DeleteConstraintReferenceResponseSet{Response500: &api.InternalServerErrorBody{ + ErrorMessage: *dsserr.Handle(ctx, stacktrace.NewError("invalid result type"))}} + } + response = deleteConstraintResponse + } return restapi.DeleteConstraintReferenceResponseSet{Response200: response} } @@ -147,7 +156,7 @@ func (a *Server) GetConstraintReference(ctx context.Context, req *restapi.GetCon return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.GetConstraintTransaction, req, action) if err != nil { err = stacktrace.Propagate(err, "Could not get constraint") if stacktrace.GetCode(err) == dsserr.NotFound { @@ -156,6 +165,14 @@ func (a *Server) GetConstraintReference(ctx context.Context, req *restapi.GetCon return restapi.GetConstraintReferenceResponseSet{Response500: &api.InternalServerErrorBody{ ErrorMessage: *dsserr.Handle(ctx, stacktrace.Propagate(err, "Got an unexpected error"))}} } + if raftResult != nil { + getConstraintResponse, ok := raftResult.(*restapi.GetConstraintReferenceResponse) + if !ok { + return restapi.GetConstraintReferenceResponseSet{Response500: &api.InternalServerErrorBody{ + ErrorMessage: *dsserr.Handle(ctx, stacktrace.NewError("invalid result type"))}} + } + response = getConstraintResponse + } return restapi.GetConstraintReferenceResponseSet{Response200: response} } @@ -233,6 +250,18 @@ func (a *Server) PutConstraintReference(ctx context.Context, manager string, ent return nil, stacktrace.PropagateWithCode(err, dsserr.BadRequest, "Failed to validate Constraint upsert parameters") } + payload := &scdraftstore.UpsertConstraintTransactionPayload{ + Manager: dssmodels.Manager(manager), + ID: validParams.id, + Ovn: scdmodels.OVN(ovn), + USSBaseURL: validParams.ussBaseURL, + StartTime: validParams.uExtent.StartTime, + EndTime: validParams.uExtent.EndTime, + AltitudeLo: validParams.uExtent.SpatialVolume.AltitudeLo, + AltitudeHi: validParams.uExtent.SpatialVolume.AltitudeHi, + Cells: validParams.cells, + } + var response *restapi.ChangeConstraintReferenceResponse action := func(ctx context.Context, r repos.Repository) (err error) { version := scdmodels.VersionNumber(1) @@ -300,17 +329,25 @@ func (a *Server) PutConstraintReference(ctx context.Context, manager string, ent // Return response to client response = &restapi.ChangeConstraintReferenceResponse{ ConstraintReference: *constraint.ToRest(), - Subscribers: makeSubscribersToNotify(subs), + Subscribers: repos.MakeSubscribersToNotify(subs), } return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.UpsertConstraintTransaction, payload, action) if err != nil { return nil, err // No need to Propagate this error as this is not a useful stacktrace line } + if raftResult != nil { + upsertConstraintResponse, ok := raftResult.(*restapi.ChangeConstraintReferenceResponse) + if !ok { + return nil, stacktrace.NewError("invalid result type") + } + response = upsertConstraintResponse + } + return response, nil } @@ -434,11 +471,19 @@ func (a *Server) QueryConstraintReferences(ctx context.Context, req *restapi.Que return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.QueryConstraintTransaction, req, action) if err != nil { return restapi.QueryConstraintReferencesResponseSet{Response500: &api.InternalServerErrorBody{ ErrorMessage: *dsserr.Handle(ctx, stacktrace.Propagate(err, "Got an unexpected error"))}} } + if raftResult != nil { + queryConstraintResponse, ok := raftResult.(*restapi.QueryConstraintReferencesResponse) + if !ok { + return restapi.QueryConstraintReferencesResponseSet{Response500: &api.InternalServerErrorBody{ + ErrorMessage: *dsserr.Handle(ctx, stacktrace.NewError("invalid result type"))}} + } + response = queryConstraintResponse + } return restapi.QueryConstraintReferencesResponseSet{Response200: response} } diff --git a/pkg/scd/operational_intents_handler.go b/pkg/scd/operational_intents_handler.go index f611797d0..bf96b5517 100644 --- a/pkg/scd/operational_intents_handler.go +++ b/pkg/scd/operational_intents_handler.go @@ -13,40 +13,10 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/dss/pkg/scd/repos" + scdraftstore "github.com/interuss/dss/pkg/scd/store/raftstore" "github.com/interuss/stacktrace" ) -// subscriptionIsImplicitAndOnlyAttachedToOIR will check if: -// - the subscription is defined and is implicit -// - the subscription is attached to the specified operational intent -// - the subscription is not attached to any other operational intent -// -// This is to be used in contexts where an implicit subscription may need to be cleaned up: if true is returned, -// the subscription can be safely removed after the operational intent is deleted or attached to another subscription. -// -// NOTE: this should eventually be pushed down the datastore as part of the queries being executed in the callers of this method. -// -// See https://github.com/interuss/dss/issues/1059 for more details -func subscriptionIsImplicitAndOnlyAttachedToOIR(ctx context.Context, r repos.Repository, oirID dssmodels.ID, subscription *scdmodels.Subscription) (bool, error) { - if subscription == nil { - return false, nil - } - if !subscription.ImplicitSubscription { - return false, nil - } - // Get the Subscription's dependent OperationalIntents - dependentOps, err := r.GetDependentOperationalIntents(ctx, subscription.ID) - if err != nil { - return false, stacktrace.Propagate(err, "Could not find dependent OperationalIntents") - } - if len(dependentOps) == 0 { - return false, stacktrace.NewError("An implicit Subscription had no dependent OperationalIntents") - } else if len(dependentOps) == 1 && dependentOps[0] == oirID { - return true, nil - } - return false, nil -} - // DeleteOperationalIntentReference deletes a single operational intent ref for a given ID at // the specified version. func (a *Server) DeleteOperationalIntentReference(ctx context.Context, req *restapi.DeleteOperationalIntentReferenceRequest, @@ -120,7 +90,7 @@ func (a *Server) DeleteOperationalIntentReference(ctx context.Context, req *rest } } - removeImplicitSubscription, err := subscriptionIsImplicitAndOnlyAttachedToOIR(ctx, r, id, previousSubscription) + removeImplicitSubscription, err := repos.SubscriptionIsImplicitAndOnlyAttachedToOIR(ctx, r, id, previousSubscription) if err != nil { return stacktrace.Propagate(err, "Could not determine if Subscription can be removed") } @@ -137,7 +107,7 @@ func (a *Server) DeleteOperationalIntentReference(ctx context.Context, req *rest }), }} - subsToNotify, err := getRelevantSubscriptionsAndIncrementIndices(ctx, r, notifyVolume) + subsToNotify, err := repos.GetRelevantSubscriptionsAndIncrementIndices(ctx, r, notifyVolume) if err != nil { return stacktrace.Propagate(err, "could not obtain relevant subscriptions") } @@ -159,13 +129,13 @@ func (a *Server) DeleteOperationalIntentReference(ctx context.Context, req *rest // Return response to client response = &restapi.ChangeOperationalIntentReferenceResponse{ OperationalIntentReference: *old.ToRest(), - Subscribers: makeSubscribersToNotify(subsToNotify), + Subscribers: repos.MakeSubscribersToNotify(subsToNotify), } return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.DeleteOperationalIntentTransaction, req, action) if err != nil { err = stacktrace.Propagate(err, "Could not delete operational intent") errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -182,6 +152,15 @@ func (a *Server) DeleteOperationalIntentReference(ctx context.Context, req *rest } } + if raftResult != nil { + deleteOIRResponse, ok := raftResult.(*restapi.ChangeOperationalIntentReferenceResponse) + if !ok { + return restapi.DeleteOperationalIntentReferenceResponseSet{Response500: &api.InternalServerErrorBody{ + ErrorMessage: *dsserr.Handle(ctx, stacktrace.NewError("invalid result type"))}} + } + response = deleteOIRResponse + } + return restapi.DeleteOperationalIntentReferenceResponseSet{Response200: response} } @@ -221,7 +200,7 @@ func (a *Server) GetOperationalIntentReference(ctx context.Context, req *restapi return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.GetOperationalIntentTransaction, req, action) if err != nil { err = stacktrace.Propagate(err, "Could not get operational intent") if stacktrace.GetCode(err) == dsserr.NotFound { @@ -231,6 +210,16 @@ func (a *Server) GetOperationalIntentReference(ctx context.Context, req *restapi ErrorMessage: *dsserr.Handle(ctx, stacktrace.Propagate(err, "Got an unexpected error"))}} } + if raftResult != nil { + getOIRResponse, ok := raftResult.(*restapi.GetOperationalIntentReferenceResponse) + if !ok { + return restapi.GetOperationalIntentReferenceResponseSet{Response500: &api.InternalServerErrorBody{ + ErrorMessage: *dsserr.Handle(ctx, stacktrace.NewError("invalid result type"))}} + } + + response = getOIRResponse + } + return restapi.GetOperationalIntentReferenceResponseSet{Response200: response} } @@ -288,7 +277,7 @@ func (a *Server) QueryOperationalIntentReferences(ctx context.Context, req *rest return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.QueryOperationalIntentTransaction, req, action) if err != nil { err = stacktrace.Propagate(err, "Could not query operational intent") if stacktrace.GetCode(err) == dsserr.BadRequest { @@ -298,6 +287,16 @@ func (a *Server) QueryOperationalIntentReferences(ctx context.Context, req *rest ErrorMessage: *dsserr.Handle(ctx, stacktrace.Propagate(err, "Got an unexpected error"))}} } + if raftResult != nil { + queryOIRResponse, ok := raftResult.(*restapi.QueryOperationalIntentReferenceResponse) + if !ok { + return restapi.QueryOperationalIntentReferencesResponseSet{Response500: &api.InternalServerErrorBody{ + ErrorMessage: *dsserr.Handle(ctx, stacktrace.NewError("Invalid result type"))}} + } + + response = queryOIRResponse + } + return restapi.QueryOperationalIntentReferencesResponseSet{Response200: response} } @@ -363,51 +362,6 @@ func (a *Server) UpdateOperationalIntentReference(ctx context.Context, req *rest return restapi.UpdateOperationalIntentReferenceResponseSet{Response200: respOK} } -type validOIRParams struct { - id dssmodels.ID - ovn scdmodels.OVN - newOVN scdmodels.OVN - state scdmodels.OperationalIntentState - uExtent *dssmodels.Volume4D - cells s2.CellUnion - subscriptionID dssmodels.ID - ussBaseURL string - implicitSubscription struct { - requested bool - baseURL string - forConstraints bool - } - key map[scdmodels.OVN]bool -} - -func (vp *validOIRParams) toOIR(manager dssmodels.Manager, attachedSub *scdmodels.Subscription, version scdmodels.VersionNumber, pastOVNs []scdmodels.OVN) *scdmodels.OperationalIntent { - // For OIR's in the accepted state, we may not have a attachedSub available, - // in such cases the attachedSub ID on scdmodels.OperationalIntent will be nil - // and will be replaced with the 'NullV4UUID' when sent over to a client. - var subID *dssmodels.ID - if attachedSub != nil { - // Note: do _not_ use vp.subscriptionID here, as it may be empty - subID = &attachedSub.ID - } - return &scdmodels.OperationalIntent{ - ID: vp.id, - Manager: manager, - Version: version, - OVN: vp.newOVN, // non-empty only if the USS has requested an OVN - PastOVNs: pastOVNs, - - StartTime: vp.uExtent.StartTime, - EndTime: vp.uExtent.EndTime, - AltitudeLower: vp.uExtent.SpatialVolume.AltitudeLo, - AltitudeUpper: vp.uExtent.SpatialVolume.AltitudeHi, - Cells: vp.cells, - - USSBaseURL: vp.ussBaseURL, - SubscriptionID: subID, - State: vp.state, - } -} - // validateAndReturnOIRUpsertParams checks that the parameters for an Operational Intent Reference upsert are valid. // Note that this does NOT check for anything related to access controls: any error returned should be labeled // as a dsserr.BadRequest. @@ -417,12 +371,12 @@ func validateAndReturnOIRUpsertParams( ovn restapi.EntityOVN, params *restapi.PutOperationalIntentReferenceParameters, allowHTTPBaseUrls bool, -) (*validOIRParams, error) { +) (*repos.ValidOIRParams, error) { - valid := &validOIRParams{} + valid := &repos.ValidOIRParams{} var err error - valid.id, err = dssmodels.IDFromString(string(entityid)) + valid.ID, err = dssmodels.IDFromString(string(entityid)) if err != nil { return nil, stacktrace.NewError("Invalid ID format: `%s`", entityid) } @@ -431,10 +385,10 @@ func validateAndReturnOIRUpsertParams( return nil, stacktrace.NewError("Missing required UssBaseUrl") } - valid.ussBaseURL = string(params.UssBaseUrl) + valid.USSBaseURL = string(params.UssBaseUrl) if params.SubscriptionId != nil { - valid.subscriptionID, err = dssmodels.IDFromOptionalString(string(*params.SubscriptionId)) + valid.SubscriptionID, err = dssmodels.IDFromOptionalString(string(*params.SubscriptionId)) if err != nil { return nil, stacktrace.NewError("Invalid ID format for Subscription ID: `%s`", *params.SubscriptionId) } @@ -450,11 +404,12 @@ func validateAndReturnOIRUpsertParams( if params.SubscriptionId != nil { return nil, stacktrace.NewError("Cannot provide both a Subscription ID and request an implicit subscription") } - valid.implicitSubscription.requested = true - valid.implicitSubscription.baseURL = string(params.NewSubscription.UssBaseUrl) + valid.ImplicitSubscription.Requested = true + valid.ImplicitSubscription.ID = dssmodels.ID(uuid.New().String()) + valid.ImplicitSubscription.BaseURL = string(params.NewSubscription.UssBaseUrl) // notify for constraints defaults to false if not specified if params.NewSubscription.NotifyForConstraints != nil { - valid.implicitSubscription.forConstraints = *params.NewSubscription.NotifyForConstraints + valid.ImplicitSubscription.ForConstraints = *params.NewSubscription.NotifyForConstraints } } @@ -465,21 +420,21 @@ func validateAndReturnOIRUpsertParams( } if params.NewSubscription != nil { - err := scdmodels.ValidateUSSBaseURL(valid.implicitSubscription.baseURL) + err := scdmodels.ValidateUSSBaseURL(valid.ImplicitSubscription.BaseURL) if err != nil { return nil, stacktrace.Propagate(err, "Failed to validate USS base URL for subscription (in parameters for implicit subscription)") } } } - valid.state = scdmodels.OperationalIntentState(params.State) - if !valid.state.IsValidInDSS() { + valid.State = scdmodels.OperationalIntentState(params.State) + if !valid.State.IsValidInDSS() { return nil, stacktrace.NewError("Invalid OperationalIntent state: %s", params.State) } // Start and end times, as well as lower and upper altitudes, are required for each volume // The end time may not be in the past. - valid.uExtent, err = dssmodels.UnionVolumes4DFromSCDRest( + valid.UExtent, err = dssmodels.UnionVolumes4DFromSCDRest( params.Extents, dssmodels.WithRequireTimeBounds(), dssmodels.WithRequireAltitudeBounds(), @@ -488,7 +443,7 @@ func validateAndReturnOIRUpsertParams( if err != nil { return nil, stacktrace.Propagate(err, "Invalid extents") } - valid.cells, err = valid.uExtent.CalculateSpatialCovering() + valid.Cells, err = valid.UExtent.CalculateSpatialCovering() if err != nil { return nil, stacktrace.Propagate(err, "Invalid area") } @@ -496,10 +451,10 @@ func validateAndReturnOIRUpsertParams( if ovn == "" && params.State != restapi.OperationalIntentState_Accepted { return nil, stacktrace.NewError("Invalid state for initial version: `%s`", params.State) } - valid.ovn = scdmodels.OVN(ovn) + valid.OVN = scdmodels.OVN(ovn) if params.RequestedOvnSuffix != nil { - valid.newOVN, err = scdmodels.NewOVNFromUUIDv7Suffix(now, valid.id, string(*params.RequestedOvnSuffix)) + valid.NewOVN, err = scdmodels.NewOVNFromUUIDv7Suffix(now, valid.ID, string(*params.RequestedOvnSuffix)) if err != nil { return nil, stacktrace.Propagate(err, "Invalid requested OVN suffix") } @@ -507,18 +462,18 @@ func validateAndReturnOIRUpsertParams( // Check if a subscription is required for this request: // OIRs in an accepted state do not need a subscription. - if valid.state.RequiresSubscription() && - valid.subscriptionID.Empty() && + if valid.State.RequiresSubscription() && + valid.SubscriptionID.Empty() && (params.NewSubscription == nil || params.NewSubscription.UssBaseUrl == "") { - return nil, stacktrace.NewError("Provided Operational Intent Reference state `%s` requires either a subscription ID or information to create an implicit subscription", valid.state) + return nil, stacktrace.NewError("Provided Operational Intent Reference state `%s` requires either a subscription ID or information to create an implicit subscription", valid.State) } // Construct a hash set of OVNs as the key - valid.key = map[scdmodels.OVN]bool{} + valid.Key = map[scdmodels.OVN]bool{} if params.Key != nil { for _, ovn := range *params.Key { - valid.key[scdmodels.OVN(ovn)] = true + valid.Key[scdmodels.OVN(ovn)] = true } } @@ -537,228 +492,6 @@ func checkUpsertPermissionsAndReturnManager(authorizedManager *api.Authorization return dssmodels.Manager(*authorizedManager.ClientID), nil } -// validateUpsertRequestAgainstPreviousOIR checks that the client requesting an OIR upsert has the necessary permissions and that the request is valid. -// On success, the version of the OIR is returned: -// - upon initial creation (if no previous OIR exists), it is 0 -// - otherwise, it is the version of the previous OIR -func validateUpsertRequestAgainstPreviousOIR( - requestingManager dssmodels.Manager, - providedOVN scdmodels.OVN, - previousOIR *scdmodels.OperationalIntent, -) error { - - if previousOIR != nil { - if previousOIR.Manager != requestingManager { - return stacktrace.NewErrorWithCode(dsserr.PermissionDenied, - "OperationalIntent owned by %s, but %s attempted to modify", previousOIR.Manager, requestingManager) - } - if previousOIR.OVN != providedOVN { - return stacktrace.NewErrorWithCode(dsserr.VersionMismatch, - "Current version is %s but client specified version %s", previousOIR.OVN, providedOVN) - } - - return nil - } - - if providedOVN != "" { - return stacktrace.NewErrorWithCode(dsserr.NotFound, "OperationalIntent does not exist and therefore is not version %s", providedOVN) - } - - return nil -} - -// createAndStoreNewImplicitSubscription will create a brand new implicit subscription based on the provided parameters, -// store it and return it. -func createAndStoreNewImplicitSubscription(ctx context.Context, r repos.Repository, manager dssmodels.Manager, validParams *validOIRParams) (*scdmodels.Subscription, error) { - subToUpsert := scdmodels.Subscription{ - ID: dssmodels.ID(uuid.New().String()), - Manager: manager, - StartTime: validParams.uExtent.StartTime, - EndTime: validParams.uExtent.EndTime, - AltitudeLo: validParams.uExtent.SpatialVolume.AltitudeLo, - AltitudeHi: validParams.uExtent.SpatialVolume.AltitudeHi, - Cells: validParams.cells, - USSBaseURL: validParams.implicitSubscription.baseURL, - NotifyForOperationalIntents: true, - NotifyForConstraints: validParams.implicitSubscription.forConstraints, - ImplicitSubscription: true, - } - - return r.UpsertSubscription(ctx, &subToUpsert) -} - -// computeNotificationVolume computes the volume that needs to be queried for subscriptions -// given the requested extent and the (possibly nil) previous operational intent. -// The returned volume is either the union of the requested extent and the previous OIR's extent, or just the requested extent -// if the previous OIR is nil. -func computeNotificationVolume( - previousOIR *scdmodels.OperationalIntent, - requestedExtent *dssmodels.Volume4D) (*dssmodels.Volume4D, error) { - - if previousOIR == nil { - return requestedExtent, nil - } - - // Compute total affected Volume4D for notification purposes - oldVolume := &dssmodels.Volume4D{ - StartTime: previousOIR.StartTime, - EndTime: previousOIR.EndTime, - SpatialVolume: &dssmodels.Volume3D{ - AltitudeHi: previousOIR.AltitudeUpper, - AltitudeLo: previousOIR.AltitudeLower, - Footprint: dssmodels.GeometryFunc(func() (s2.CellUnion, error) { - return previousOIR.Cells, nil - }), - }, - } - notifyVolume, err := dssmodels.UnionVolumes4D(requestedExtent, oldVolume) - if err != nil { - return nil, stacktrace.Propagate(err, "Error constructing 4D volumes union") - } - - return notifyVolume, nil -} - -// getRelevantSubscriptionsAndIncrementIndices retrieves the subscriptions relevant to the passed volume and increments their notification indices -// before returning them. -func getRelevantSubscriptionsAndIncrementIndices( - ctx context.Context, - r repos.Repository, - notifyVolume *dssmodels.Volume4D, -) (repos.Subscriptions, error) { - - // Find the Subscriptions interested in OperationalIntents and increment their - // notification indices - subs, err := r.IncrementNotificationIndicesForOperationalIntents(ctx, notifyVolume) - - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to increment notification indices of relevant subscriptions") - } - - return subs, nil -} - -// validateKeyAndProvideConflictResponse ensures that the provided key contains all the necessary OVNs relevant for the area covered by the OperationalIntent. -// - If all required keys are provided, (nil, nil) will be returned. -// - If keys are missing, the conflict response to be sent back as well as an error with the dsserr.MissingOVNs code will be returned. -// - In case of any other error, (nil, error) will be returned. -func validateKeyAndProvideConflictResponse( - ctx context.Context, - r repos.Repository, - requestingManager dssmodels.Manager, - params *validOIRParams, - attachedSubscription *scdmodels.Subscription, -) (*restapi.AirspaceConflictResponse, error) { - - // Identify OperationalIntents missing from the key - var missingOps []*scdmodels.OperationalIntent - relevantOps, err := r.SearchOperationalIntents(ctx, params.uExtent) - if err != nil { - return nil, stacktrace.Propagate(err, "Unable to SearchOperations") - } - for _, relevantOp := range relevantOps { - _, ok := params.key[relevantOp.OVN] - // Note: The OIR being mutated does not need to be specified in the key: - if !ok && relevantOp.RequiresKey() && relevantOp.ID != params.id { - missingOps = append(missingOps, relevantOp) - } - } - - // Identify Constraints missing from the key - var missingConstraints []*scdmodels.Constraint - if attachedSubscription != nil && attachedSubscription.NotifyForConstraints { - constraints, err := r.SearchConstraints(ctx, params.uExtent) - if err != nil { - return nil, stacktrace.Propagate(err, "Unable to SearchConstraints") - } - for _, relevantConstraint := range constraints { - if _, ok := params.key[relevantConstraint.OVN]; !ok { - missingConstraints = append(missingConstraints, relevantConstraint) - } - } - } - - // If the client is missing some OVNs, provide the pointers to the - // information they need - if len(missingOps) > 0 || len(missingConstraints) > 0 { - msg := "Current OVNs not provided for one or more OperationalIntents or Constraints" - responseConflict := &restapi.AirspaceConflictResponse{Message: &msg} - - if len(missingOps) > 0 { - responseConflict.MissingOperationalIntents = new([]restapi.OperationalIntentReference) - for _, missingOp := range missingOps { - p := missingOp.ToRest() - // We scrub the OVNs of entities not owned by the requesting manager to make sure - // they have really contacted the managing USS - if missingOp.Manager != requestingManager { - noOvnPhrase := restapi.EntityOVN(scdmodels.NoOvnPhrase) - p.Ovn = &noOvnPhrase - } - *responseConflict.MissingOperationalIntents = append(*responseConflict.MissingOperationalIntents, *p) - } - } - - if len(missingConstraints) > 0 { - responseConflict.MissingConstraints = new([]restapi.ConstraintReference) - for _, missingConstraint := range missingConstraints { - c := missingConstraint.ToRest() - // We scrub the OVNs of entities not owned by the requesting manager to make sure - // they have really contacted the managing USS - if missingConstraint.Manager != requestingManager { - noOvnPhrase := restapi.EntityOVN(scdmodels.NoOvnPhrase) - c.Ovn = &noOvnPhrase - } - *responseConflict.MissingConstraints = append(*responseConflict.MissingConstraints, *c) - } - } - - return responseConflict, stacktrace.NewErrorWithCode(dsserr.MissingOVNs, "Missing OVNs: %v", msg) - } - - return nil, nil -} - -// ensureSubscriptionCoversOIR ensures that the subscription covers the requested geo-temporal extent, extending it if both possible and required, -// or failing otherwise. -// After this method returns successfully, the subscription will cover the requested geo-temporal extent. -func ensureSubscriptionCoversOIR(ctx context.Context, r repos.Repository, sub *scdmodels.Subscription, params *validOIRParams) (*scdmodels.Subscription, error) { - - updateSub := false - if sub.StartTime != nil && sub.StartTime.After(*params.uExtent.StartTime) { - if sub.ImplicitSubscription { - sub.StartTime = params.uExtent.StartTime - updateSub = true - } else { - return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Subscription does not begin until after the OperationalIntent starts") - } - } - if sub.EndTime != nil && sub.EndTime.Before(*params.uExtent.EndTime) { - if sub.ImplicitSubscription { - sub.EndTime = params.uExtent.EndTime - updateSub = true - } else { - return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Subscription ends before the OperationalIntent ends") - } - } - if !sub.Cells.Contains(params.cells) { - if sub.ImplicitSubscription { - sub.Cells = s2.CellUnionFromUnion(sub.Cells, params.cells) - updateSub = true - } else { - return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Subscription does not cover entire spatial area of the OperationalIntent") - } - } - if updateSub { - upsertedSub, err := r.UpsertSubscription(ctx, sub) - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to update existing Subscription") - } - return upsertedSub, nil - } - - return sub, nil -} - // upsertOperationalIntentReference inserts or updates an Operational Intent. // If the ovn argument is empty (""), it will attempt to create a new Operational Intent. func (a *Server) upsertOperationalIntentReference(ctx context.Context, now time.Time, authorizedManager *api.AuthorizationResult, entityid restapi.EntityID, ovn restapi.EntityOVN, params *restapi.PutOperationalIntentReferenceParameters, @@ -769,17 +502,25 @@ func (a *Server) upsertOperationalIntentReference(ctx context.Context, now time. if err != nil { return nil, nil, stacktrace.PropagateWithCode(err, dsserr.BadRequest, "Failed to validate Operational Intent Reference upsert parameters") } - manager, err := checkUpsertPermissionsAndReturnManager(authorizedManager, validParams.state) + manager, err := checkUpsertPermissionsAndReturnManager(authorizedManager, validParams.State) if err != nil { return nil, nil, stacktrace.PropagateWithCode(err, dsserr.PermissionDenied, "Caller is not allowed to upsert with the requested state") } + payload := &scdraftstore.UpsertOperationalIntentTransactionPayload{ + Manager: manager, + ValidParams: validParams, + } + for ovn := range validParams.Key { + payload.Key = append(payload.Key, ovn) + } + var responseOK *restapi.ChangeOperationalIntentReferenceResponse var responseConflict *restapi.AirspaceConflictResponse action := func(ctx context.Context, r repos.Repository) (err error) { // Get existing OperationalIntent, if any - old, err := r.GetOperationalIntent(ctx, validParams.id) + old, err := r.GetOperationalIntent(ctx, validParams.ID) if err != nil { return stacktrace.Propagate(err, "Could not get OperationalIntent from repo") } @@ -793,17 +534,17 @@ func (a *Server) upsertOperationalIntentReference(ctx context.Context, now time. subscriptionIds = append(subscriptionIds, *old.SubscriptionID) } - if !validParams.subscriptionID.Empty() { - subscriptionIds = append(subscriptionIds, validParams.subscriptionID) + if !validParams.SubscriptionID.Empty() { + subscriptionIds = append(subscriptionIds, validParams.SubscriptionID) } - err = r.LockSubscriptionsOnCells(ctx, validParams.cells, subscriptionIds, validParams.uExtent.StartTime, validParams.uExtent.EndTime) + err = r.LockSubscriptionsOnCells(ctx, validParams.Cells, subscriptionIds, validParams.UExtent.StartTime, validParams.UExtent.EndTime) if err != nil { return stacktrace.Propagate(err, "Unable to acquire lock") } // Validate the request against the previous OIR - if err := validateUpsertRequestAgainstPreviousOIR(manager, validParams.ovn, old); err != nil { + if err := repos.ValidateUpsertRequestAgainstPreviousOIR(manager, validParams.OVN, old); err != nil { return stacktrace.PropagateWithCode(err, stacktrace.GetCode(err), "Request validation failed") } @@ -814,7 +555,7 @@ func (a *Server) upsertOperationalIntentReference(ctx context.Context, now time. ) if old != nil { version = old.Version + 1 - pastOVNs = append(old.PastOVNs, validParams.ovn) + pastOVNs = append(old.PastOVNs, validParams.OVN) // Fetch the previous OIR's subscription if it exists if old.SubscriptionID != nil { @@ -826,10 +567,10 @@ func (a *Server) upsertOperationalIntentReference(ctx context.Context, now time. } // Determine if the previous subscription is being replaced and if it will need to be cleaned up - previousSubIsBeingReplaced := previousSub != nil && validParams.subscriptionID != previousSub.ID + previousSubIsBeingReplaced := previousSub != nil && validParams.SubscriptionID != previousSub.ID removePreviousImplicitSubscription := false if previousSubIsBeingReplaced { - removePreviousImplicitSubscription, err = subscriptionIsImplicitAndOnlyAttachedToOIR(ctx, r, validParams.id, previousSub) + removePreviousImplicitSubscription, err = repos.SubscriptionIsImplicitAndOnlyAttachedToOIR(ctx, r, validParams.ID, previousSub) if err != nil { return stacktrace.Propagate(err, "Could not determine if previous Subscription can be removed") } @@ -838,14 +579,14 @@ func (a *Server) upsertOperationalIntentReference(ctx context.Context, now time. // attachedSub is the subscription that will end up being attached to the OIR // it defaults to the previous subscription (which may be nil), and may be updated if required by the parameters attachedSub := previousSub - if validParams.subscriptionID.Empty() { + if validParams.SubscriptionID.Empty() { // No subscription ID was provided: // check if an implicit subscription should be created, otherwise do nothing - if validParams.implicitSubscription.requested { + if validParams.ImplicitSubscription.Requested { // Parameters for a new implicit subscription have been passed: we will create // a new implicit subscription even if another subscription was attached to this OIR before, // regardless of whether it was an implicit subscription or not. - if attachedSub, err = createAndStoreNewImplicitSubscription(ctx, r, manager, validParams); err != nil { + if attachedSub, err = repos.CreateAndStoreNewImplicitSubscription(ctx, r, manager, validParams); err != nil { return stacktrace.Propagate(err, "Failed to create implicit subscription") } } else { @@ -859,12 +600,12 @@ func (a *Server) upsertOperationalIntentReference(ctx context.Context, now time. // in order to ensure it correctly covers the OIR. // We do the check below in order to avoid re-fetching the subscription if it has not changed if attachedSub == nil || previousSubIsBeingReplaced { - attachedSub, err = r.GetSubscription(ctx, validParams.subscriptionID) + attachedSub, err = r.GetSubscription(ctx, validParams.SubscriptionID) if err != nil { return stacktrace.Propagate(err, "Unable to get requested Subscription from store") } if attachedSub == nil { - return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Specified Subscription %s does not exist", validParams.subscriptionID) + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Specified Subscription %s does not exist", validParams.SubscriptionID) } } @@ -877,28 +618,28 @@ func (a *Server) upsertOperationalIntentReference(ctx context.Context, now time. dsserr.PermissionDenied, "Specificed Subscription is owned by different client"), // The propagation message will end in the logs and help with debugging. "Subscription %s owned by %s, but %s attempted to use it for an OperationalIntent", - validParams.subscriptionID, + validParams.SubscriptionID, attachedSub.Manager, manager, ) } // We need to ensure the subscription covers the OIR's geo-temporal extent - attachedSub, err = ensureSubscriptionCoversOIR(ctx, r, attachedSub, validParams) + attachedSub, err = repos.EnsureSubscriptionCoversOIR(ctx, r, attachedSub, validParams) if err != nil { return stacktrace.Propagate(err, "Failed to ensure subscription covers OIR") } } - if validParams.state.RequiresKey() { - responseConflict, err = validateKeyAndProvideConflictResponse(ctx, r, manager, validParams, attachedSub) + if validParams.State.RequiresKey() { + responseConflict, err = repos.ValidateKeyAndProvideConflictResponse(ctx, r, manager, validParams, attachedSub) if err != nil { return stacktrace.PropagateWithCode(err, stacktrace.GetCode(err), "Failed to validate key") } } // Construct the new OperationalIntent - op := validParams.toOIR(manager, attachedSub, version, pastOVNs) + op := validParams.ToOIR(manager, attachedSub, version, pastOVNs) // Upsert the OperationalIntent op, err = r.UpsertOperationalIntent(ctx, op) @@ -914,13 +655,13 @@ func (a *Server) upsertOperationalIntentReference(ctx context.Context, now time. } } - notifyVolume, err := computeNotificationVolume(old, validParams.uExtent) + notifyVolume, err := repos.ComputeNotificationVolume(old, validParams.UExtent) if err != nil { return stacktrace.Propagate(err, "Failed to compute notification volume") } // Notify relevant Subscriptions - subsToNotify, err := getRelevantSubscriptionsAndIncrementIndices(ctx, r, notifyVolume) + subsToNotify, err := repos.GetRelevantSubscriptionsAndIncrementIndices(ctx, r, notifyVolume) if err != nil { return stacktrace.Propagate(err, "Failed to notify relevant Subscriptions") } @@ -928,16 +669,27 @@ func (a *Server) upsertOperationalIntentReference(ctx context.Context, now time. // Return response to client responseOK = &restapi.ChangeOperationalIntentReferenceResponse{ OperationalIntentReference: *op.ToRest(), - Subscribers: makeSubscribersToNotify(subsToNotify), + Subscribers: repos.MakeSubscribersToNotify(subsToNotify), } return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.UpsertOperationalIntentTransaction, payload, action) if err != nil { + if raftResult != nil { + if upsertRes, ok := raftResult.(*scdraftstore.UpsertOperationalIntentTransactionResult); ok { + responseConflict = upsertRes.ResponseConflict + } + } return nil, responseConflict, err // No need to Propagate this error as this is not a useful stacktrace line } + if raftResult != nil { + if upsertRes, ok := raftResult.(*scdraftstore.UpsertOperationalIntentTransactionResult); ok { + responseOK = upsertRes.ResponseOK + } + } + return responseOK, responseConflict, nil } diff --git a/pkg/scd/repos/helpers.go b/pkg/scd/repos/helpers.go new file mode 100644 index 000000000..513a07283 --- /dev/null +++ b/pkg/scd/repos/helpers.go @@ -0,0 +1,334 @@ +package repos + +import ( + "context" + + "github.com/golang/geo/s2" + + restapi "github.com/interuss/dss/pkg/api/scdv1" + + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/stacktrace" +) + +type ValidOIRParams struct { + ID dssmodels.ID + OVN scdmodels.OVN + NewOVN scdmodels.OVN + State scdmodels.OperationalIntentState + UExtent *dssmodels.Volume4D + Cells s2.CellUnion + SubscriptionID dssmodels.ID + USSBaseURL string + ImplicitSubscription struct { + Requested bool + ID dssmodels.ID + BaseURL string + ForConstraints bool + } + Key map[scdmodels.OVN]bool +} + +func (vp *ValidOIRParams) ToOIR(manager dssmodels.Manager, attachedSub *scdmodels.Subscription, version scdmodels.VersionNumber, pastOVNs []scdmodels.OVN) *scdmodels.OperationalIntent { + // For OIR's in the accepted state, we may not have a attachedSub available, + // in such cases the attachedSub ID on scdmodels.OperationalIntent will be nil + // and will be replaced with the 'NullV4UUID' when sent over to a client. + var subID *dssmodels.ID + if attachedSub != nil { + // Note: do _not_ use vp.subscriptionID here, as it may be empty + subID = &attachedSub.ID + } + return &scdmodels.OperationalIntent{ + ID: vp.ID, + Manager: manager, + Version: version, + OVN: vp.NewOVN, // non-empty only if the USS has requested an OVN + PastOVNs: pastOVNs, + + StartTime: vp.UExtent.StartTime, + EndTime: vp.UExtent.EndTime, + AltitudeLower: vp.UExtent.SpatialVolume.AltitudeLo, + AltitudeUpper: vp.UExtent.SpatialVolume.AltitudeHi, + Cells: vp.Cells, + + USSBaseURL: vp.USSBaseURL, + SubscriptionID: subID, + State: vp.State, + } +} + +// SubscriptionIsImplicitAndOnlyAttachedToOIR will check if: +// - the subscription is defined and is implicit +// - the subscription is attached to the specified operational intent +// - the subscription is not attached to any other operational intent +// +// This is to be used in contexts where an implicit subscription may need to be cleaned up: if true is returned, +// the subscription can be safely removed after the operational intent is deleted or attached to another subscription. +// +// NOTE: this should eventually be pushed down the datastore as part of the queries being executed in the callers of this method. +// +// See https://github.com/interuss/dss/issues/1059 for more details +func SubscriptionIsImplicitAndOnlyAttachedToOIR(ctx context.Context, r Repository, oirID dssmodels.ID, subscription *scdmodels.Subscription) (bool, error) { + if subscription == nil { + return false, nil + } + if !subscription.ImplicitSubscription { + return false, nil + } + // Get the Subscription's dependent OperationalIntents + dependentOps, err := r.GetDependentOperationalIntents(ctx, subscription.ID) + if err != nil { + return false, stacktrace.Propagate(err, "Could not find dependent OperationalIntents") + } + if len(dependentOps) == 0 { + return false, stacktrace.NewError("An implicit Subscription had no dependent OperationalIntents") + } else if len(dependentOps) == 1 && dependentOps[0] == oirID { + return true, nil + } + return false, nil +} + +// ValidateUpsertRequestAgainstPreviousOIR checks that the client requesting an OIR upsert has the necessary permissions and that the request is valid. +// On success, the version of the OIR is returned: +// - upon initial creation (if no previous OIR exists), it is 0 +// - otherwise, it is the version of the previous OIR +func ValidateUpsertRequestAgainstPreviousOIR( + requestingManager dssmodels.Manager, + providedOVN scdmodels.OVN, + previousOIR *scdmodels.OperationalIntent, +) error { + + if previousOIR != nil { + if previousOIR.Manager != requestingManager { + return stacktrace.NewErrorWithCode(dsserr.PermissionDenied, + "OperationalIntent owned by %s, but %s attempted to modify", previousOIR.Manager, requestingManager) + } + if previousOIR.OVN != providedOVN { + return stacktrace.NewErrorWithCode(dsserr.VersionMismatch, + "Current version is %s but client specified version %s", previousOIR.OVN, providedOVN) + } + + return nil + } + + if providedOVN != "" { + return stacktrace.NewErrorWithCode(dsserr.NotFound, "OperationalIntent does not exist and therefore is not version %s", providedOVN) + } + + return nil +} + +// CreateAndStoreNewImplicitSubscription will create a brand new implicit subscription based on the provided parameters, +// store it and return it. +func CreateAndStoreNewImplicitSubscription(ctx context.Context, r Repository, manager dssmodels.Manager, validParams *ValidOIRParams) (*scdmodels.Subscription, error) { + subToUpsert := scdmodels.Subscription{ + ID: validParams.ImplicitSubscription.ID, + Manager: manager, + StartTime: validParams.UExtent.StartTime, + EndTime: validParams.UExtent.EndTime, + AltitudeLo: validParams.UExtent.SpatialVolume.AltitudeLo, + AltitudeHi: validParams.UExtent.SpatialVolume.AltitudeHi, + Cells: validParams.Cells, + USSBaseURL: validParams.ImplicitSubscription.BaseURL, + NotifyForOperationalIntents: true, + NotifyForConstraints: validParams.ImplicitSubscription.ForConstraints, + ImplicitSubscription: true, + } + + return r.UpsertSubscription(ctx, &subToUpsert) +} + +// ComputeNotificationVolume computes the volume that needs to be queried for subscriptions +// given the requested extent and the (possibly nil) previous operational intent. +// The returned volume is either the union of the requested extent and the previous OIR's extent, or just the requested extent +// if the previous OIR is nil. +func ComputeNotificationVolume( + previousOIR *scdmodels.OperationalIntent, + requestedExtent *dssmodels.Volume4D) (*dssmodels.Volume4D, error) { + + if previousOIR == nil { + return requestedExtent, nil + } + + // Compute total affected Volume4D for notification purposes + oldVolume := &dssmodels.Volume4D{ + StartTime: previousOIR.StartTime, + EndTime: previousOIR.EndTime, + SpatialVolume: &dssmodels.Volume3D{ + AltitudeHi: previousOIR.AltitudeUpper, + AltitudeLo: previousOIR.AltitudeLower, + Footprint: dssmodels.GeometryFunc(func() (s2.CellUnion, error) { + return previousOIR.Cells, nil + }), + }, + } + notifyVolume, err := dssmodels.UnionVolumes4D(requestedExtent, oldVolume) + if err != nil { + return nil, stacktrace.Propagate(err, "Error constructing 4D volumes union") + } + + return notifyVolume, nil +} + +// GetRelevantSubscriptionsAndIncrementIndices retrieves the subscriptions relevant to the passed volume and increments their notification indices +// before returning them. +func GetRelevantSubscriptionsAndIncrementIndices( + ctx context.Context, + r Repository, + notifyVolume *dssmodels.Volume4D, +) (Subscriptions, error) { + + // Find the Subscriptions interested in OperationalIntents and increment their + // notification indices + subs, err := r.IncrementNotificationIndicesForOperationalIntents(ctx, notifyVolume) + + if err != nil { + return nil, stacktrace.Propagate(err, "Failed to increment notification indices of relevant subscriptions") + } + + return subs, nil +} + +// ValidateKeyAndProvideConflictResponse ensures that the provided key contains all the necessary OVNs relevant for the area covered by the OperationalIntent. +// - If all required keys are provided, (nil, nil) will be returned. +// - If keys are missing, the conflict response to be sent back as well as an error with the dsserr.MissingOVNs code will be returned. +// - In case of any other error, (nil, error) will be returned. +func ValidateKeyAndProvideConflictResponse( + ctx context.Context, + r Repository, + requestingManager dssmodels.Manager, + params *ValidOIRParams, + attachedSubscription *scdmodels.Subscription, +) (*restapi.AirspaceConflictResponse, error) { + + // Identify OperationalIntents missing from the key + var missingOps []*scdmodels.OperationalIntent + relevantOps, err := r.SearchOperationalIntents(ctx, params.UExtent) + if err != nil { + return nil, stacktrace.Propagate(err, "Unable to SearchOperations") + } + for _, relevantOp := range relevantOps { + _, ok := params.Key[relevantOp.OVN] + // Note: The OIR being mutated does not need to be specified in the key: + if !ok && relevantOp.RequiresKey() && relevantOp.ID != params.ID { + missingOps = append(missingOps, relevantOp) + } + } + + // Identify Constraints missing from the key + var missingConstraints []*scdmodels.Constraint + if attachedSubscription != nil && attachedSubscription.NotifyForConstraints { + constraints, err := r.SearchConstraints(ctx, params.UExtent) + if err != nil { + return nil, stacktrace.Propagate(err, "Unable to SearchConstraints") + } + for _, relevantConstraint := range constraints { + if _, ok := params.Key[relevantConstraint.OVN]; !ok { + missingConstraints = append(missingConstraints, relevantConstraint) + } + } + } + + // If the client is missing some OVNs, provide the pointers to the + // information they need + if len(missingOps) > 0 || len(missingConstraints) > 0 { + msg := "Current OVNs not provided for one or more OperationalIntents or Constraints" + responseConflict := &restapi.AirspaceConflictResponse{Message: &msg} + + if len(missingOps) > 0 { + responseConflict.MissingOperationalIntents = new([]restapi.OperationalIntentReference) + for _, missingOp := range missingOps { + p := missingOp.ToRest() + // We scrub the OVNs of entities not owned by the requesting manager to make sure + // they have really contacted the managing USS + if missingOp.Manager != requestingManager { + noOvnPhrase := restapi.EntityOVN(scdmodels.NoOvnPhrase) + p.Ovn = &noOvnPhrase + } + *responseConflict.MissingOperationalIntents = append(*responseConflict.MissingOperationalIntents, *p) + } + } + + if len(missingConstraints) > 0 { + responseConflict.MissingConstraints = new([]restapi.ConstraintReference) + for _, missingConstraint := range missingConstraints { + c := missingConstraint.ToRest() + // We scrub the OVNs of entities not owned by the requesting manager to make sure + // they have really contacted the managing USS + if missingConstraint.Manager != requestingManager { + noOvnPhrase := restapi.EntityOVN(scdmodels.NoOvnPhrase) + c.Ovn = &noOvnPhrase + } + *responseConflict.MissingConstraints = append(*responseConflict.MissingConstraints, *c) + } + } + + return responseConflict, stacktrace.NewErrorWithCode(dsserr.MissingOVNs, "Missing OVNs: %v", msg) + } + + return nil, nil +} + +// EnsureSubscriptionCoversOIR ensures that the subscription covers the requested geo-temporal extent, extending it if both possible and required, +// or failing otherwise. +// After this method returns successfully, the subscription will cover the requested geo-temporal extent. +func EnsureSubscriptionCoversOIR(ctx context.Context, r Repository, sub *scdmodels.Subscription, params *ValidOIRParams) (*scdmodels.Subscription, error) { + + updateSub := false + if sub.StartTime != nil && sub.StartTime.After(*params.UExtent.StartTime) { + if sub.ImplicitSubscription { + sub.StartTime = params.UExtent.StartTime + updateSub = true + } else { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Subscription does not begin until after the OperationalIntent starts") + } + } + if sub.EndTime != nil && sub.EndTime.Before(*params.UExtent.EndTime) { + if sub.ImplicitSubscription { + sub.EndTime = params.UExtent.EndTime + updateSub = true + } else { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Subscription ends before the OperationalIntent ends") + } + } + if !sub.Cells.Contains(params.Cells) { + if sub.ImplicitSubscription { + sub.Cells = s2.CellUnionFromUnion(sub.Cells, params.Cells) + updateSub = true + } else { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Subscription does not cover entire spatial area of the OperationalIntent") + } + } + if updateSub { + upsertedSub, err := r.UpsertSubscription(ctx, sub) + if err != nil { + return nil, stacktrace.Propagate(err, "Failed to update existing Subscription") + } + return upsertedSub, nil + } + + return sub, nil +} + +func MakeSubscribersToNotify(subscriptions []*scdmodels.Subscription) []restapi.SubscriberToNotify { + result := []restapi.SubscriberToNotify{} + + subscriptionsByURL := map[string][]restapi.SubscriptionState{} + for _, sub := range subscriptions { + subState := restapi.SubscriptionState{ + SubscriptionId: restapi.SubscriptionID(sub.ID.String()), + NotificationIndex: restapi.SubscriptionNotificationIndex(sub.NotificationIndex), + } + subscriptionsByURL[sub.USSBaseURL] = append(subscriptionsByURL[sub.USSBaseURL], subState) + } + for url, states := range subscriptionsByURL { + result = append(result, restapi.SubscriberToNotify{ + UssBaseUrl: restapi.SubscriptionUssBaseURL(url), + Subscriptions: states, + }) + } + + return result +} diff --git a/pkg/scd/server.go b/pkg/scd/server.go index a8eab0aaa..9dea0d488 100644 --- a/pkg/scd/server.go +++ b/pkg/scd/server.go @@ -1,32 +1,9 @@ package scd import ( - restapi "github.com/interuss/dss/pkg/api/scdv1" - scdmodels "github.com/interuss/dss/pkg/scd/models" scdstore "github.com/interuss/dss/pkg/scd/store" ) -func makeSubscribersToNotify(subscriptions []*scdmodels.Subscription) []restapi.SubscriberToNotify { - result := []restapi.SubscriberToNotify{} - - subscriptionsByURL := map[string][]restapi.SubscriptionState{} - for _, sub := range subscriptions { - subState := restapi.SubscriptionState{ - SubscriptionId: restapi.SubscriptionID(sub.ID.String()), - NotificationIndex: restapi.SubscriptionNotificationIndex(sub.NotificationIndex), - } - subscriptionsByURL[sub.USSBaseURL] = append(subscriptionsByURL[sub.USSBaseURL], subState) - } - for url, states := range subscriptionsByURL { - result = append(result, restapi.SubscriberToNotify{ - UssBaseUrl: restapi.SubscriptionUssBaseURL(url), - Subscriptions: states, - }) - } - - return result -} - // Server implements scdv1.Implementation. type Server struct { Store scdstore.Store diff --git a/pkg/scd/store/memstore/availability.go b/pkg/scd/store/memstore/availability.go new file mode 100644 index 000000000..7be9ded67 --- /dev/null +++ b/pkg/scd/store/memstore/availability.go @@ -0,0 +1,38 @@ +package memstore + +import ( + "context" + + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/dss/pkg/timestamp" + "github.com/jackc/pgx/v5" +) + +func (rec *availabilityRecord) toModel() *scdmodels.UssAvailabilityStatus { + return &scdmodels.UssAvailabilityStatus{ + Uss: rec.Uss, + Availability: rec.Availability, + Version: scdmodels.NewOVNFromTime(rec.UpdatedAt, rec.Uss.String()), + } +} + +// GetUssAvailability implements scd.repos.UssAvailability.GetUssAvailability. +func (r *repo) GetUssAvailability(_ context.Context, id dssmodels.Manager) (*scdmodels.UssAvailabilityStatus, error) { + rec, ok := r.state.Availabilities[id] + if !ok { + return nil, pgx.ErrNoRows + } + return rec.toModel(), nil +} + +// UpsertUssAvailability implements scd.repos.UssAvailability.UpsertUssAvailability. +func (r *repo) UpsertUssAvailability(ctx context.Context, s *scdmodels.UssAvailabilityStatus) (*scdmodels.UssAvailabilityStatus, error) { + rec := &availabilityRecord{ + Uss: s.Uss, + Availability: s.Availability, + UpdatedAt: timestamp.NowFromContext(ctx), + } + r.state.Availabilities[s.Uss] = rec + return rec.toModel(), nil +} diff --git a/pkg/scd/store/memstore/availability_test.go b/pkg/scd/store/memstore/availability_test.go new file mode 100644 index 000000000..aa4b33e5e --- /dev/null +++ b/pkg/scd/store/memstore/availability_test.go @@ -0,0 +1,32 @@ +package memstore + +import ( + "errors" + "testing" + + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" +) + +func TestUssAvailabilityUpsertGet(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + + got, err := r.UpsertUssAvailability(ctx, sampleAvailability()) + require.NoError(t, err) + require.Equal(t, manager, got.Uss) + require.Equal(t, scdmodels.UssAvailabilityStateNormal, got.Availability) + require.NotEmpty(t, got.Version) + + fetched, err := r.GetUssAvailability(ctx, manager) + require.NoError(t, err) + require.Equal(t, got.Version, fetched.Version) + require.Equal(t, scdmodels.UssAvailabilityStateNormal, fetched.Availability) +} + +func TestGetUssAvailabilityMissingReturnsErrNoRows(t *testing.T) { + r := setUpStore(t) + _, err := r.GetUssAvailability(writeCtx(), manager) + require.True(t, errors.Is(err, pgx.ErrNoRows)) +} diff --git a/pkg/scd/store/memstore/constraints.go b/pkg/scd/store/memstore/constraints.go new file mode 100644 index 000000000..1a848a0c4 --- /dev/null +++ b/pkg/scd/store/memstore/constraints.go @@ -0,0 +1,103 @@ +package memstore + +import ( + "context" + + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + dsssql "github.com/interuss/dss/pkg/sql" + "github.com/interuss/dss/pkg/timestamp" + "github.com/interuss/stacktrace" + "github.com/jackc/pgx/v5" +) + +func (rec *constraintRecord) toModel() *scdmodels.Constraint { + return &scdmodels.Constraint{ + ID: rec.ID, + Manager: rec.Manager, + Version: rec.Version, + OVN: scdmodels.NewOVNFromTime(rec.UpdatedAt, rec.ID.String()), + StartTime: cloneTime(rec.StartTime), + EndTime: cloneTime(rec.EndTime), + USSBaseURL: rec.USSBaseURL, + AltitudeLower: cloneFloat32(rec.AltitudeLower), + AltitudeUpper: cloneFloat32(rec.AltitudeUpper), + Cells: cloneCells(rec.Cells), + } +} + +// SearchConstraints implements scd.repos.Constraint.SearchConstraints. +func (r *repo) SearchConstraints(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Constraint, error) { + cells, err := v4d.CalculateSpatialCovering() + if err != nil { + return nil, stacktrace.Propagate(err, "Could not calculate spatial covering") + } + if len(cells) == 0 { + return []*scdmodels.Constraint{}, nil + } + + want := cellSet(cells) + var out []*scdmodels.Constraint + for _, rec := range r.state.Constraints { + if !overlaps(rec.Cells, want) { + continue + } + // COALESCE(starts_at <= $3, true) with $3 = v4d.EndTime + if rec.StartTime != nil && v4d.EndTime != nil && rec.StartTime.After(*v4d.EndTime) { + continue + } + // COALESCE(ends_at >= $2, true) with $2 = v4d.StartTime + if rec.EndTime != nil && v4d.StartTime != nil && rec.EndTime.Before(*v4d.StartTime) { + continue + } + out = append(out, rec.toModel()) + if len(out) >= dssmodels.MaxResultLimit { // mirror SQL "LIMIT MaxResultLimit" + break + } + } + return out, nil +} + +// GetConstraint implements scd.repos.Constraint.GetConstraint. +func (r *repo) GetConstraint(_ context.Context, id dssmodels.ID) (*scdmodels.Constraint, error) { + rec, ok := r.state.Constraints[id] + if !ok { + return nil, pgx.ErrNoRows + } + return rec.toModel(), nil +} + +// UpsertConstraint implements scd.repos.Constraint.UpsertConstraint. +func (r *repo) UpsertConstraint(ctx context.Context, s *scdmodels.Constraint) (*scdmodels.Constraint, error) { + if _, err := dsssql.CellUnionToCellIdsWithValidation(s.Cells); err != nil { + return nil, stacktrace.Propagate(err, "Failed to convert array to jackc/pgtype") + } + + rec := &constraintRecord{ + ID: s.ID, + Manager: s.Manager, + Version: s.Version, + StartTime: cloneTime(s.StartTime), + EndTime: cloneTime(s.EndTime), + USSBaseURL: s.USSBaseURL, + AltitudeLower: cloneFloat32(s.AltitudeLower), + AltitudeUpper: cloneFloat32(s.AltitudeUpper), + Cells: cloneCells(s.Cells), + UpdatedAt: timestamp.NowFromContext(ctx), + } + r.state.Constraints[s.ID] = rec + return rec.toModel(), nil +} + +// DeleteConstraint implements scd.repos.Constraint.DeleteConstraint. +func (r *repo) DeleteConstraint(_ context.Context, id dssmodels.ID) error { + if _, ok := r.state.Constraints[id]; !ok { + return pgx.ErrNoRows + } + delete(r.state.Constraints, id) + return nil +} + +func (r *repo) CountConstraints(_ context.Context) (int64, error) { + return int64(len(r.state.Constraints)), nil +} diff --git a/pkg/scd/store/memstore/constraints_test.go b/pkg/scd/store/memstore/constraints_test.go new file mode 100644 index 000000000..3515820c5 --- /dev/null +++ b/pkg/scd/store/memstore/constraints_test.go @@ -0,0 +1,73 @@ +package memstore + +import ( + "errors" + "testing" + "time" + + "github.com/golang/geo/s2" + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" +) + +func TestConstraintUpsertGetDelete(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + + got, err := r.UpsertConstraint(ctx, sampleConstraint()) + require.NoError(t, err) + require.Equal(t, constraintId, got.ID) + require.Equal(t, manager, got.Manager) + require.NotEmpty(t, got.OVN) + + fetched, err := r.GetConstraint(ctx, constraintId) + require.NoError(t, err) + require.Equal(t, got.OVN, fetched.OVN) + require.Equal(t, cells, fetched.Cells) + + count, err := r.CountConstraints(ctx) + require.NoError(t, err) + require.Equal(t, int64(1), count) + + require.NoError(t, r.DeleteConstraint(ctx, constraintId)) + + _, err = r.GetConstraint(ctx, constraintId) + require.True(t, errors.Is(err, pgx.ErrNoRows)) +} + +func TestConstraintGetMissingReturnsErrNoRows(t *testing.T) { + r := setUpStore(t) + _, err := r.GetConstraint(writeCtx(), constraintId) + require.True(t, errors.Is(err, pgx.ErrNoRows)) +} + +func TestConstraintDeleteMissingReturnsErrNoRows(t *testing.T) { + r := setUpStore(t) + err := r.DeleteConstraint(writeCtx(), constraintId) + require.True(t, errors.Is(err, pgx.ErrNoRows)) +} + +func TestSearchConstraints(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + _, err := r.UpsertConstraint(ctx, sampleConstraint()) + require.NoError(t, err) + + // Overlapping volume with no time bounds matches. + res, err := r.SearchConstraints(ctx, volume4D(cells, nil, nil, nil, nil)) + require.NoError(t, err) + require.Len(t, res, 1) + + // Time window after the constraint's end excludes it. + afterStart := endTime.Add(time.Hour) + afterEnd := afterStart.Add(time.Hour) + res, err = r.SearchConstraints(ctx, volume4D(cells, &afterStart, &afterEnd, nil, nil)) + require.NoError(t, err) + require.Empty(t, res) + + // No covering cells returns an empty (non-nil) slice. + res, err = r.SearchConstraints(ctx, volume4D(s2.CellUnion{}, nil, nil, nil, nil)) + require.NoError(t, err) + require.NotNil(t, res) + require.Empty(t, res) +} diff --git a/pkg/scd/store/memstore/doc.go b/pkg/scd/store/memstore/doc.go new file mode 100644 index 000000000..f86b85987 --- /dev/null +++ b/pkg/scd/store/memstore/doc.go @@ -0,0 +1,3 @@ +// Package scd.store.memstore provides a full implementation of store.Store[scd.repos.Repository] +// storing data in memory. It is meant to be used by raftstore. +package memstore diff --git a/pkg/scd/store/memstore/operational_intents.go b/pkg/scd/store/memstore/operational_intents.go new file mode 100644 index 000000000..15dbed4a3 --- /dev/null +++ b/pkg/scd/store/memstore/operational_intents.go @@ -0,0 +1,205 @@ +package memstore + +import ( + "context" + "errors" + "time" + + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/dss/pkg/timestamp" + "github.com/interuss/stacktrace" + "github.com/jackc/pgx/v5" +) + +// toModel rebuilds the OperationalIntent model without its UssAvailability, +// which is attached separately (see buildOperationalIntents). +func (rec *operationalIntentRecord) toModel() *scdmodels.OperationalIntent { + // If the managing USS has requested a specific OVN it is persisted, otherwise + // a default DSS-generated OVN based on the last update time is used. + var ovn scdmodels.OVN + if rec.USSRequestedOVN != "" { + ovn = scdmodels.OVN(rec.USSRequestedOVN) + } else { + ovn = scdmodels.NewOVNFromTime(rec.UpdatedAt, rec.ID.String()) + } + return &scdmodels.OperationalIntent{ + ID: rec.ID, + Manager: rec.Manager, + Version: rec.Version, + State: rec.State, + OVN: ovn, + PastOVNs: clonePastOVNs(rec.PastOVNs), + StartTime: cloneTime(rec.StartTime), + EndTime: cloneTime(rec.EndTime), + USSBaseURL: rec.USSBaseURL, + SubscriptionID: cloneID(rec.SubscriptionID), + AltitudeLower: cloneFloat32(rec.AltitudeLower), + AltitudeUpper: cloneFloat32(rec.AltitudeUpper), + Cells: cloneCells(rec.Cells), + } +} + +// buildOperationalIntents converts records to models and attaches the +// UssAvailability of each managing USS +func (r *repo) buildOperationalIntents(ctx context.Context, recs []*operationalIntentRecord) ([]*scdmodels.OperationalIntent, error) { + ussAvailabilities := map[dssmodels.Manager]scdmodels.UssAvailabilityState{} + payload := make([]*scdmodels.OperationalIntent, 0, len(recs)) + for _, rec := range recs { + o := rec.toModel() + ussAvailabilities[o.Manager] = scdmodels.UssAvailabilityStateUnknown + payload = append(payload, o) + } + + for manager := range ussAvailabilities { + ussAvailability, err := r.GetUssAvailability(ctx, manager) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return nil, stacktrace.Propagate(err, "Error getting USS availability of %s", manager) + } + if ussAvailability != nil { + ussAvailabilities[manager] = ussAvailability.Availability + } + } + + for _, op := range payload { + op.UssAvailability = ussAvailabilities[op.Manager] + } + return payload, nil +} + +// GetOperationalIntent implements scd.repos.OperationalIntent.GetOperationalIntent. +func (r *repo) GetOperationalIntent(ctx context.Context, id dssmodels.ID) (*scdmodels.OperationalIntent, error) { + rec, ok := r.state.OperationalIntents[id] + if !ok { + return nil, nil + } + built, err := r.buildOperationalIntents(ctx, []*operationalIntentRecord{rec}) + if err != nil { + return nil, err + } + return built[0], nil +} + +// DeleteOperationalIntent implements scd.repos.OperationalIntent.DeleteOperationalIntent. +func (r *repo) DeleteOperationalIntent(_ context.Context, id dssmodels.ID) error { + if _, ok := r.state.OperationalIntents[id]; !ok { + return stacktrace.NewError("Could not delete Operation that does not exist") + } + delete(r.state.OperationalIntents, id) + return nil +} + +// UpsertOperationalIntent implements scd.repos.OperationalIntent.UpsertOperationalIntent. +func (r *repo) UpsertOperationalIntent(ctx context.Context, operation *scdmodels.OperationalIntent) (*scdmodels.OperationalIntent, error) { + // An empty OVN means the DSS generates it; it is persisted as NULL in the + // sqlstore (represented here by an empty USSRequestedOVN). + var ussRequestedOVN string + if operation.OVN != "" { + ussRequestedOVN = operation.OVN.String() + } + + rec := &operationalIntentRecord{ + ID: operation.ID, + Manager: operation.Manager, + Version: operation.Version, + State: operation.State, + StartTime: cloneTime(operation.StartTime), + EndTime: cloneTime(operation.EndTime), + USSBaseURL: operation.USSBaseURL, + SubscriptionID: cloneID(operation.SubscriptionID), + AltitudeLower: cloneFloat32(operation.AltitudeLower), + AltitudeUpper: cloneFloat32(operation.AltitudeUpper), + Cells: cloneCells(operation.Cells), + USSRequestedOVN: ussRequestedOVN, + PastOVNs: clonePastOVNs(operation.PastOVNs), + UpdatedAt: timestamp.NowFromContext(ctx), + } + r.state.OperationalIntents[operation.ID] = rec + + built, err := r.buildOperationalIntents(ctx, []*operationalIntentRecord{rec}) + if err != nil { + return nil, err + } + return built[0], nil +} + +// SearchOperationalIntents implements scd.repos.OperationalIntent.SearchOperationalIntents. +func (r *repo) SearchOperationalIntents(ctx context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.OperationalIntent, error) { + if v4d.SpatialVolume == nil || v4d.SpatialVolume.Footprint == nil { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Missing geospatial footprint for query") + } + cells, err := v4d.SpatialVolume.Footprint.CalculateCovering() + if err != nil { + return nil, stacktrace.PropagateWithCode(err, dsserr.BadRequest, "Failed to calculate footprint covering") + } + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Missing cell IDs for query") + } + + want := cellSet(cells) + var matched []*operationalIntentRecord + for _, rec := range r.state.OperationalIntents { + if !overlaps(rec.Cells, want) { + continue + } + // COALESCE(altitude_upper >= $2, true) with $2 = SpatialVolume.AltitudeLo + if rec.AltitudeUpper != nil && v4d.SpatialVolume.AltitudeLo != nil && *rec.AltitudeUpper < *v4d.SpatialVolume.AltitudeLo { + continue + } + // COALESCE(altitude_lower <= $3, true) with $3 = SpatialVolume.AltitudeHi + if rec.AltitudeLower != nil && v4d.SpatialVolume.AltitudeHi != nil && *rec.AltitudeLower > *v4d.SpatialVolume.AltitudeHi { + continue + } + // COALESCE(ends_at >= $4, true) with $4 = v4d.StartTime + if rec.EndTime != nil && v4d.StartTime != nil && rec.EndTime.Before(*v4d.StartTime) { + continue + } + // COALESCE(starts_at <= $5, true) with $5 = v4d.EndTime + if rec.StartTime != nil && v4d.EndTime != nil && rec.StartTime.After(*v4d.EndTime) { + continue + } + matched = append(matched, rec) + if len(matched) >= dssmodels.MaxResultLimit { // mirror SQL "LIMIT MaxResultLimit" + break + } + } + return r.buildOperationalIntents(ctx, matched) +} + +// GetDependentOperationalIntents implements scd.repos.OperationalIntent.GetDependentOperationalIntents. +func (r *repo) GetDependentOperationalIntents(_ context.Context, subscriptionID dssmodels.ID) ([]dssmodels.ID, error) { + var dependentOps []dssmodels.ID + for _, rec := range r.state.OperationalIntents { + if rec.SubscriptionID != nil && *rec.SubscriptionID == subscriptionID { + dependentOps = append(dependentOps, rec.ID) + } + } + return dependentOps, nil +} + +// ListExpiredOperationalIntents implements scd.repos.OperationalIntent.ListExpiredOperationalIntents. +func (r *repo) ListExpiredOperationalIntents(ctx context.Context, threshold time.Time) ([]*scdmodels.OperationalIntent, error) { + var matched []*operationalIntentRecord + for _, rec := range r.state.OperationalIntents { + // (ends_at IS NOT NULL AND ends_at <= threshold) OR (ends_at IS NULL AND updated_at <= threshold) + var expired bool + if rec.EndTime != nil { + expired = !rec.EndTime.After(threshold) + } else { + expired = !rec.UpdatedAt.After(threshold) + } + if !expired { + continue + } + matched = append(matched, rec) + if len(matched) >= dssmodels.MaxResultLimit { // mirror SQL "LIMIT MaxResultLimit" + break + } + } + return r.buildOperationalIntents(ctx, matched) +} + +func (r *repo) CountOperationalIntents(_ context.Context) (int64, error) { + return int64(len(r.state.OperationalIntents)), nil +} diff --git a/pkg/scd/store/memstore/operational_intents_test.go b/pkg/scd/store/memstore/operational_intents_test.go new file mode 100644 index 000000000..4eaabb575 --- /dev/null +++ b/pkg/scd/store/memstore/operational_intents_test.go @@ -0,0 +1,208 @@ +package memstore + +import ( + "testing" + "time" + + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/stretchr/testify/require" +) + +func TestOperationalIntentUpsertGetDelete(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + + got, err := r.UpsertOperationalIntent(ctx, sampleOperationalIntent()) + require.NoError(t, err) + require.Equal(t, operationalIntentId, got.ID) + require.Equal(t, scdmodels.OperationalIntentStateAccepted, got.State) + require.NotEmpty(t, got.OVN) + // No availability stored yet: defaults to Unknown. + require.Equal(t, scdmodels.UssAvailabilityStateUnknown, got.UssAvailability) + + count, err := r.CountOperationalIntents(ctx) + require.NoError(t, err) + require.Equal(t, int64(1), count) + + require.NoError(t, r.DeleteOperationalIntent(ctx, operationalIntentId)) + gone, err := r.GetOperationalIntent(ctx, operationalIntentId) + require.NoError(t, err) + require.Nil(t, gone) +} + +func TestOperationalIntentGetMissingReturnsNil(t *testing.T) { + r := setUpStore(t) + got, err := r.GetOperationalIntent(writeCtx(), operationalIntentId) + require.NoError(t, err) + require.Nil(t, got) +} + +func TestOperationalIntentDeleteMissingErrors(t *testing.T) { + r := setUpStore(t) + require.Error(t, r.DeleteOperationalIntent(writeCtx(), operationalIntentId)) +} + +func TestOperationalIntentUssAvailabilityAttached(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + _, err := r.UpsertUssAvailability(ctx, sampleAvailability()) + require.NoError(t, err) + _, err = r.UpsertOperationalIntent(ctx, sampleOperationalIntent()) + require.NoError(t, err) + + got, err := r.GetOperationalIntent(ctx, operationalIntentId) + require.NoError(t, err) + require.Equal(t, scdmodels.UssAvailabilityStateNormal, got.UssAvailability) +} + +func TestSearchOperationalIntents(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + _, err := r.UpsertOperationalIntent(ctx, sampleOperationalIntent()) + require.NoError(t, err) + + res, err := r.SearchOperationalIntents(ctx, volume4D(cells, nil, nil, nil, nil)) + require.NoError(t, err) + require.Len(t, res, 1) + + // Altitude window entirely above the operational intent excludes it. + var lo float32 = 200 + res, err = r.SearchOperationalIntents(ctx, volume4D(cells, nil, nil, &lo, nil)) + require.NoError(t, err) + require.Empty(t, res) + + // Missing footprint is a bad request. + _, err = r.SearchOperationalIntents(ctx, &dssmodels.Volume4D{}) + require.Error(t, err) +} + +func TestGetDependentOperationalIntents(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + _, err := r.UpsertOperationalIntent(ctx, sampleOperationalIntent()) + require.NoError(t, err) + + deps, err := r.GetDependentOperationalIntents(ctx, subscriptionId) + require.NoError(t, err) + require.Equal(t, []dssmodels.ID{operationalIntentId}, deps) + + deps, err = r.GetDependentOperationalIntents(ctx, "other") + require.NoError(t, err) + require.Nil(t, deps) +} + +var ( + oi1ID = dssmodels.ID("00000185-e36d-40be-8d38-beca6ca30000") + oi2ID = dssmodels.ID("00000185-e36d-40be-8d38-beca6ca30001") + oi3ID = dssmodels.ID("00000185-e36d-40be-8d38-beca6ca30003") + + start1 = time.Date(2024, time.August, 14, 15, 48, 36, 0, time.UTC) + end1 = start1.Add(time.Hour) + start2 = time.Date(2024, time.September, 15, 15, 48, 36, 0, time.UTC) + end2 = start2.Add(time.Hour) + start3 = time.Date(2024, time.September, 16, 15, 48, 36, 0, time.UTC) + end3 = start3.Add(time.Hour) +) + +var ( + oi1 = &scdmodels.OperationalIntent{ + ID: oi1ID, + Manager: "unittest", + Version: 1, + State: scdmodels.OperationalIntentStateAccepted, + StartTime: &start1, + EndTime: &end1, + USSBaseURL: "https://dummy.uss", + SubscriptionID: &sub1ID, + AltitudeLower: &altLow, + AltitudeUpper: &altHigh, + Cells: cells, + } + oi2 = &scdmodels.OperationalIntent{ + ID: oi2ID, + Manager: "unittest", + Version: 1, + State: scdmodels.OperationalIntentStateAccepted, + StartTime: &start2, + EndTime: &end2, + USSBaseURL: "https://dummy.uss", + SubscriptionID: &sub2ID, + AltitudeLower: &altLow, + AltitudeUpper: &altHigh, + Cells: cells, + } + oi3 = &scdmodels.OperationalIntent{ + ID: oi3ID, + Manager: "unittest", + Version: 1, + State: scdmodels.OperationalIntentStateAccepted, + StartTime: &start3, + EndTime: &end3, + USSBaseURL: "https://dummy.uss", + SubscriptionID: &sub3ID, + AltitudeLower: &altLow, + AltitudeUpper: &altHigh, + Cells: cells, + } +) + +func TestListExpiredOperationalIntents(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + + _, err := r.UpsertSubscription(ctx, sub1) + require.NoError(t, err) + _, err = r.UpsertOperationalIntent(ctx, oi1) + require.NoError(t, err) + + _, err = r.UpsertSubscription(ctx, sub2) + require.NoError(t, err) + _, err = r.UpsertOperationalIntent(ctx, oi2) + require.NoError(t, err) + + _, err = r.UpsertSubscription(ctx, sub3) + require.NoError(t, err) + _, err = r.UpsertOperationalIntent(ctx, oi3) + require.NoError(t, err) + + testCases := []struct { + name string + timeRef time.Time + ttl time.Duration + expired []dssmodels.ID + }{{ + name: "none expired, one in close past", + timeRef: time.Date(2024, time.August, 25, 15, 0, 0, 0, time.UTC), + ttl: time.Hour * 24 * 30, + expired: []dssmodels.ID{}, + }, { + name: "one recently expired, one current, one in future", + timeRef: time.Date(2024, time.September, 15, 16, 0, 0, 0, time.UTC), + ttl: time.Hour * 24 * 30, + expired: []dssmodels.ID{oi1ID}, + }, { + name: "two expired, one in future", + timeRef: time.Date(2024, time.September, 16, 16, 0, 0, 0, time.UTC), + ttl: time.Hour * 2, + expired: []dssmodels.ID{oi1ID, oi2ID}, + }, { + name: "all expired", + timeRef: time.Date(2024, time.December, 15, 15, 0, 0, 0, time.UTC), + ttl: time.Hour * 24 * 30, + expired: []dssmodels.ID{oi1ID, oi2ID, oi3ID}, + }} + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + threshold := testCase.timeRef.Add(-testCase.ttl) + expired, err := r.ListExpiredOperationalIntents(ctx, threshold) + require.NoError(t, err) + + expiredIDs := make([]dssmodels.ID, 0, len(expired)) + for _, expiredOi := range expired { + expiredIDs = append(expiredIDs, expiredOi.ID) + } + require.ElementsMatch(t, expiredIDs, testCase.expired) + }) + } +} diff --git a/pkg/scd/store/memstore/snapshot.go b/pkg/scd/store/memstore/snapshot.go new file mode 100644 index 000000000..6688121f8 --- /dev/null +++ b/pkg/scd/store/memstore/snapshot.go @@ -0,0 +1,49 @@ +package memstore + +import ( + "bytes" + "encoding/gob" + + dssmodels "github.com/interuss/dss/pkg/models" + "github.com/interuss/stacktrace" +) + +const snapshotVersion = 1 + +type snapshotEnvelope struct { + Version int + State state +} + +func (r *repo) GetSnapshot() ([]byte, error) { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion, State: r.state}); err != nil { + return nil, stacktrace.Propagate(err, "Failed to encode memstore snapshot") + } + return buf.Bytes(), nil +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + var env snapshotEnvelope + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&env); err != nil { + return stacktrace.Propagate(err, "Failed to decode memstore snapshot") + } + if env.Version != snapshotVersion { + return stacktrace.NewError("Unsupported memstore snapshot version %d, expected %d", env.Version, snapshotVersion) + } + r.state = env.State + // gob decodes an empty map as nil; re-initialize to keep the repo writable. + if r.state.Constraints == nil { + r.state.Constraints = map[dssmodels.ID]*constraintRecord{} + } + if r.state.Subscriptions == nil { + r.state.Subscriptions = map[dssmodels.ID]*subscriptionRecord{} + } + if r.state.OperationalIntents == nil { + r.state.OperationalIntents = map[dssmodels.ID]*operationalIntentRecord{} + } + if r.state.Availabilities == nil { + r.state.Availabilities = map[dssmodels.Manager]*availabilityRecord{} + } + return nil +} diff --git a/pkg/scd/store/memstore/snapshot_test.go b/pkg/scd/store/memstore/snapshot_test.go new file mode 100644 index 000000000..5f7d24380 --- /dev/null +++ b/pkg/scd/store/memstore/snapshot_test.go @@ -0,0 +1,99 @@ +package memstore + +import ( + "bytes" + "encoding/gob" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/require" +) + +func TestSnapshotRoundTrip(t *testing.T) { + ctx := writeCtx() + src := setUpStore(t) + _, err := src.UpsertConstraint(ctx, sampleConstraint()) + require.NoError(t, err) + _, err = src.UpsertSubscription(ctx, sampleSubscription()) + require.NoError(t, err) + _, err = src.UpsertOperationalIntent(ctx, sampleOperationalIntent()) + require.NoError(t, err) + _, err = src.UpsertUssAvailability(ctx, sampleAvailability()) + require.NoError(t, err) + + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := setUpStore(t) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + opt := cmpopts.EquateApproxTime(0) + + wantCon, err := src.GetConstraint(ctx, constraintId) + require.NoError(t, err) + gotCon, err := dst.GetConstraint(ctx, constraintId) + require.NoError(t, err) + if diff := cmp.Diff(wantCon, gotCon, opt); diff != "" { + t.Errorf("Constraint mismatch (-want +got):\n%s", diff) + } + + wantSub, err := src.GetSubscription(ctx, subscriptionId) + require.NoError(t, err) + gotSub, err := dst.GetSubscription(ctx, subscriptionId) + require.NoError(t, err) + if diff := cmp.Diff(wantSub, gotSub, opt); diff != "" { + t.Errorf("Subscription mismatch (-want +got):\n%s", diff) + } + + wantOI, err := src.GetOperationalIntent(ctx, operationalIntentId) + require.NoError(t, err) + gotOI, err := dst.GetOperationalIntent(ctx, operationalIntentId) + require.NoError(t, err) + if diff := cmp.Diff(wantOI, gotOI, opt); diff != "" { + t.Errorf("OperationalIntent mismatch (-want +got):\n%s", diff) + } + + wantAvail, err := src.GetUssAvailability(ctx, manager) + require.NoError(t, err) + gotAvail, err := dst.GetUssAvailability(ctx, manager) + require.NoError(t, err) + if diff := cmp.Diff(wantAvail, gotAvail, opt); diff != "" { + t.Errorf("UssAvailability mismatch (-want +got):\n%s", diff) + } +} + +func TestRestoreFromSnapshotReplacesState(t *testing.T) { + ctx := writeCtx() + src := setUpStore(t) + _, err := src.UpsertConstraint(ctx, sampleConstraint()) + require.NoError(t, err) + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := setUpStore(t) + other := sampleConstraint() + other.ID = "00000185-e36d-40be-8d38-beca6ca39999" + _, err = dst.UpsertConstraint(ctx, other) + require.NoError(t, err) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + count, err := dst.CountConstraints(ctx) + require.NoError(t, err) + require.Equal(t, int64(1), count) + got, err := dst.GetConstraint(ctx, constraintId) + require.NoError(t, err) + require.NotNil(t, got) + _, err = dst.GetConstraint(ctx, other.ID) + require.Error(t, err) +} + +func TestRestoreFromSnapshotInvalidData(t *testing.T) { + require.Error(t, setUpStore(t).RestoreFromSnapshot([]byte("random value that is definitely not valid"))) +} + +func TestRestoreFromSnapshotVersionMismatch(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion + 1})) + require.Error(t, setUpStore(t).RestoreFromSnapshot(buf.Bytes())) +} diff --git a/pkg/scd/store/memstore/store.go b/pkg/scd/store/memstore/store.go new file mode 100644 index 000000000..f2d6e7e0c --- /dev/null +++ b/pkg/scd/store/memstore/store.go @@ -0,0 +1,218 @@ +package memstore + +import ( + "context" + "time" + + "github.com/golang/geo/s2" + "github.com/interuss/dss/pkg/memstore" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/dss/pkg/scd/repos" + "github.com/interuss/stacktrace" + "go.uber.org/zap" +) + +// repo is a full implementation of scd.repos.Repository for memory-based storage. +type repo struct { + state state +} + +// state is the serializable in-memory state. +type state struct { + // Constraints holds the stored constraints keyed by ID. + Constraints map[dssmodels.ID]*constraintRecord + // Subscriptions holds the stored subscriptions keyed by ID. + Subscriptions map[dssmodels.ID]*subscriptionRecord + // OperationalIntents holds the stored operational intents keyed by ID. + OperationalIntents map[dssmodels.ID]*operationalIntentRecord + // Availabilities holds the stored USS availabilities keyed by USS Manager. + Availabilities map[dssmodels.Manager]*availabilityRecord +} + +// constraintRecord is the gob-serializable representation of a Constraint. The +// model's OVN is never persisted: it is derived from UpdatedAt on read +type constraintRecord struct { + ID dssmodels.ID + Manager dssmodels.Manager + Version scdmodels.VersionNumber + StartTime *time.Time + EndTime *time.Time + USSBaseURL string + AltitudeLower *float32 + AltitudeUpper *float32 + Cells s2.CellUnion + UpdatedAt time.Time +} + +// subscriptionRecord is the gob-serializable representation of a Subscription. +// The sqlstore stores the version column but always writes 0 and discards it on +// read (the model Version is derived from UpdatedAt), so it is not kept here. +type subscriptionRecord struct { + ID dssmodels.ID + Manager dssmodels.Manager + NotificationIndex int + USSBaseURL string + NotifyForOperationalIntents bool + NotifyForConstraints bool + ImplicitSubscription bool + StartTime *time.Time + EndTime *time.Time + Cells s2.CellUnion + UpdatedAt time.Time +} + +// operationalIntentRecord is the gob-serializable representation of an +// OperationalIntent. USSRequestedOVN is empty when the OVN is DSS-generated. +type operationalIntentRecord struct { + ID dssmodels.ID + Manager dssmodels.Manager + Version scdmodels.VersionNumber + State scdmodels.OperationalIntentState + StartTime *time.Time + EndTime *time.Time + USSBaseURL string + SubscriptionID *dssmodels.ID + AltitudeLower *float32 + AltitudeUpper *float32 + Cells s2.CellUnion + USSRequestedOVN string + PastOVNs []scdmodels.OVN + UpdatedAt time.Time +} + +// availabilityRecord is the gob-serializable representation of a +// UssAvailabilityStatus. The model's Version is derived from UpdatedAt on read. +type availabilityRecord struct { + Uss dssmodels.Manager + Availability scdmodels.UssAvailabilityState + UpdatedAt time.Time +} + +func newRepo() *repo { + r := &repo{} + r.resetState() + return r +} + +func (r *repo) resetState() { + r.state = state{ + Constraints: map[dssmodels.ID]*constraintRecord{}, + Subscriptions: map[dssmodels.ID]*subscriptionRecord{}, + OperationalIntents: map[dssmodels.ID]*operationalIntentRecord{}, + Availabilities: map[dssmodels.Manager]*availabilityRecord{}, + } +} + +func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { + return memstore.Init(ctx, logger, "scd", newRepo()) +} + +func (r *repo) GetRepo() repos.Repository { return r } + +// cellSet builds a lookup set from a cell union. +func cellSet(cells s2.CellUnion) map[s2.CellID]struct{} { + set := make(map[s2.CellID]struct{}, len(cells)) + for _, c := range cells { + set[c] = struct{}{} + } + return set +} + +// overlaps reports whether any cell is present in set (equivalent to the SQL +// "cells && $x" array-overlap operator). +func overlaps(cells s2.CellUnion, set map[s2.CellID]struct{}) bool { + for _, c := range cells { + if _, ok := set[c]; ok { + return true + } + } + return false +} + +func cloneCells(cells s2.CellUnion) s2.CellUnion { + if cells == nil { + return nil + } + return append(s2.CellUnion(nil), cells...) +} + +func cloneTime(t *time.Time) *time.Time { + if t == nil { + return nil + } + v := *t + return &v +} + +func cloneFloat32(f *float32) *float32 { + if f == nil { + return nil + } + v := *f + return &v +} + +func cloneID(id *dssmodels.ID) *dssmodels.ID { + if id == nil { + return nil + } + v := *id + return &v +} + +func clonePastOVNs(ovns []scdmodels.OVN) []scdmodels.OVN { + if ovns == nil { + return nil + } + return append([]scdmodels.OVN(nil), ovns...) +} + +// clone returns a copy of s with independent maps and records. Cell slices, +// time pointers and OVN slices are shared, as they are never mutated in place. +func (s state) clone() state { + constraints := make(map[dssmodels.ID]*constraintRecord, len(s.Constraints)) + for id, rec := range s.Constraints { + cp := *rec + constraints[id] = &cp + } + subs := make(map[dssmodels.ID]*subscriptionRecord, len(s.Subscriptions)) + for id, rec := range s.Subscriptions { + cp := *rec + subs[id] = &cp + } + ois := make(map[dssmodels.ID]*operationalIntentRecord, len(s.OperationalIntents)) + for id, rec := range s.OperationalIntents { + cp := *rec + ois[id] = &cp + } + avails := make(map[dssmodels.Manager]*availabilityRecord, len(s.Availabilities)) + for id, rec := range s.Availabilities { + cp := *rec + avails[id] = &cp + } + return state{ + Constraints: constraints, + Subscriptions: subs, + OperationalIntents: ois, + Availabilities: avails, + } +} + +// Checkpoint returns a fast, restorable in-memory copy of the current state. +// Unlike GetSnapshot it does not serialize, so it is cheap but only valid +// in-process. +func (r *repo) Checkpoint() any { + return r.state.clone() +} + +// Restore replaces the current state with a checkpoint previously returned by +// Checkpoint. The checkpoint is copied, so it stays reusable. +func (r *repo) Restore(cp any) error { + s, ok := cp.(state) + if !ok { + return stacktrace.NewError("Invalid checkpoint type %T", cp) + } + r.state = s.clone() + return nil +} diff --git a/pkg/scd/store/memstore/store_test.go b/pkg/scd/store/memstore/store_test.go new file mode 100644 index 000000000..b60afdafe --- /dev/null +++ b/pkg/scd/store/memstore/store_test.go @@ -0,0 +1,170 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/golang/geo/s2" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/dss/pkg/timestamp" + "github.com/stretchr/testify/require" +) + +var ( + manager = dssmodels.Manager("unittest") + + constraintId = dssmodels.ID("00000185-e36d-40be-8d38-beca6ca31000") + subscriptionId = dssmodels.ID("00000185-e36d-40be-8d38-beca6ca31001") + operationalIntentId = dssmodels.ID("00000185-e36d-40be-8d38-beca6ca31002") + + cells = s2.CellUnion{ + s2.CellID(int64(8768904281496485888)), + s2.CellID(int64(8768904178417270784)), + } + + startTime = time.Date(2024, time.August, 14, 15, 48, 36, 0, time.UTC) + endTime = startTime.Add(time.Hour) + writeTime = time.Date(2024, time.August, 1, 0, 0, 0, 0, time.UTC) + + altLow, altHigh float32 = 84, 169 +) + +// setUpStore returns a fresh in-memory repo. +func setUpStore(t *testing.T) *repo { + t.Helper() + return newRepo() +} + +// writeCtx returns a context carrying a deterministic write timestamp so that +// updated_at is controlled in tests. +func writeCtx() context.Context { + return timestamp.WithTimestamp(context.Background(), writeTime) +} + +func sampleConstraint() *scdmodels.Constraint { + return &scdmodels.Constraint{ + ID: constraintId, + Manager: manager, + Version: 1, + StartTime: &startTime, + EndTime: &endTime, + USSBaseURL: "https://dummy.uss", + AltitudeLower: &altLow, + AltitudeUpper: &altHigh, + Cells: cells, + } +} + +func sampleSubscription() *scdmodels.Subscription { + return &scdmodels.Subscription{ + ID: subscriptionId, + Manager: manager, + NotificationIndex: 1, + USSBaseURL: "https://dummy.uss", + NotifyForOperationalIntents: true, + NotifyForConstraints: true, + StartTime: &startTime, + EndTime: &endTime, + Cells: cells, + } +} + +func sampleOperationalIntent() *scdmodels.OperationalIntent { + sid := subscriptionId + return &scdmodels.OperationalIntent{ + ID: operationalIntentId, + Manager: manager, + Version: 1, + State: scdmodels.OperationalIntentStateAccepted, + StartTime: &startTime, + EndTime: &endTime, + USSBaseURL: "https://dummy.uss", + SubscriptionID: &sid, + AltitudeLower: &altLow, + AltitudeUpper: &altHigh, + Cells: cells, + } +} + +func sampleAvailability() *scdmodels.UssAvailabilityStatus { + return &scdmodels.UssAvailabilityStatus{ + Uss: manager, + Availability: scdmodels.UssAvailabilityStateNormal, + } +} + +// volume4D builds a Volume4D whose footprint covers the provided cells. +func volume4D(cu s2.CellUnion, start, end *time.Time, altLo, altHi *float32) *dssmodels.Volume4D { + return &dssmodels.Volume4D{ + StartTime: start, + EndTime: end, + SpatialVolume: &dssmodels.Volume3D{ + AltitudeLo: altLo, + AltitudeHi: altHi, + Footprint: dssmodels.GeometryFunc(func() (s2.CellUnion, error) { + return cu, nil + }), + }, + } +} + +func TestCheckpointRestoreRoundTrip(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + + _, err := r.UpsertConstraint(ctx, sampleConstraint()) + require.NoError(t, err) + _, err = r.UpsertSubscription(ctx, sampleSubscription()) + require.NoError(t, err) + _, err = r.UpsertOperationalIntent(ctx, sampleOperationalIntent()) + require.NoError(t, err) + _, err = r.UpsertUssAvailability(ctx, sampleAvailability()) + require.NoError(t, err) + + cp := r.Checkpoint() + + // Mutate after the checkpoint. + require.NoError(t, r.DeleteConstraint(ctx, constraintId)) + require.NoError(t, r.DeleteSubscription(ctx, subscriptionId)) + require.NoError(t, r.DeleteOperationalIntent(ctx, operationalIntentId)) + + // Restore brings everything back. + require.NoError(t, r.Restore(cp)) + + con, err := r.GetConstraint(ctx, constraintId) + require.NoError(t, err) + require.NotNil(t, con) + sub, err := r.GetSubscription(ctx, subscriptionId) + require.NoError(t, err) + require.NotNil(t, sub) + oi, err := r.GetOperationalIntent(ctx, operationalIntentId) + require.NoError(t, err) + require.NotNil(t, oi) +} + +func TestCheckpointIsolatesNotificationIndex(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + + sub, err := r.UpsertSubscription(ctx, sampleSubscription()) + require.NoError(t, err) + + cp := r.Checkpoint() + + // In-place notification-index bump must not leak into the checkpoint. + bumped, err := r.IncrementNotificationIndicesForOperationalIntents(ctx, volume4D(cells, nil, nil, nil, nil)) + require.NoError(t, err) + require.Len(t, bumped, 1) + require.Equal(t, sub.NotificationIndex+1, bumped[0].NotificationIndex) + + require.NoError(t, r.Restore(cp)) + restored, err := r.GetSubscription(ctx, subscriptionId) + require.NoError(t, err) + require.Equal(t, sub.NotificationIndex, restored.NotificationIndex) +} + +func TestRestoreInvalidType(t *testing.T) { + require.Error(t, setUpStore(t).Restore("not a checkpoint")) +} diff --git a/pkg/scd/store/memstore/subscriptions.go b/pkg/scd/store/memstore/subscriptions.go new file mode 100644 index 000000000..7a483aa94 --- /dev/null +++ b/pkg/scd/store/memstore/subscriptions.go @@ -0,0 +1,197 @@ +package memstore + +import ( + "context" + "time" + + "github.com/golang/geo/s2" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/dss/pkg/timestamp" + "github.com/interuss/stacktrace" +) + +func (rec *subscriptionRecord) toModel() *scdmodels.Subscription { + return &scdmodels.Subscription{ + ID: rec.ID, + Version: scdmodels.NewOVNFromTime(rec.UpdatedAt, rec.ID.String()), + NotificationIndex: rec.NotificationIndex, + Manager: rec.Manager, + StartTime: cloneTime(rec.StartTime), + EndTime: cloneTime(rec.EndTime), + USSBaseURL: rec.USSBaseURL, + NotifyForOperationalIntents: rec.NotifyForOperationalIntents, + NotifyForConstraints: rec.NotifyForConstraints, + ImplicitSubscription: rec.ImplicitSubscription, + Cells: cloneCells(rec.Cells), + } +} + +// SearchSubscriptions implements scd.repos.Subscription.SearchSubscriptions. +func (r *repo) SearchSubscriptions(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { + cells, err := v4d.CalculateSpatialCovering() + if err != nil { + return nil, stacktrace.Propagate(err, "Could not calculate spatial covering") + } + if len(cells) == 0 { + return nil, nil + } + + want := cellSet(cells) + var out []*scdmodels.Subscription + for _, rec := range r.state.Subscriptions { + if !overlaps(rec.Cells, want) { + continue + } + // COALESCE(starts_at <= $3, true) with $3 = v4d.EndTime + if rec.StartTime != nil && v4d.EndTime != nil && rec.StartTime.After(*v4d.EndTime) { + continue + } + // COALESCE(ends_at >= $2, true) with $2 = v4d.StartTime + if rec.EndTime != nil && v4d.StartTime != nil && rec.EndTime.Before(*v4d.StartTime) { + continue + } + out = append(out, rec.toModel()) + if len(out) >= dssmodels.MaxResultLimit { // mirror SQL "LIMIT MaxResultLimit" + break + } + } + return out, nil +} + +// GetSubscription implements scd.repos.Subscription.GetSubscription. +func (r *repo) GetSubscription(_ context.Context, id dssmodels.ID) (*scdmodels.Subscription, error) { + rec, ok := r.state.Subscriptions[id] + if !ok { + return nil, nil + } + return rec.toModel(), nil +} + +// UpsertSubscription implements scd.repos.Subscription.UpsertSubscription. +func (r *repo) UpsertSubscription(ctx context.Context, s *scdmodels.Subscription) (*scdmodels.Subscription, error) { + rec := &subscriptionRecord{ + ID: s.ID, + Manager: s.Manager, + NotificationIndex: s.NotificationIndex, + USSBaseURL: s.USSBaseURL, + NotifyForOperationalIntents: s.NotifyForOperationalIntents, + NotifyForConstraints: s.NotifyForConstraints, + ImplicitSubscription: s.ImplicitSubscription, + StartTime: cloneTime(s.StartTime), + EndTime: cloneTime(s.EndTime), + Cells: cloneCells(s.Cells), + UpdatedAt: timestamp.NowFromContext(ctx), + } + r.state.Subscriptions[s.ID] = rec + return rec.toModel(), nil +} + +// DeleteSubscription implements scd.repos.Subscription.DeleteSubscription. +func (r *repo) DeleteSubscription(_ context.Context, id dssmodels.ID) error { + if _, ok := r.state.Subscriptions[id]; !ok { + return stacktrace.NewError("Attempted to delete non-existent Subscription") + } + delete(r.state.Subscriptions, id) + return nil +} + +// IncrementNotificationIndicesForOperationalIntents implements +// scd.repos.Subscription.IncrementNotificationIndicesForOperationalIntents. +func (r *repo) IncrementNotificationIndicesForOperationalIntents(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { + cells, err := v4d.CalculateSpatialCovering() + if err != nil { + return nil, stacktrace.Propagate(err, "Could not calculate spatial covering") + } + if len(cells) == 0 { + return nil, nil + } + + want := cellSet(cells) + var out []*scdmodels.Subscription + for _, rec := range r.state.Subscriptions { + if !overlaps(rec.Cells, want) { + continue + } + if !rec.NotifyForOperationalIntents { + continue + } + // COALESCE(starts_at <= $3, true) with $3 = v4d.EndTime + if rec.StartTime != nil && v4d.EndTime != nil && rec.StartTime.After(*v4d.EndTime) { + continue + } + // COALESCE(ends_at >= $2, true) with $2 = v4d.StartTime + if rec.EndTime != nil && v4d.StartTime != nil && rec.EndTime.Before(*v4d.StartTime) { + continue + } + rec.NotificationIndex++ + out = append(out, rec.toModel()) + } + return out, nil +} + +// IncrementNotificationIndicesForConstraints implements +// scd.repos.Subscription.IncrementNotificationIndicesForConstraints. +func (r *repo) IncrementNotificationIndicesForConstraints(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { + cells, err := v4d.CalculateSpatialCovering() + if err != nil { + return nil, stacktrace.Propagate(err, "Could not calculate spatial covering") + } + if len(cells) == 0 { + return nil, nil + } + + want := cellSet(cells) + var out []*scdmodels.Subscription + for _, rec := range r.state.Subscriptions { + if !overlaps(rec.Cells, want) { + continue + } + if !rec.NotifyForConstraints { + continue + } + // COALESCE(starts_at <= $3, true) with $3 = v4d.EndTime + if rec.StartTime != nil && v4d.EndTime != nil && rec.StartTime.After(*v4d.EndTime) { + continue + } + // COALESCE(ends_at >= $2, true) with $2 = v4d.StartTime + if rec.EndTime != nil && v4d.StartTime != nil && rec.EndTime.Before(*v4d.StartTime) { + continue + } + rec.NotificationIndex++ + out = append(out, rec.toModel()) + } + return out, nil +} + +// LockSubscriptionsOnCells implements scd.repos.Subscription.LockSubscriptionsOnCells. +func (r *repo) LockSubscriptionsOnCells(_ context.Context, _ s2.CellUnion, _ []dssmodels.ID, _ *time.Time, _ *time.Time) error { + // For the memory store, that a no-op + return nil +} + +// ListExpiredSubscriptions implements scd.repos.Subscription.ListExpiredSubscriptions. +func (r *repo) ListExpiredSubscriptions(_ context.Context, threshold time.Time) ([]*scdmodels.Subscription, error) { + var out []*scdmodels.Subscription + for _, rec := range r.state.Subscriptions { + // (ends_at IS NOT NULL AND ends_at <= threshold) OR (ends_at IS NULL AND updated_at <= threshold) + var expired bool + if rec.EndTime != nil { + expired = !rec.EndTime.After(threshold) + } else { + expired = !rec.UpdatedAt.After(threshold) + } + if !expired { + continue + } + out = append(out, rec.toModel()) + if len(out) >= dssmodels.MaxResultLimit { // mirror SQL "LIMIT MaxResultLimit" + break + } + } + return out, nil +} + +func (r *repo) CountSubscriptions(_ context.Context) (int64, error) { + return int64(len(r.state.Subscriptions)), nil +} diff --git a/pkg/scd/store/memstore/subscriptions_test.go b/pkg/scd/store/memstore/subscriptions_test.go new file mode 100644 index 000000000..2176e411b --- /dev/null +++ b/pkg/scd/store/memstore/subscriptions_test.go @@ -0,0 +1,234 @@ +package memstore + +import ( + "testing" + "time" + + "github.com/golang/geo/s2" + dssmodels "github.com/interuss/dss/pkg/models" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/stretchr/testify/require" +) + +func TestSubscriptionUpsertGetDelete(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + + got, err := r.UpsertSubscription(ctx, sampleSubscription()) + require.NoError(t, err) + require.Equal(t, subscriptionId, got.ID) + require.Equal(t, 1, got.NotificationIndex) + require.NotEmpty(t, got.Version) + + fetched, err := r.GetSubscription(ctx, subscriptionId) + require.NoError(t, err) + require.Equal(t, got.Version, fetched.Version) + require.True(t, fetched.NotifyForOperationalIntents) + + count, err := r.CountSubscriptions(ctx) + require.NoError(t, err) + require.Equal(t, int64(1), count) + + require.NoError(t, r.DeleteSubscription(ctx, subscriptionId)) + gone, err := r.GetSubscription(ctx, subscriptionId) + require.NoError(t, err) + require.Nil(t, gone) +} + +func TestSubscriptionGetMissingReturnsNil(t *testing.T) { + r := setUpStore(t) + got, err := r.GetSubscription(writeCtx(), subscriptionId) + require.NoError(t, err) + require.Nil(t, got) +} + +func TestSubscriptionDeleteMissingErrors(t *testing.T) { + r := setUpStore(t) + require.Error(t, r.DeleteSubscription(writeCtx(), subscriptionId)) +} + +func TestSearchSubscriptions(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + _, err := r.UpsertSubscription(ctx, sampleSubscription()) + require.NoError(t, err) + + res, err := r.SearchSubscriptions(ctx, volume4D(cells, nil, nil, nil, nil)) + require.NoError(t, err) + require.Len(t, res, 1) + + // No covering cells returns nil. + res, err = r.SearchSubscriptions(ctx, volume4D(s2.CellUnion{}, nil, nil, nil, nil)) + require.NoError(t, err) + require.Nil(t, res) +} + +func TestIncrementNotificationIndicesForOperationalIntents(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + + // notify_for_operations = true. + opSub := sampleSubscription() + _, err := r.UpsertSubscription(ctx, opSub) + require.NoError(t, err) + + // A second subscription that only wants constraint notifications must be skipped. + conSub := sampleSubscription() + conSub.ID = "00000185-e36d-40be-8d38-beca6ca31aaa" + conSub.NotifyForOperationalIntents = false + conSub.NotifyForConstraints = true + _, err = r.UpsertSubscription(ctx, conSub) + require.NoError(t, err) + + got, err := r.IncrementNotificationIndicesForOperationalIntents(ctx, volume4D(cells, nil, nil, nil, nil)) + require.NoError(t, err) + require.Len(t, got, 1) + require.Equal(t, opSub.ID, got[0].ID) + require.Equal(t, opSub.NotificationIndex+1, got[0].NotificationIndex) + + // The bump is persisted. + fetched, err := r.GetSubscription(ctx, opSub.ID) + require.NoError(t, err) + require.Equal(t, opSub.NotificationIndex+1, fetched.NotificationIndex) + + // The constraint-only subscription was untouched. + other, err := r.GetSubscription(ctx, conSub.ID) + require.NoError(t, err) + require.Equal(t, conSub.NotificationIndex, other.NotificationIndex) + + // No covering cells returns nil. + got, err = r.IncrementNotificationIndicesForOperationalIntents(ctx, volume4D(s2.CellUnion{}, nil, nil, nil, nil)) + require.NoError(t, err) + require.Nil(t, got) +} + +func TestIncrementNotificationIndicesForConstraints(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + + // notify_for_constraints = true (sample sets both notify flags). + conSub := sampleSubscription() + _, err := r.UpsertSubscription(ctx, conSub) + require.NoError(t, err) + + // A subscription that does not want constraint notifications must be skipped. + opSub := sampleSubscription() + opSub.ID = "00000185-e36d-40be-8d38-beca6ca31bbb" + opSub.NotifyForConstraints = false + _, err = r.UpsertSubscription(ctx, opSub) + require.NoError(t, err) + + got, err := r.IncrementNotificationIndicesForConstraints(ctx, volume4D(cells, nil, nil, nil, nil)) + require.NoError(t, err) + require.Len(t, got, 1) + require.Equal(t, conSub.ID, got[0].ID) + require.Equal(t, conSub.NotificationIndex+1, got[0].NotificationIndex) + + other, err := r.GetSubscription(ctx, opSub.ID) + require.NoError(t, err) + require.Equal(t, opSub.NotificationIndex, other.NotificationIndex) +} + +func TestLockSubscriptionsOnCellsNoop(t *testing.T) { + r := setUpStore(t) + require.NoError(t, r.LockSubscriptionsOnCells(writeCtx(), cells, []dssmodels.ID{subscriptionId}, nil, nil)) +} + +var ( + sub1ID = dssmodels.ID("189ec22f-5e61-418a-940b-36de2d201fd5") + sub2ID = dssmodels.ID("78f98cc5-94f3-4c04-8da9-a8398feba3f3") + sub3ID = dssmodels.ID("9f0d4575-b275-4a4c-a261-e1e04d324565") +) + +var ( + sub1 = &scdmodels.Subscription{ + ID: sub1ID, + NotificationIndex: 1, + Manager: "unittest", + StartTime: &start1, + EndTime: &end1, + USSBaseURL: "https://dummy.uss", + NotifyForOperationalIntents: true, + NotifyForConstraints: false, + ImplicitSubscription: true, + Cells: cells, + } + sub2 = &scdmodels.Subscription{ + ID: sub2ID, + NotificationIndex: 1, + Manager: "unittest", + StartTime: &start2, + EndTime: &end2, + USSBaseURL: "https://dummy.uss", + NotifyForOperationalIntents: true, + NotifyForConstraints: false, + ImplicitSubscription: true, + Cells: cells, + } + sub3 = &scdmodels.Subscription{ + ID: sub3ID, + NotificationIndex: 1, + Manager: "unittest", + StartTime: &start3, + EndTime: &end3, + USSBaseURL: "https://dummy.uss", + NotifyForOperationalIntents: true, + NotifyForConstraints: false, + ImplicitSubscription: true, + Cells: cells, + } +) + +func TestListExpiredSubscriptions(t *testing.T) { + ctx := writeCtx() + r := setUpStore(t) + + _, err := r.UpsertSubscription(ctx, sub1) + require.NoError(t, err) + + _, err = r.UpsertSubscription(ctx, sub2) + require.NoError(t, err) + + _, err = r.UpsertSubscription(ctx, sub3) + require.NoError(t, err) + + testCases := []struct { + name string + timeRef time.Time + ttl time.Duration + expired []dssmodels.ID + }{{ + name: "none expired, one in close past", + timeRef: time.Date(2024, time.August, 25, 15, 0, 0, 0, time.UTC), + ttl: time.Hour * 24 * 30, + expired: []dssmodels.ID{}, + }, { + name: "one recently expired, one current, one in future", + timeRef: time.Date(2024, time.September, 15, 16, 0, 0, 0, time.UTC), + ttl: time.Hour * 24 * 30, + expired: []dssmodels.ID{sub1ID}, + }, { + name: "two expired, one in future", + timeRef: time.Date(2024, time.September, 16, 16, 0, 0, 0, time.UTC), + ttl: time.Hour * 2, + expired: []dssmodels.ID{sub1ID, sub2ID}, + }, { + name: "all expired", + timeRef: time.Date(2024, time.December, 15, 15, 0, 0, 0, time.UTC), + ttl: time.Hour * 24 * 30, + expired: []dssmodels.ID{sub1ID, sub2ID, sub3ID}, + }} + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + threshold := testCase.timeRef.Add(-testCase.ttl) + expired, err := r.ListExpiredSubscriptions(ctx, threshold) + require.NoError(t, err) + + expiredIDs := make([]dssmodels.ID, 0, len(expired)) + for _, expiredSub := range expired { + expiredIDs = append(expiredIDs, expiredSub.ID) + } + require.ElementsMatch(t, expiredIDs, testCase.expired) + }) + } +} diff --git a/pkg/scd/store/raftstore/availability.go b/pkg/scd/store/raftstore/availability.go index 7e45f9c4b..8c361f0b2 100644 --- a/pkg/scd/store/raftstore/availability.go +++ b/pkg/scd/store/raftstore/availability.go @@ -3,16 +3,43 @@ package raftstore import ( "context" - dsserr "github.com/interuss/dss/pkg/errors" dssmodels "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/stacktrace" ) -func (r *repo) GetUssAvailability(_ context.Context, id dssmodels.Manager) (*scdmodels.UssAvailabilityStatus, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetUssAvailability not implemented for raftstore") +func (r *repo) GetUssAvailability(ctx context.Context, id dssmodels.Manager) (*scdmodels.UssAvailabilityStatus, error) { + result, err := r.consensus.ProposeValue(ctx, string(getUSSAvailability), id, true) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose getUSSAvailability") + } + + if result == nil { + return nil, nil + } + + status, ok := result.(*scdmodels.UssAvailabilityStatus) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return status, nil } -func (r *repo) UpsertUssAvailability(_ context.Context, ussa *scdmodels.UssAvailabilityStatus) (*scdmodels.UssAvailabilityStatus, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpsertUssAvailability not implemented for raftstore") +func (r *repo) UpsertUssAvailability(ctx context.Context, ussa *scdmodels.UssAvailabilityStatus) (*scdmodels.UssAvailabilityStatus, error) { + result, err := r.consensus.ProposeValue(ctx, string(upsertUSSAvailability), ussa, false) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose upsertUSSAvailability") + } + + if result == nil { + return nil, nil + } + + status, ok := result.(*scdmodels.UssAvailabilityStatus) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return status, nil } diff --git a/pkg/scd/store/raftstore/availability_appliers.go b/pkg/scd/store/raftstore/availability_appliers.go new file mode 100644 index 000000000..d3ef4ece0 --- /dev/null +++ b/pkg/scd/store/raftstore/availability_appliers.go @@ -0,0 +1,103 @@ +package raftstore + +import ( + "context" + "encoding/json" + + restapi "github.com/interuss/dss/pkg/api/scdv1" + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + "github.com/interuss/dss/pkg/raftstore/consensus" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/dss/pkg/scd/repos" + "github.com/interuss/stacktrace" + "github.com/jackc/pgx/v5" +) + +func (r *repo) getUSSAvailabilityTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.UssAvailabilityStatusResponse, error) { + var req *restapi.GetUssAvailabilityRequest + if err := json.Unmarshal(proposal.Value, &req); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal get USS availability request") + } + + id := dssmodels.ManagerFromString(req.UssId) + if id == "" { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "UssId not provided") + } + + ussa, err := mem.GetUssAvailability(ctx, id) + if err != nil && err != pgx.ErrNoRows { + return nil, stacktrace.Propagate(err, "Could not get USS availability from repo") + } + if ussa == nil { + return &restapi.UssAvailabilityStatusResponse{ + Status: restapi.UssAvailabilityStatus{ + Availability: restapi.UssAvailabilityState_Unknown, + Uss: id.String(), + }, + Version: "", + }, nil + } + + return &restapi.UssAvailabilityStatusResponse{ + Status: *ussa.ToRest(), + Version: ussa.Version.String(), + }, nil +} + +func (r *repo) setUSSAvailabilityTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.UssAvailabilityStatusResponse, error) { + var req *restapi.SetUssAvailabilityRequest + if err := json.Unmarshal(proposal.Value, &req); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal set USS availability request") + } + + if req.UssId == "" { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "ussID not provided") + } + + availability, err := scdmodels.UssAvailabilityStateFromRest(req.Body.Availability) + if err != nil { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Invalid availability state") + } + + id := dssmodels.ManagerFromString(req.UssId) + version := scdmodels.OVN(req.Body.OldVersion) + + old, err := mem.GetUssAvailability(ctx, id) + if err != nil && err != pgx.ErrNoRows { + return nil, stacktrace.Propagate(err, "Could not get USS availability from repo") + } + + switch { + case old == nil && !version.Empty(): + return nil, stacktrace.NewErrorWithCode(dsserr.AlreadyExists, "availability for USS %s already exists", id.String()) + case old != nil && old.Version != version: + return nil, stacktrace.Propagate( + stacktrace.NewErrorWithCode(dsserr.VersionMismatch, "USS availability version %s is not current", version), + "Current version is %s but client specified version %s", old.Version, version) + } + + cp := r.memStore.Checkpoint() + + ussa, err := mem.UpsertUssAvailability(ctx, &scdmodels.UssAvailabilityStatus{ + Uss: id, + Availability: availability, + }) + if err != nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return nil, stacktrace.Propagate(err, "Could not upsert USS Availability into repo") + } + if ussa == nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return nil, stacktrace.NewError("UpsertUssAvailability returned no USS availability for ID: %s", id) + } + + return &restapi.UssAvailabilityStatusResponse{ + Status: *ussa.ToRest(), + Version: ussa.Version.String(), + }, nil +} diff --git a/pkg/scd/store/raftstore/constraints.go b/pkg/scd/store/raftstore/constraints.go index 0983add14..dff8173f0 100644 --- a/pkg/scd/store/raftstore/constraints.go +++ b/pkg/scd/store/raftstore/constraints.go @@ -3,28 +3,84 @@ package raftstore import ( "context" - dsserr "github.com/interuss/dss/pkg/errors" dssmodels "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/stacktrace" ) -func (r *repo) SearchConstraints(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Constraint, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchConstraints not implemented for raftstore") +func (r *repo) SearchConstraints(ctx context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Constraint, error) { + result, err := r.consensus.ProposeValue(ctx, string(searchConstraints), v4d, true) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose searchConstraints") + } + + if result == nil { + return nil, nil + } + + constraints, ok := result.([]*scdmodels.Constraint) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return constraints, nil } -func (r *repo) GetConstraint(_ context.Context, id dssmodels.ID) (*scdmodels.Constraint, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetConstraint not implemented for raftstore") +func (r *repo) GetConstraint(ctx context.Context, id dssmodels.ID) (*scdmodels.Constraint, error) { + result, err := r.consensus.ProposeValue(ctx, string(getConstraint), id, true) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose getConstraint") + } + + if result == nil { + return nil, nil + } + + constraint, ok := result.(*scdmodels.Constraint) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return constraint, nil } -func (r *repo) UpsertConstraint(_ context.Context, constraint *scdmodels.Constraint) (*scdmodels.Constraint, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpsertConstraint not implemented for raftstore") +func (r *repo) UpsertConstraint(ctx context.Context, constraint *scdmodels.Constraint) (*scdmodels.Constraint, error) { + result, err := r.consensus.ProposeValue(ctx, string(upsertConstraint), constraint, false) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose upsertConstraint") + } + + if result == nil { + return nil, nil + } + + upserted, ok := result.(*scdmodels.Constraint) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return upserted, nil } -func (r *repo) DeleteConstraint(_ context.Context, id dssmodels.ID) error { - return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteConstraint not implemented for raftstore") +func (r *repo) DeleteConstraint(ctx context.Context, id dssmodels.ID) error { + _, err := r.consensus.ProposeValue(ctx, string(deleteConstraint), id, false) + return err } -func (r *repo) CountConstraints(_ context.Context) (int64, error) { - return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountConstraint not implemented for raftstore") +func (r *repo) CountConstraints(ctx context.Context) (int64, error) { + result, err := r.consensus.ProposeValue(ctx, string(countConstraints), nil, true) + if err != nil { + return 0, stacktrace.Propagate(err, "failed to propose countConstraints") + } + + if result == nil { + return 0, nil + } + + count, ok := result.(int64) + if !ok { + return 0, stacktrace.NewError("invalid result type: %T", result) + } + + return count, nil } diff --git a/pkg/scd/store/raftstore/constraints_appliers.go b/pkg/scd/store/raftstore/constraints_appliers.go new file mode 100644 index 000000000..051cd7a35 --- /dev/null +++ b/pkg/scd/store/raftstore/constraints_appliers.go @@ -0,0 +1,264 @@ +package raftstore + +import ( + "context" + "encoding/json" + "time" + + "github.com/golang/geo/s2" + restapi "github.com/interuss/dss/pkg/api/scdv1" + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + "github.com/interuss/dss/pkg/raftstore/consensus" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/dss/pkg/scd/repos" + "github.com/interuss/stacktrace" + "github.com/jackc/pgx/v5" +) + +type UpsertConstraintTransactionPayload struct { + Manager dssmodels.Manager + ID dssmodels.ID + Ovn scdmodels.OVN + USSBaseURL string + StartTime *time.Time + EndTime *time.Time + AltitudeLo *float32 + AltitudeHi *float32 + Cells s2.CellUnion +} + +func (r *repo) deleteConstraintTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.ChangeConstraintReferenceResponse, error) { + var req *restapi.DeleteConstraintReferenceRequest + if err := json.Unmarshal(proposal.Value, &req); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal delete constraint reference request") + } + + id, err := dssmodels.IDFromString(string(req.Entityid)) + if err != nil { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Invalid ID format: `%s`", req.Entityid) + } + + ovn := scdmodels.OVN(req.Ovn) + if ovn == "" { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Missing OVN for constraint to modify") + } + + if req.Auth.ClientID == nil { + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Missing manager") + } + + old, err := mem.GetConstraint(ctx, id) + switch { + case err == pgx.ErrNoRows: + return nil, stacktrace.NewErrorWithCode(dsserr.NotFound, "Constraint %s not found", id) + case err != nil: + return nil, stacktrace.Propagate(err, "Unable to get Constraint from repo") + case old == nil: + return nil, stacktrace.NewErrorWithCode(dsserr.NotFound, "Constraint %s not found", id) + case old.Manager != dssmodels.Manager(*req.Auth.ClientID): + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, + "Constraint owned by %s, but %s attempted to delete", old.Manager, *req.Auth.ClientID) + case old.OVN != ovn: + return nil, stacktrace.NewErrorWithCode(dsserr.VersionMismatch, + "Current version is %s but client specified version %s", old.OVN, ovn) + } + + notifyVolume := &dssmodels.Volume4D{ + StartTime: old.StartTime, + EndTime: old.EndTime, + SpatialVolume: &dssmodels.Volume3D{ + AltitudeHi: old.AltitudeUpper, + AltitudeLo: old.AltitudeLower, + Footprint: dssmodels.GeometryFunc(func() (s2.CellUnion, error) { + return old.Cells, nil + }), + }, + } + + cp := r.memStore.Checkpoint() + + if err := mem.DeleteConstraint(ctx, id); err != nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return nil, stacktrace.Propagate(err, "Unable to delete Constraint from repo") + } + + subs, err := mem.IncrementNotificationIndicesForConstraints(ctx, notifyVolume) + if err != nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return nil, stacktrace.Propagate(err, "Unable to increment notification indices") + } + + return &restapi.ChangeConstraintReferenceResponse{ + ConstraintReference: *old.ToRest(), + Subscribers: repos.MakeSubscribersToNotify(subs), + }, nil +} + +func (r *repo) getConstraintTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.GetConstraintReferenceResponse, error) { + var req *restapi.GetConstraintReferenceRequest + if err := json.Unmarshal(proposal.Value, &req); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal get constraint reference request") + } + + id, err := dssmodels.IDFromString(string(req.Entityid)) + if err != nil { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Invalid ID format: `%s`", req.Entityid) + } + + if req.Auth.ClientID == nil { + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Missing manager") + } + + constraint, err := mem.GetConstraint(ctx, id) + if err == pgx.ErrNoRows || constraint == nil { + return nil, stacktrace.NewErrorWithCode(dsserr.NotFound, "Constraint %s not found", id) + } + if err != nil { + return nil, stacktrace.Propagate(err, "Unable to get Constraint from repo") + } + + if constraint.Manager != dssmodels.Manager(*req.Auth.ClientID) { + constraint.OVN = scdmodels.NoOvnPhrase + } + + return &restapi.GetConstraintReferenceResponse{ + ConstraintReference: *constraint.ToRest(), + }, nil +} + +func (r *repo) queryConstraintTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.QueryConstraintReferencesResponse, error) { + var req *restapi.QueryConstraintReferencesRequest + if err := json.Unmarshal(proposal.Value, &req); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal query constraint references request") + } + + if req.Auth.ClientID == nil { + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Missing manager") + } + + vol4, err := dssmodels.Volume4DFromSCDRest(req.Body.AreaOfInterest) + if err != nil { + return nil, stacktrace.PropagateWithCode(err, dsserr.BadRequest, "Error parsing geometry") + } + + constraints, err := mem.SearchConstraints(ctx, vol4) + if err != nil { + return nil, stacktrace.Propagate(err, "Unable to query for Constraints in repo") + } + + response := &restapi.QueryConstraintReferencesResponse{ + ConstraintReferences: make([]restapi.ConstraintReference, 0, len(constraints)), + } + for _, constraint := range constraints { + p := constraint.ToRest() + if constraint.Manager != dssmodels.Manager(*req.Auth.ClientID) { + noOvnPhrase := restapi.EntityOVN(scdmodels.NoOvnPhrase) + p.Ovn = &noOvnPhrase + } + response.ConstraintReferences = append(response.ConstraintReferences, *p) + } + + return response, nil +} + +func (r *repo) upsertConstraintTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.ChangeConstraintReferenceResponse, error) { + var payload *UpsertConstraintTransactionPayload + if err := json.Unmarshal(proposal.Value, &payload); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal upsert constraint request") + } + + uExtent := &dssmodels.Volume4D{ + StartTime: payload.StartTime, + EndTime: payload.EndTime, + SpatialVolume: &dssmodels.Volume3D{ + AltitudeHi: payload.AltitudeHi, + AltitudeLo: payload.AltitudeLo, + Footprint: dssmodels.GeometryFunc(func() (s2.CellUnion, error) { + return payload.Cells, nil + }), + }, + } + + old, err := mem.GetConstraint(ctx, payload.ID) + if err != nil && err != pgx.ErrNoRows { + return nil, stacktrace.Propagate(err, "Could not get Constraint from repo") + } + + version := scdmodels.VersionNumber(1) + if old == nil { + if payload.Ovn != "" { + return nil, stacktrace.NewErrorWithCode(dsserr.VersionMismatch, "Old version %s does not exist", payload.Ovn) + } + } else { + if old.Manager != payload.Manager { + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, + "Constraint owned by %s, but %s attempted to modify", old.Manager, payload.Manager) + } + if old.OVN != payload.Ovn { + return nil, stacktrace.NewErrorWithCode(dsserr.VersionMismatch, + "Current version is %s but client specified version %s", old.OVN, payload.Ovn) + } + version = old.Version + 1 + } + + var notifyVol4 *dssmodels.Volume4D + if old == nil { + notifyVol4 = uExtent + } else { + oldVol4 := &dssmodels.Volume4D{ + StartTime: old.StartTime, + EndTime: old.EndTime, + SpatialVolume: &dssmodels.Volume3D{ + AltitudeHi: old.AltitudeUpper, + AltitudeLo: old.AltitudeLower, + Footprint: dssmodels.GeometryFunc(func() (s2.CellUnion, error) { + return old.Cells, nil + }), + }, + } + notifyVol4, err = dssmodels.UnionVolumes4D(uExtent, oldVol4) + if err != nil { + return nil, stacktrace.Propagate(err, "Error constructing 4D volumes union") + } + } + + constraint := &scdmodels.Constraint{ + ID: payload.ID, + Manager: payload.Manager, + Version: version, + StartTime: uExtent.StartTime, + EndTime: uExtent.EndTime, + AltitudeLower: uExtent.SpatialVolume.AltitudeLo, + AltitudeUpper: uExtent.SpatialVolume.AltitudeHi, + USSBaseURL: payload.USSBaseURL, + Cells: payload.Cells, + } + + cp := r.memStore.Checkpoint() + + constraint, err = mem.UpsertConstraint(ctx, constraint) + if err != nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return nil, stacktrace.Propagate(err, "Failed to upsert Constraint in repo") + } + + subs, err := mem.IncrementNotificationIndicesForConstraints(ctx, notifyVol4) + if err != nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return nil, stacktrace.Propagate(err, "Unable to increment notification indices") + } + + return &restapi.ChangeConstraintReferenceResponse{ + ConstraintReference: *constraint.ToRest(), + Subscribers: repos.MakeSubscribersToNotify(subs), + }, nil +} diff --git a/pkg/scd/store/raftstore/operational_intents.go b/pkg/scd/store/raftstore/operational_intents.go index 3a6af6acd..b0565a2ba 100644 --- a/pkg/scd/store/raftstore/operational_intents.go +++ b/pkg/scd/store/raftstore/operational_intents.go @@ -4,36 +4,120 @@ import ( "context" "time" - dsserr "github.com/interuss/dss/pkg/errors" dssmodels "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/stacktrace" ) -func (r *repo) GetOperationalIntent(_ context.Context, id dssmodels.ID) (*scdmodels.OperationalIntent, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetOperationalIntent not implemented for raftstore") +func (r *repo) GetOperationalIntent(ctx context.Context, id dssmodels.ID) (*scdmodels.OperationalIntent, error) { + result, err := r.consensus.ProposeValue(ctx, string(getOperationalIntent), id, true) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose getOperationalIntent") + } + + if result == nil { + return nil, nil + } + + intent, ok := result.(*scdmodels.OperationalIntent) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return intent, nil } -func (r *repo) DeleteOperationalIntent(_ context.Context, id dssmodels.ID) error { - return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteOperationalIntent not implemented for raftstore") +func (r *repo) DeleteOperationalIntent(ctx context.Context, id dssmodels.ID) error { + _, err := r.consensus.ProposeValue(ctx, string(deleteOperationalIntent), id, false) + return err } -func (r *repo) UpsertOperationalIntent(_ context.Context, operation *scdmodels.OperationalIntent) (*scdmodels.OperationalIntent, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpsertOperationalIntent not implemented for raftstore") +func (r *repo) UpsertOperationalIntent(ctx context.Context, operation *scdmodels.OperationalIntent) (*scdmodels.OperationalIntent, error) { + result, err := r.consensus.ProposeValue(ctx, string(upsertOperationalIntent), operation, false) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose upsertOperationalIntent") + } + + if result == nil { + return nil, nil + } + + intent, ok := result.(*scdmodels.OperationalIntent) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return intent, nil } -func (r *repo) SearchOperationalIntents(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.OperationalIntent, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchOperationalIntents not implemented for raftstore") +func (r *repo) SearchOperationalIntents(ctx context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.OperationalIntent, error) { + result, err := r.consensus.ProposeValue(ctx, string(searchOperationalIntents), v4d, true) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose searchOperationalIntents") + } + + if result == nil { + return nil, nil + } + + intents, ok := result.([]*scdmodels.OperationalIntent) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return intents, nil } -func (r *repo) GetDependentOperationalIntents(_ context.Context, subscriptionID dssmodels.ID) ([]dssmodels.ID, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDependentOperationalIntents not implemented for raftstore") +func (r *repo) GetDependentOperationalIntents(ctx context.Context, subscriptionID dssmodels.ID) ([]dssmodels.ID, error) { + result, err := r.consensus.ProposeValue(ctx, string(getDependentOperationalIntents), subscriptionID, true) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose getDependentOperationalIntents") + } + + if result == nil { + return nil, nil + } + + idList, ok := result.([]dssmodels.ID) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return idList, nil } -func (r *repo) ListExpiredOperationalIntents(_ context.Context, threshold time.Time) ([]*scdmodels.OperationalIntent, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredOperationalIntents not implemented for raftstore") +func (r *repo) ListExpiredOperationalIntents(ctx context.Context, threshold time.Time) ([]*scdmodels.OperationalIntent, error) { + result, err := r.consensus.ProposeValue(ctx, string(listExpiredOperationalIntents), threshold, true) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose listExpiredOperationalIntents") + } + + if result == nil { + return nil, nil + } + + intents, ok := result.([]*scdmodels.OperationalIntent) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return intents, nil } -func (r *repo) CountOperationalIntents(_ context.Context) (int64, error) { - return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountOperationalIntents not implemented for raftstore") +func (r *repo) CountOperationalIntents(ctx context.Context) (int64, error) { + result, err := r.consensus.ProposeValue(ctx, string(countOperationalIntents), nil, true) + if err != nil { + return 0, stacktrace.Propagate(err, "failed to propose countOperationalIntents") + } + + if result == nil { + return 0, nil + } + + count, ok := result.(int64) + if !ok { + return 0, stacktrace.NewError("invalid result type: %T", result) + } + + return count, nil } diff --git a/pkg/scd/store/raftstore/operational_intents_appliers.go b/pkg/scd/store/raftstore/operational_intents_appliers.go new file mode 100644 index 000000000..d4db139a3 --- /dev/null +++ b/pkg/scd/store/raftstore/operational_intents_appliers.go @@ -0,0 +1,358 @@ +package raftstore + +import ( + "context" + "encoding/json" + + "github.com/golang/geo/s2" + restapi "github.com/interuss/dss/pkg/api/scdv1" + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + "github.com/interuss/dss/pkg/raftstore/consensus" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/dss/pkg/scd/repos" + "github.com/interuss/stacktrace" +) + +func (r *repo) deleteOperationalIntentTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.ChangeOperationalIntentReferenceResponse, error) { + var req *restapi.DeleteOperationalIntentReferenceRequest + if err := json.Unmarshal(proposal.Value, &req); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal delete operational intent request") + } + + id, err := dssmodels.IDFromString(string(req.Entityid)) + if err != nil { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Invalid ID format: `%s`", req.Entityid) + } + + ovn := scdmodels.OVN(req.Ovn) + if ovn == "" { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Missing OVN for operational intent to modify") + } + + old, err := mem.GetOperationalIntent(ctx, id) + if err != nil { + return nil, stacktrace.Propagate(err, "Unable to get OperationIntent from repo") + } + if old == nil { + return nil, stacktrace.NewErrorWithCode(dsserr.NotFound, "OperationalIntent %s not found", id) + } + + if old.Manager != dssmodels.Manager(*req.Auth.ClientID) { + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, + "OperationalIntent owned by %s, but %s attempted to delete", old.Manager, *req.Auth.ClientID) + } + + if old.OVN != ovn { + return nil, stacktrace.NewErrorWithCode(dsserr.VersionMismatch, + "Current version is %s but client specified version %s", old.OVN, ovn) + } + + // Get the Subscription supporting the OperationalIntent, if one is defined + var previousSubscription *scdmodels.Subscription + if old.SubscriptionID != nil { + previousSubscription, err = mem.GetSubscription(ctx, *old.SubscriptionID) + if err != nil { + return nil, stacktrace.Propagate(err, "Unable to get OperationalIntent's Subscription from repo") + } + if previousSubscription == nil { + return nil, stacktrace.NewError("OperationalIntent's Subscription missing from repo") + } + } + + removeImplicitSubscription, err := repos.SubscriptionIsImplicitAndOnlyAttachedToOIR(ctx, mem, id, previousSubscription) + if err != nil { + return nil, stacktrace.Propagate(err, "Could not determine if Subscription can be removed") + } + + // Gather the subscriptions that need to be notified + notifyVolume := &dssmodels.Volume4D{ + StartTime: old.StartTime, + EndTime: old.EndTime, + SpatialVolume: &dssmodels.Volume3D{ + AltitudeHi: old.AltitudeUpper, + AltitudeLo: old.AltitudeLower, + Footprint: dssmodels.GeometryFunc(func() (s2.CellUnion, error) { + return old.Cells, nil + }), + }} + + cp := r.memStore.Checkpoint() + subsToNotify, err := repos.GetRelevantSubscriptionsAndIncrementIndices(ctx, mem, notifyVolume) + if err != nil { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + + return nil, stacktrace.Propagate(err, "could not obtain relevant subscriptions") + } + + if err := mem.DeleteOperationalIntent(ctx, id); err != nil { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + + return nil, stacktrace.Propagate(err, "Unable to delete OperationalIntent from repo") + } + + // removeImplicitSubscription is only true if the OIR had a subscription defined + if removeImplicitSubscription { + // Automatically remove a now-unused implicit Subscription + err = mem.DeleteSubscription(ctx, previousSubscription.ID) + if err != nil { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + + return nil, stacktrace.Propagate(err, "Unable to delete associated implicit Subscription") + } + } + + return &restapi.ChangeOperationalIntentReferenceResponse{ + OperationalIntentReference: *old.ToRest(), + Subscribers: repos.MakeSubscribersToNotify(subsToNotify), + }, nil +} + +func (r *repo) getOperationalIntentTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.GetOperationalIntentReferenceResponse, error) { + var req *restapi.GetOperationalIntentReferenceRequest + if err := json.Unmarshal(proposal.Value, &req); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal get operational intent request") + } + + id, err := dssmodels.IDFromString(string(req.Entityid)) + if err != nil { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Invalid ID format: `%s`", req.Entityid) + } + + op, err := mem.GetOperationalIntent(ctx, id) + if err != nil { + return nil, stacktrace.Propagate(err, "Unable to get OperationalIntent from repo") + } + if op == nil { + return nil, stacktrace.NewErrorWithCode(dsserr.NotFound, "OperationalIntent %s not found", id) + } + + if op.Manager != dssmodels.Manager(*req.Auth.ClientID) { + op.OVN = scdmodels.NoOvnPhrase + } + + return &restapi.GetOperationalIntentReferenceResponse{ + OperationalIntentReference: *op.ToRest(), + }, nil +} + +func (r *repo) queryOperationalIntentTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.QueryOperationalIntentReferenceResponse, error) { + var req *restapi.QueryOperationalIntentReferencesRequest + if err := json.Unmarshal(proposal.Value, &req); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal query operational intent request") + } + + vol4, err := dssmodels.Volume4DFromSCDRest(req.Body.AreaOfInterest) + if err != nil { + return nil, stacktrace.PropagateWithCode(err, dsserr.BadRequest, "Error parsing geometry") + } + + ops, err := mem.SearchOperationalIntents(ctx, vol4) + if err != nil { + return nil, stacktrace.Propagate(err, "Unable to query for OperationalIntents in repo") + } + + response := &restapi.QueryOperationalIntentReferenceResponse{ + OperationalIntentReferences: make([]restapi.OperationalIntentReference, 0, len(ops)), + } + for _, op := range ops { + p := op.ToRest() + if op.Manager != dssmodels.Manager(*req.Auth.ClientID) { + noOvnPhrase := restapi.EntityOVN(scdmodels.NoOvnPhrase) + p.Ovn = &noOvnPhrase + } + response.OperationalIntentReferences = append(response.OperationalIntentReferences, *p) + } + + return response, nil +} + +type UpsertOperationalIntentTransactionPayload struct { + Manager dssmodels.Manager + ValidParams *repos.ValidOIRParams + Key []scdmodels.OVN +} + +type UpsertOperationalIntentTransactionResult struct { + ResponseOK *restapi.ChangeOperationalIntentReferenceResponse + ResponseConflict *restapi.AirspaceConflictResponse +} + +func (r *repo) upsertOperationalIntentTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*UpsertOperationalIntentTransactionResult, error) { + var payload *UpsertOperationalIntentTransactionPayload + if err := json.Unmarshal(proposal.Value, &payload); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal upsert operational intent request") + } + + upsertResult := &UpsertOperationalIntentTransactionResult{} + + key := make(map[scdmodels.OVN]bool, len(payload.Key)) + for _, ovn := range payload.Key { + key[ovn] = true + } + + old, err := mem.GetOperationalIntent(ctx, payload.ValidParams.ID) + if err != nil { + return upsertResult, stacktrace.Propagate(err, "Could not get OperationalIntent from repo") + } + + if err := repos.ValidateUpsertRequestAgainstPreviousOIR(payload.Manager, payload.ValidParams.OVN, old); err != nil { + return upsertResult, stacktrace.PropagateWithCode(err, stacktrace.GetCode(err), "Request validation failed") + } + + var ( + version = scdmodels.VersionNumber(1) + pastOVNs = make([]scdmodels.OVN, 0) + previousSub *scdmodels.Subscription + ) + if old != nil { + version = old.Version + 1 + pastOVNs = append(old.PastOVNs, payload.ValidParams.OVN) + + if old.SubscriptionID != nil { + previousSub, err = mem.GetSubscription(ctx, *old.SubscriptionID) + if err != nil { + return upsertResult, stacktrace.Propagate(err, "Unable to get OperationalIntent's Subscription from repo") + } + } + } + + previousSubIsBeingReplaced := previousSub != nil && payload.ValidParams.SubscriptionID != previousSub.ID + removePreviousImplicitSubscription := false + if previousSubIsBeingReplaced { + removePreviousImplicitSubscription, err = repos.SubscriptionIsImplicitAndOnlyAttachedToOIR(ctx, mem, payload.ValidParams.ID, previousSub) + if err != nil { + return upsertResult, stacktrace.Propagate(err, "Could not determine if previous Subscription can be removed") + } + } + + // Every error path after this point must restore the store to the checkpoint, since we may have already written to the store. + cp := r.memStore.Checkpoint() + + attachedSub := previousSub + if payload.ValidParams.SubscriptionID.Empty() { + if payload.ValidParams.ImplicitSubscription.Requested { + if attachedSub, err = repos.CreateAndStoreNewImplicitSubscription(ctx, mem, payload.Manager, payload.ValidParams); err != nil { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return upsertResult, stacktrace.Propagate(err, "Failed to create implicit subscription") + } + } else { + attachedSub = nil + } + } else { + if attachedSub == nil || previousSubIsBeingReplaced { + attachedSub, err = mem.GetSubscription(ctx, payload.ValidParams.SubscriptionID) + if err != nil { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return upsertResult, stacktrace.Propagate(err, "Failed to ensure subscription covers OIR") + } + + if attachedSub == nil { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return upsertResult, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Specified Subscription %s does not exist", payload.ValidParams.SubscriptionID) + } + } + + if attachedSub.Manager != payload.Manager { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + + return upsertResult, stacktrace.Propagate( + stacktrace.NewErrorWithCode( + dsserr.PermissionDenied, "Specificed Subscription is owned by different client"), + "Subscription %s owned by %s, but %s attempted to use it for an OperationalIntent", + payload.ValidParams.SubscriptionID, + attachedSub.Manager, + payload.Manager, + ) + } + + attachedSub, err = repos.EnsureSubscriptionCoversOIR(ctx, mem, attachedSub, payload.ValidParams) + if err != nil { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + + return upsertResult, stacktrace.Propagate(err, "Failed to ensure subscription covers OIR") + } + } + + if payload.ValidParams.State.RequiresKey() { + upsertResult.ResponseConflict, err = repos.ValidateKeyAndProvideConflictResponse(ctx, mem, payload.Manager, payload.ValidParams, attachedSub) + if err != nil { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + + return upsertResult, stacktrace.PropagateWithCode(err, stacktrace.GetCode(err), "Failed to validate key") + } + } + + op := payload.ValidParams.ToOIR(payload.Manager, attachedSub, version, pastOVNs) + + op, err = mem.UpsertOperationalIntent(ctx, op) + if err != nil { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return upsertResult, stacktrace.Propagate(err, "Failed to upsert OperationalIntent in repo") + } + + if removePreviousImplicitSubscription { + if err = mem.DeleteSubscription(ctx, previousSub.ID); err != nil { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return upsertResult, stacktrace.Propagate(err, "Unable to delete previous implicit Subscription") + } + } + + notifyVolume, err := repos.ComputeNotificationVolume(old, payload.ValidParams.UExtent) + if err != nil { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return upsertResult, stacktrace.Propagate(err, "Failed to compute notification volume") + } + + subsToNotify, err := repos.GetRelevantSubscriptionsAndIncrementIndices(ctx, mem, notifyVolume) + if err != nil { + restoreErr := r.memStore.Restore(cp) + if restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return upsertResult, stacktrace.Propagate(err, "Failed to notify relevant Subscriptions") + } + + upsertResult.ResponseOK = &restapi.ChangeOperationalIntentReferenceResponse{ + OperationalIntentReference: *op.ToRest(), + Subscribers: repos.MakeSubscribersToNotify(subsToNotify), + } + + return upsertResult, nil +} diff --git a/pkg/scd/store/raftstore/params/params.go b/pkg/scd/store/raftstore/params/params.go new file mode 100644 index 000000000..8642a7efc --- /dev/null +++ b/pkg/scd/store/raftstore/params/params.go @@ -0,0 +1,29 @@ +package params + +import ( + "flag" + + raftparams "github.com/interuss/dss/pkg/raftstore/params" + "github.com/interuss/stacktrace" +) + +const peersFlag = "scd_raft_peers" + +var peers string + +func init() { + flag.StringVar(&peers, peersFlag, "", `Comma-separated "nodeID=peerURL" pairs for the scd store, e.g. "1=http://node1:9021,2=http://node2:9021,3=http://node3:9021"`) +} + +func GetConnectParameters() (raftparams.ConnectParameters, error) { + if peers == "" { + return raftparams.ConnectParameters{}, stacktrace.NewError("--%s is required", peersFlag) + } + + p, err := raftparams.GetConnectParameters("scd") + if err != nil { + return raftparams.ConnectParameters{}, err + } + p.Peers = peers + return p, nil +} diff --git a/pkg/scd/store/raftstore/store.go b/pkg/scd/store/raftstore/store.go index 75dd1ff5c..5bf520801 100644 --- a/pkg/scd/store/raftstore/store.go +++ b/pkg/scd/store/raftstore/store.go @@ -2,15 +2,352 @@ package raftstore import ( "context" + "encoding/json" + "slices" + "time" + "github.com/interuss/dss/pkg/memstore" + dssmodels "github.com/interuss/dss/pkg/models" "github.com/interuss/dss/pkg/raftstore" + "github.com/interuss/dss/pkg/raftstore/consensus" + scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/dss/pkg/scd/repos" + scdmemstore "github.com/interuss/dss/pkg/scd/store/memstore" + scdraftparams "github.com/interuss/dss/pkg/scd/store/raftstore/params" + "github.com/interuss/stacktrace" + "go.uber.org/zap" ) +const ( + getOperationalIntent raftstore.RequestType = "getOperationalIntent" + deleteOperationalIntent raftstore.RequestType = "deleteOperationalIntent" + upsertOperationalIntent raftstore.RequestType = "upsertOperationalIntent" + searchOperationalIntents raftstore.RequestType = "searchOperationalIntents" + getDependentOperationalIntents raftstore.RequestType = "getDependentOperationalIntents" + listExpiredOperationalIntents raftstore.RequestType = "listExpiredOperationalIntents" + countOperationalIntents raftstore.RequestType = "countOperationalIntents" + + DeleteOperationalIntentTransaction raftstore.RequestType = "deleteOperationalIntentTransaction" + GetOperationalIntentTransaction raftstore.RequestType = "getOperationalIntentTransaction" + QueryOperationalIntentTransaction raftstore.RequestType = "queryOperationalIntentTransaction" + UpsertOperationalIntentTransaction raftstore.RequestType = "upsertOperationalIntentTransaction" + + searchSubscriptions raftstore.RequestType = "searchSubscriptions" + getSubscription raftstore.RequestType = "getSubscription" + upsertSubscription raftstore.RequestType = "upsertSubscription" + deleteSubscription raftstore.RequestType = "deleteSubscription" + incrementNotificationForOIs raftstore.RequestType = "incrementNotificationForOIs" + incrementNotificationForConstraints raftstore.RequestType = "incrementNotificationForConstraints" + listExpiredSubscriptions raftstore.RequestType = "listExpiredSubscriptions" + countSubscriptions raftstore.RequestType = "countSubscriptions" + + DeleteSubscriptionTransaction raftstore.RequestType = "deleteSubscriptionTransaction" + GetSubscriptionTransaction raftstore.RequestType = "getSubscriptionTransaction" + QuerySubscriptionTransaction raftstore.RequestType = "querySubscriptionTransaction" + UpsertSubscriptionTransaction raftstore.RequestType = "upsertSubscriptionTransaction" + + searchConstraints raftstore.RequestType = "searchConstraints" + getConstraint raftstore.RequestType = "getConstraint" + upsertConstraint raftstore.RequestType = "upsertConstraint" + deleteConstraint raftstore.RequestType = "deleteConstraint" + countConstraints raftstore.RequestType = "countConstraints" + + DeleteConstraintTransaction raftstore.RequestType = "deleteConstraintTransaction" + GetConstraintTransaction raftstore.RequestType = "getConstraintTransaction" + QueryConstraintTransaction raftstore.RequestType = "queryConstraintTransaction" + UpsertConstraintTransaction raftstore.RequestType = "upsertConstraintTransaction" + + getUSSAvailability raftstore.RequestType = "getUSSAvailability" + upsertUSSAvailability raftstore.RequestType = "upsertUSSAvailability" + + GetUSSAvailabilityTransaction raftstore.RequestType = "getUSSAvailabilityTransaction" + SetUSSAvailabilityTransaction raftstore.RequestType = "setUSSAvailabilityTransaction" +) + +var readOnlyRequests = []raftstore.RequestType{ + getOperationalIntent, + searchOperationalIntents, + getDependentOperationalIntents, + listExpiredOperationalIntents, + countOperationalIntents, + + GetOperationalIntentTransaction, + QueryOperationalIntentTransaction, + + searchSubscriptions, + getSubscription, + listExpiredSubscriptions, + countSubscriptions, + + GetSubscriptionTransaction, + QuerySubscriptionTransaction, + + searchConstraints, + getConstraint, + countConstraints, + + GetConstraintTransaction, + QueryConstraintTransaction, + + getUSSAvailability, + GetUSSAvailabilityTransaction, +} + // repo is a full implementation of scd.repos.Repository for Raft-based storage. -type repo struct{} +type repo struct { + consensus *consensus.Consensus + + memStore *memstore.Store[repos.Repository] +} func Init(ctx context.Context, logger *zap.Logger) (*raftstore.Store[repos.Repository], error) { - return raftstore.Init[repos.Repository](ctx, logger, func() repos.Repository { return &repo{} }) + params, err := scdraftparams.GetConnectParameters() + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get scd raft parameters") + } + + memStore, err := scdmemstore.Init(ctx, logger) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to initialize SCD memstore") + } + + r := &repo{memStore: memStore} + store, err := raftstore.Init(ctx, logger, params, r) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to initialize SCDs raftstore") + } + + r.consensus = store.Consensus + + return store, nil +} + +func (r *repo) GetRepo() repos.Repository { return r } + +func (r *repo) IsReadOnly(requestType raftstore.RequestType) bool { + return slices.Contains(readOnlyRequests, requestType) +} + +func (r *repo) GetSnapshot() ([]byte, error) { + return r.memStore.GetSnapshot() +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + return r.memStore.RestoreFromSnapshot(data) +} + +func (r *repo) Apply(ctx context.Context, proposal consensus.Proposal) (any, error) { + mem, err := r.memStore.Interact(ctx) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to obtain scd memstore repository") + } + + switch raftstore.RequestType(proposal.RequestType) { + case getOperationalIntent: + var id dssmodels.ID + if err := json.Unmarshal(proposal.Value, &id); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", getOperationalIntent) + } + + return mem.GetOperationalIntent(ctx, id) + + case deleteOperationalIntent: + var id dssmodels.ID + if err := json.Unmarshal(proposal.Value, &id); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", deleteOperationalIntent) + } + + return nil, mem.DeleteOperationalIntent(ctx, id) + + case upsertOperationalIntent: + var operation *scdmodels.OperationalIntent + if err := json.Unmarshal(proposal.Value, &operation); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", upsertOperationalIntent) + } + + return mem.UpsertOperationalIntent(ctx, operation) + + case searchOperationalIntents: + var v4d *dssmodels.Volume4D + if err := json.Unmarshal(proposal.Value, &v4d); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", searchOperationalIntents) + } + + return mem.SearchOperationalIntents(ctx, v4d) + + case getDependentOperationalIntents: + var subscriptionID dssmodels.ID + if err := json.Unmarshal(proposal.Value, &subscriptionID); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", getDependentOperationalIntents) + } + + return mem.GetDependentOperationalIntents(ctx, subscriptionID) + + case listExpiredOperationalIntents: + var threshold time.Time + if err := json.Unmarshal(proposal.Value, &threshold); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", listExpiredOperationalIntents) + } + + return mem.ListExpiredOperationalIntents(ctx, threshold) + + case countOperationalIntents: + return mem.CountOperationalIntents(ctx) + + case DeleteOperationalIntentTransaction: + return r.deleteOperationalIntentTransactionApplier(ctx, proposal, mem) + + case GetOperationalIntentTransaction: + return r.getOperationalIntentTransactionApplier(ctx, proposal, mem) + + case QueryOperationalIntentTransaction: + return r.queryOperationalIntentTransactionApplier(ctx, proposal, mem) + + case UpsertOperationalIntentTransaction: + return r.upsertOperationalIntentTransactionApplier(ctx, proposal, mem) + + case searchSubscriptions: + var v4d *dssmodels.Volume4D + if err := json.Unmarshal(proposal.Value, &v4d); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", searchSubscriptions) + } + + return mem.SearchSubscriptions(ctx, v4d) + + case getSubscription: + var id dssmodels.ID + if err := json.Unmarshal(proposal.Value, &id); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", getSubscription) + } + + return mem.GetSubscription(ctx, id) + + case upsertSubscription: + var sub *scdmodels.Subscription + if err := json.Unmarshal(proposal.Value, &sub); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", upsertSubscription) + } + + return mem.UpsertSubscription(ctx, sub) + + case deleteSubscription: + var id dssmodels.ID + if err := json.Unmarshal(proposal.Value, &id); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", deleteSubscription) + } + + return nil, mem.DeleteSubscription(ctx, id) + + case incrementNotificationForOIs: + var v4d *dssmodels.Volume4D + if err := json.Unmarshal(proposal.Value, &v4d); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", incrementNotificationForOIs) + } + + return mem.IncrementNotificationIndicesForOperationalIntents(ctx, v4d) + + case incrementNotificationForConstraints: + var v4d *dssmodels.Volume4D + if err := json.Unmarshal(proposal.Value, &v4d); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", incrementNotificationForConstraints) + } + + return mem.IncrementNotificationIndicesForConstraints(ctx, v4d) + + case listExpiredSubscriptions: + var threshold time.Time + if err := json.Unmarshal(proposal.Value, &threshold); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", listExpiredSubscriptions) + } + + return mem.ListExpiredSubscriptions(ctx, threshold) + + case countSubscriptions: + return mem.CountSubscriptions(ctx) + + case UpsertSubscriptionTransaction: + return r.upsertSubscriptionTransactionApplier(ctx, proposal, mem) + + case DeleteSubscriptionTransaction: + return r.deleteSubscriptionTransactionApplier(ctx, proposal, mem) + + case GetSubscriptionTransaction: + return r.getSubscriptionTransactionApplier(ctx, proposal, mem) + + case QuerySubscriptionTransaction: + return r.querySubscriptionTransactionApplier(ctx, proposal, mem) + + case searchConstraints: + var v4d *dssmodels.Volume4D + if err := json.Unmarshal(proposal.Value, &v4d); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", searchConstraints) + } + + return mem.SearchConstraints(ctx, v4d) + + case getConstraint: + var id dssmodels.ID + if err := json.Unmarshal(proposal.Value, &id); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", getConstraint) + } + + return mem.GetConstraint(ctx, id) + + case upsertConstraint: + var constraint *scdmodels.Constraint + if err := json.Unmarshal(proposal.Value, &constraint); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", upsertConstraint) + } + + return mem.UpsertConstraint(ctx, constraint) + + case deleteConstraint: + var id dssmodels.ID + if err := json.Unmarshal(proposal.Value, &id); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", deleteConstraint) + } + + return nil, mem.DeleteConstraint(ctx, id) + + case countConstraints: + return mem.CountConstraints(ctx) + + case DeleteConstraintTransaction: + return r.deleteConstraintTransactionApplier(ctx, proposal, mem) + + case GetConstraintTransaction: + return r.getConstraintTransactionApplier(ctx, proposal, mem) + + case QueryConstraintTransaction: + return r.queryConstraintTransactionApplier(ctx, proposal, mem) + + case UpsertConstraintTransaction: + return r.upsertConstraintTransactionApplier(ctx, proposal, mem) + + case getUSSAvailability: + var id dssmodels.Manager + if err := json.Unmarshal(proposal.Value, &id); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", getUSSAvailability) + } + + return mem.GetUssAvailability(ctx, id) + + case upsertUSSAvailability: + var ussa scdmodels.UssAvailabilityStatus + if err := json.Unmarshal(proposal.Value, &ussa); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal %s proposal value", upsertUSSAvailability) + } + + return mem.UpsertUssAvailability(ctx, &ussa) + + case GetUSSAvailabilityTransaction: + return r.getUSSAvailabilityTransactionApplier(ctx, proposal, mem) + + case SetUSSAvailabilityTransaction: + return r.setUSSAvailabilityTransactionApplier(ctx, proposal, mem) + + default: + return nil, stacktrace.NewError("unknown request type: %q", proposal.RequestType) + } } diff --git a/pkg/scd/store/raftstore/subscriptions.go b/pkg/scd/store/raftstore/subscriptions.go index 2e5609335..e0b5c6cb3 100644 --- a/pkg/scd/store/raftstore/subscriptions.go +++ b/pkg/scd/store/raftstore/subscriptions.go @@ -5,44 +5,143 @@ import ( "time" "github.com/golang/geo/s2" - dsserr "github.com/interuss/dss/pkg/errors" dssmodels "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/stacktrace" ) -func (r *repo) SearchSubscriptions(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchSubscriptions not implemented for raftstore") +func (r *repo) SearchSubscriptions(ctx context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, string(searchSubscriptions), v4d, true) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose searchSubscriptions") + } + + if result == nil { + return nil, nil + } + + subs, ok := result.([]*scdmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return subs, nil } -func (r *repo) GetSubscription(_ context.Context, id dssmodels.ID) (*scdmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetSubscription not implemented for raftstore") +func (r *repo) GetSubscription(ctx context.Context, id dssmodels.ID) (*scdmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, string(getSubscription), id, true) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose getSubscription") + } + + if result == nil { + return nil, nil + } + + sub, ok := result.(*scdmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return sub, nil } -func (r *repo) UpsertSubscription(_ context.Context, sub *scdmodels.Subscription) (*scdmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpsertSubscription not implemented for raftstore") +func (r *repo) UpsertSubscription(ctx context.Context, sub *scdmodels.Subscription) (*scdmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, string(upsertSubscription), sub, false) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose upsertSubscription") + } + + if result == nil { + return nil, nil + } + + upserted, ok := result.(*scdmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return upserted, nil } -func (r *repo) DeleteSubscription(_ context.Context, id dssmodels.ID) error { - return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteSubscription not implemented for raftstore") +func (r *repo) DeleteSubscription(ctx context.Context, id dssmodels.ID) error { + _, err := r.consensus.ProposeValue(ctx, string(deleteSubscription), id, false) + return err } -func (r *repo) IncrementNotificationIndicesForOperationalIntents(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "IncrementNotificationIndicesForOperationalIntents not implemented for raftstore") +func (r *repo) IncrementNotificationIndicesForOperationalIntents(ctx context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, string(incrementNotificationForOIs), v4d, false) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose incrementNotificationForOIs") + } + + if result == nil { + return nil, nil + } + + subs, ok := result.([]*scdmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return subs, nil } -func (r *repo) IncrementNotificationIndicesForConstraints(_ context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "IncrementNotificationIndicesForConstraints not implemented for raftstore") +func (r *repo) IncrementNotificationIndicesForConstraints(ctx context.Context, v4d *dssmodels.Volume4D) ([]*scdmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, string(incrementNotificationForConstraints), v4d, false) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose incrementNotificationForConstraints") + } + + if result == nil { + return nil, nil + } + + subs, ok := result.([]*scdmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return subs, nil } -func (r *repo) LockSubscriptionsOnCells(_ context.Context, cells s2.CellUnion, subscriptionIds []dssmodels.ID, startTime *time.Time, endTime *time.Time) error { - return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "LockSubscriptionsOnCells not implemented for raftstore") +func (r *repo) LockSubscriptionsOnCells(_ context.Context, _ s2.CellUnion, _ []dssmodels.ID, _ *time.Time, _ *time.Time) error { + // for the raftstore, LockSubscriptionsOnCells is a no-op + return nil } -func (r *repo) ListExpiredSubscriptions(_ context.Context, threshold time.Time) ([]*scdmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredSubscriptions not implemented for raftstore") +func (r *repo) ListExpiredSubscriptions(ctx context.Context, threshold time.Time) ([]*scdmodels.Subscription, error) { + result, err := r.consensus.ProposeValue(ctx, string(listExpiredSubscriptions), threshold, true) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to propose listExpiredSubscriptions") + } + + if result == nil { + return nil, nil + } + + subs, ok := result.([]*scdmodels.Subscription) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", result) + } + + return subs, nil } -func (r *repo) CountSubscriptions(_ context.Context) (int64, error) { - return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountSubscriptions not implemented for raftstore") +func (r *repo) CountSubscriptions(ctx context.Context) (int64, error) { + result, err := r.consensus.ProposeValue(ctx, string(countSubscriptions), nil, true) + if err != nil { + return 0, stacktrace.Propagate(err, "failed to propose countSubscriptions") + } + + if result == nil { + return 0, nil + } + + count, ok := result.(int64) + if !ok { + return 0, stacktrace.NewError("invalid result type: %T", result) + } + + return count, nil } diff --git a/pkg/scd/store/raftstore/subscriptions_appliers.go b/pkg/scd/store/raftstore/subscriptions_appliers.go new file mode 100644 index 000000000..c36ff3741 --- /dev/null +++ b/pkg/scd/store/raftstore/subscriptions_appliers.go @@ -0,0 +1,312 @@ +package raftstore + +import ( + "context" + "encoding/json" + + "github.com/golang/geo/s2" + restapi "github.com/interuss/dss/pkg/api/scdv1" + dsserr "github.com/interuss/dss/pkg/errors" + dssmodels "github.com/interuss/dss/pkg/models" + "github.com/interuss/dss/pkg/raftstore/consensus" + scdmodels "github.com/interuss/dss/pkg/scd/models" + "github.com/interuss/dss/pkg/scd/repos" + "github.com/interuss/stacktrace" +) + +type UpsertSubscriptionTransactionPayload struct { + Subreq *scdmodels.Subscription `json:"subreq"` + Extents *dssmodels.Volume4D `json:"extents"` +} + +func (r *repo) upsertSubscriptionTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.PutSubscriptionResponse, error) { + var payload *UpsertSubscriptionTransactionPayload + if err := json.Unmarshal(proposal.Value, &payload); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal upsert subscription transaction payload") + } + + old, err := mem.GetSubscription(ctx, payload.Subreq.ID) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get existing subscription from repo") + } + + if err := payload.Subreq.AdjustTimeRange(proposal.Timestamp, old); err != nil { + return nil, stacktrace.Propagate(err, "failed to adjust subscription time range") + } + + var dependentOpIds []dssmodels.ID + + if old == nil { + if payload.Subreq.Version.String() != "" { + return nil, stacktrace.NewErrorWithCode(dsserr.NotFound, "Subscription %s not found", payload.Subreq.ID.String()) + } + } else { + switch { + case payload.Subreq.Version.String() == "": + return nil, stacktrace.NewErrorWithCode(dsserr.AlreadyExists, "Subscription %s already exists", payload.Subreq.ID.String()) + case payload.Subreq.Version.String() != old.Version.String(): + return nil, stacktrace.Propagate( + stacktrace.NewErrorWithCode(dsserr.VersionMismatch, "Subscription version %s is not current", payload.Subreq.Version), + "Current version is %s but client specified version %s", old.Version, payload.Subreq.Version) + case old.Manager != payload.Subreq.Manager: + return nil, stacktrace.Propagate( + stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Subscription is owned by different client"), + "Subscription owned by %s, but %s attempted to modify", old.Manager, payload.Subreq.Manager) + } + + payload.Subreq.NotificationIndex = old.NotificationIndex + + dependentOpIds, err = mem.GetDependentOperationalIntents(ctx, payload.Subreq.ID) + if err != nil { + return nil, stacktrace.Propagate(err, "Could not find dependent Operation Ids") + } + + var operations []*scdmodels.OperationalIntent + for _, opID := range dependentOpIds { + operation, err := mem.GetOperationalIntent(ctx, opID) + if err != nil { + return nil, stacktrace.Propagate(err, "Could not retrieve dependent Operation %s", opID) + } + operations = append(operations, operation) + } + + if err := payload.Subreq.ValidateDependentOps(operations); err != nil { + return nil, err + } + } + + cp := r.memStore.Checkpoint() + + sub, err := mem.UpsertSubscription(ctx, payload.Subreq) + if err != nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return nil, stacktrace.Propagate(err, "Failed to upsert Subscription in repo") + } + if sub == nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return nil, stacktrace.NewError("UpsertSubscription returned no Subscription for ID: %s", payload.Subreq.ID) + } + + var relevantOperations []*scdmodels.OperationalIntent + if len(sub.Cells) > 0 { + ops, err := mem.SearchOperationalIntents(ctx, &dssmodels.Volume4D{ + StartTime: sub.StartTime, + EndTime: sub.EndTime, + SpatialVolume: &dssmodels.Volume3D{ + AltitudeLo: sub.AltitudeLo, + AltitudeHi: sub.AltitudeHi, + Footprint: dssmodels.GeometryFunc(func() (s2.CellUnion, error) { + return sub.Cells, nil + }), + }, + }) + if err != nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return nil, stacktrace.Propagate(err, "Could not search Operations in repo") + } + relevantOperations = ops + } + + p, err := sub.ToRest(dependentOpIds) + if err != nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return nil, stacktrace.Propagate(err, "Could not convert Subscription to REST model") + } + + result := &restapi.PutSubscriptionResponse{ + Subscription: *p, + } + + if sub.NotifyForOperationalIntents { + opIntentRefs := make([]restapi.OperationalIntentReference, 0, len(relevantOperations)) + for _, op := range relevantOperations { + if op.Manager != dssmodels.Manager(payload.Subreq.Manager) { + op.OVN = scdmodels.NoOvnPhrase + } + + opIntentRefs = append(opIntentRefs, *op.ToRest()) + } + result.OperationalIntentReferences = &opIntentRefs + } + + if sub.NotifyForConstraints { + constraints, err := mem.SearchConstraints(ctx, payload.Extents) + if err != nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return nil, stacktrace.Propagate(err, "Could not search Constraints in repo") + } + + constraintRefs := make([]restapi.ConstraintReference, 0, len(constraints)) + for _, constraint := range constraints { + p := constraint.ToRest() + if constraint.Manager != dssmodels.Manager(payload.Subreq.Manager) { + noOvnPhrase := restapi.EntityOVN(scdmodels.NoOvnPhrase) + p.Ovn = &noOvnPhrase + } + + constraintRefs = append(constraintRefs, *p) + } + result.ConstraintReferences = &constraintRefs + } + + return result, nil +} + +func (r *repo) getSubscriptionTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.GetSubscriptionResponse, error) { + var req *restapi.GetSubscriptionRequest + if err := json.Unmarshal(proposal.Value, &req); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal get subscription request") + } + + id, err := dssmodels.IDFromString(string(req.Subscriptionid)) + if err != nil { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Invalid ID format: `%s`", req.Subscriptionid) + } + + if req.Auth.ClientID == nil { + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Missing owner") + } + + sub, err := mem.GetSubscription(ctx, id) + if err != nil { + return nil, stacktrace.Propagate(err, "Could not get Subscription from repo") + } + if sub == nil { + return nil, stacktrace.NewErrorWithCode(dsserr.NotFound, "Subscription %s not found", id.String()) + } + + if dssmodels.Manager(*req.Auth.ClientID) != sub.Manager { + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, + "Subscription owned by %s, but %s attempted to view", sub.Manager, *req.Auth.ClientID) + } + + dependentOps, err := mem.GetDependentOperationalIntents(ctx, id) + if err != nil { + return nil, stacktrace.Propagate(err, "Could not find dependent Operations") + } + + p, err := sub.ToRest(dependentOps) + if err != nil { + return nil, stacktrace.Propagate(err, "Unable to convert Subscription to REST") + } + + return &restapi.GetSubscriptionResponse{Subscription: *p}, nil +} + +func (r *repo) querySubscriptionTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.QuerySubscriptionsResponse, error) { + var req *restapi.QuerySubscriptionsRequest + if err := json.Unmarshal(proposal.Value, &req); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal query subscriptions request") + } + + if req.Auth.ClientID == nil { + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Missing owner") + } + + aoi := req.Body.AreaOfInterest + if aoi == nil { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Missing area_of_interest") + } + + vol4, err := dssmodels.Volume4DFromSCDRest(aoi) + if err != nil { + return nil, stacktrace.PropagateWithCode(err, dsserr.BadRequest, "Failed to convert to internal geometry model") + } + + subs, err := mem.SearchSubscriptions(ctx, vol4) + if err != nil { + return nil, stacktrace.Propagate(err, "Error searching Subscriptions in repo") + } + + response := &restapi.QuerySubscriptionsResponse{ + Subscriptions: make([]restapi.Subscription, 0), + } + for _, sub := range subs { + if sub.EndTime.Before(proposal.Timestamp) || sub.Manager != dssmodels.Manager(*req.Auth.ClientID) { + continue + } + dependentOps, err := mem.GetDependentOperationalIntents(ctx, sub.ID) + if err != nil { + return nil, stacktrace.Propagate(err, "Could not find dependent Operations") + } + p, err := sub.ToRest(dependentOps) + if err != nil { + return nil, stacktrace.Propagate(err, "Error converting Subscription model to REST") + } + response.Subscriptions = append(response.Subscriptions, *p) + } + + return response, nil +} + +func (r *repo) deleteSubscriptionTransactionApplier(ctx context.Context, proposal consensus.Proposal, mem repos.Repository) (*restapi.DeleteSubscriptionResponse, error) { + var req *restapi.DeleteSubscriptionRequest + if err := json.Unmarshal(proposal.Value, &req); err != nil { + return nil, stacktrace.Propagate(err, "failed to unmarshal delete subscription request") + } + + id, err := dssmodels.IDFromString(string(req.Subscriptionid)) + if err != nil { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Invalid ID format: `%s`", req.Subscriptionid) + } + + version := scdmodels.OVN(req.Version) + if version == "" { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Missing version") + } + + if req.Auth.ClientID == nil { + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Missing owner") + } + + old, err := mem.GetSubscription(ctx, id) + switch { + case err != nil: + return nil, stacktrace.Propagate(err, "Could not get Subscription from repo") + case old == nil: + return nil, stacktrace.NewErrorWithCode(dsserr.NotFound, "Subscription %s not found", id.String()) + case old.Manager != dssmodels.Manager(*req.Auth.ClientID): + return nil, stacktrace.NewErrorWithCode(dsserr.PermissionDenied, + "Subscription owned by %s, but %s attempted to delete", old.Manager, *req.Auth.ClientID) + case old.Version != version: + return nil, stacktrace.NewErrorWithCode(dsserr.VersionMismatch, "Subscription version %s is not current", version) + } + + dependentOps, err := mem.GetDependentOperationalIntents(ctx, id) + if err != nil { + return nil, stacktrace.Propagate(err, "Could not find dependent Operations") + } + if len(dependentOps) > 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Subscriptions with dependent Operations may not be removed") + } + + cp := r.memStore.Checkpoint() + + if err = mem.DeleteSubscription(ctx, id); err != nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + return nil, stacktrace.Propagate(err, "Could not delete Subscription from repo") + } + + p, err := old.ToRest(dependentOps) + if err != nil { + if restoreErr := r.memStore.Restore(cp); restoreErr != nil { + return nil, stacktrace.Propagate(restoreErr, "Failed to restore store") + } + + return nil, stacktrace.Propagate(err, "Error converting Subscription model to REST") + } + + return &restapi.DeleteSubscriptionResponse{Subscription: *p}, nil +} diff --git a/pkg/scd/store/store.go b/pkg/scd/store/store.go index a8603fb42..9f313d511 100644 --- a/pkg/scd/store/store.go +++ b/pkg/scd/store/store.go @@ -4,6 +4,7 @@ import ( "context" "github.com/interuss/dss/pkg/scd/repos" + scdmemstore "github.com/interuss/dss/pkg/scd/store/memstore" scdraftstore "github.com/interuss/dss/pkg/scd/store/raftstore" scdsqlstore "github.com/interuss/dss/pkg/scd/store/sqlstore" dssstore "github.com/interuss/dss/pkg/store" @@ -24,6 +25,8 @@ func Init(ctx context.Context, logger *zap.Logger, withCheckCron bool, globalLoc return scdsqlstore.Init(ctx, logger, withCheckCron, globalLock) case params.RaftStoreType: return scdraftstore.Init(ctx, logger) + case params.MemStoreType: + return scdmemstore.Init(ctx, logger) default: return nil, stacktrace.NewError("Unsupported store type %q for scd", storeType) } diff --git a/pkg/scd/subscriptions_handler.go b/pkg/scd/subscriptions_handler.go index 98ab9fed2..c4c14d4f2 100644 --- a/pkg/scd/subscriptions_handler.go +++ b/pkg/scd/subscriptions_handler.go @@ -12,6 +12,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/dss/pkg/scd/repos" + scdraftstore "github.com/interuss/dss/pkg/scd/store/raftstore" "github.com/interuss/stacktrace" "github.com/jonboulle/clockwork" ) @@ -276,11 +277,23 @@ func (a *Server) PutSubscription(ctx context.Context, manager string, subscripti return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.UpsertSubscriptionTransaction, &scdraftstore.UpsertSubscriptionTransactionPayload{ + Subreq: subreq, + Extents: extents, + }, action) if err != nil { return nil, err // No need to Propagate this error as this is not a useful stacktrace line } + if raftResult != nil { + putSubResp, ok := raftResult.(*restapi.PutSubscriptionResponse) + if !ok { + return nil, stacktrace.NewError("invalid result type: %T", raftResult) + } + + result = putSubResp + } + // Return response to client return result, nil } @@ -340,7 +353,7 @@ func (a *Server) GetSubscription(ctx context.Context, req *restapi.GetSubscripti return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.GetSubscriptionTransaction, req, action) if err != nil { err = stacktrace.Propagate(err, "Could not get subscription") errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -356,6 +369,14 @@ func (a *Server) GetSubscription(ctx context.Context, req *restapi.GetSubscripti ErrorMessage: *dsserr.Handle(ctx, stacktrace.Propagate(err, "Got an unexpected error"))}} } } + if raftResult != nil { + getSubResponse, ok := raftResult.(*restapi.GetSubscriptionResponse) + if !ok { + return restapi.GetSubscriptionResponseSet{Response500: &api.InternalServerErrorBody{ + ErrorMessage: *dsserr.Handle(ctx, stacktrace.NewError("invalid result type: %T", raftResult))}} + } + response = getSubResponse + } return restapi.GetSubscriptionResponseSet{Response200: response} } @@ -424,7 +445,7 @@ func (a *Server) QuerySubscriptions(ctx context.Context, req *restapi.QuerySubsc return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.QuerySubscriptionTransaction, req, action) if err != nil { errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -441,6 +462,14 @@ func (a *Server) QuerySubscriptions(ctx context.Context, req *restapi.QuerySubsc } } + if raftResult != nil { + querySubResponse, ok := raftResult.(*restapi.QuerySubscriptionsResponse) + if !ok { + return restapi.QuerySubscriptionsResponseSet{Response500: &api.InternalServerErrorBody{ + ErrorMessage: *dsserr.Handle(ctx, stacktrace.NewError("invalid result type: %T", raftResult))}} + } + response = querySubResponse + } return restapi.QuerySubscriptionsResponseSet{Response200: response} } @@ -517,7 +546,7 @@ func (a *Server) DeleteSubscription(ctx context.Context, req *restapi.DeleteSubs return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.DeleteSubscriptionTransaction, req, action) if err != nil { err = stacktrace.Propagate(err, "Could not delete subscription") errResp := &restapi.ErrorResponse{Message: dsserr.Handle(ctx, err)} @@ -535,6 +564,14 @@ func (a *Server) DeleteSubscription(ctx context.Context, req *restapi.DeleteSubs ErrorMessage: *dsserr.Handle(ctx, stacktrace.Propagate(err, "Got an unexpected error"))}} } } + if raftResult != nil { + deleteSubResponse, ok := raftResult.(*restapi.DeleteSubscriptionResponse) + if !ok { + return restapi.DeleteSubscriptionResponseSet{Response500: &api.InternalServerErrorBody{ + ErrorMessage: *dsserr.Handle(ctx, stacktrace.NewError("invalid result type: %T", raftResult))}} + } + response = deleteSubResponse + } return restapi.DeleteSubscriptionResponseSet{Response200: response} } diff --git a/pkg/scd/uss_availability_handler.go b/pkg/scd/uss_availability_handler.go index 517a9859a..32455bc91 100644 --- a/pkg/scd/uss_availability_handler.go +++ b/pkg/scd/uss_availability_handler.go @@ -10,6 +10,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/interuss/dss/pkg/scd/repos" + scdraftstore "github.com/interuss/dss/pkg/scd/store/raftstore" "github.com/interuss/stacktrace" "github.com/jackc/pgx/v5" ) @@ -51,7 +52,7 @@ func (a *Server) GetUssAvailability(ctx context.Context, req *restapi.GetUssAvai return nil } - err := a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.GetUSSAvailabilityTransaction, req, action) if err != nil { // In case of older DB versions where availability table doesn't exist if strings.Contains(err.Error(), "does not exist") { @@ -62,6 +63,15 @@ func (a *Server) GetUssAvailability(ctx context.Context, req *restapi.GetUssAvai ErrorMessage: *dsserr.Handle(ctx, err)}} } } + if raftResult != nil { + getAvailResponse, ok := raftResult.(*restapi.UssAvailabilityStatusResponse) + if !ok { + return restapi.GetUssAvailabilityResponseSet{Response500: &api.InternalServerErrorBody{ + ErrorMessage: *dsserr.Handle(ctx, stacktrace.NewError("invalid result type: %T", raftResult))}} + } + response = getAvailResponse + } + return restapi.GetUssAvailabilityResponseSet{Response200: response} } @@ -121,7 +131,7 @@ func (a *Server) SetUssAvailability(ctx context.Context, req *restapi.SetUssAvai } return nil } - err = a.Store.Transact(ctx, action) + raftResult, err := a.Store.Transact(ctx, scdraftstore.SetUSSAvailabilityTransaction, req, action) if err != nil { // In case of older DB versions where availability table doesn't exist if strings.Contains(err.Error(), "does not exist") { @@ -138,6 +148,14 @@ func (a *Server) SetUssAvailability(ctx context.Context, req *restapi.SetUssAvai } } } + if raftResult != nil { + setAvailResponse, ok := raftResult.(*restapi.UssAvailabilityStatusResponse) + if !ok { + return restapi.SetUssAvailabilityResponseSet{Response500: &api.InternalServerErrorBody{ + ErrorMessage: *dsserr.Handle(ctx, stacktrace.NewError("invalid result type: %T", raftResult))}} + } + result = setAvailResponse + } // Return response to client return restapi.SetUssAvailabilityResponseSet{Response200: result} diff --git a/pkg/sqlstore/store.go b/pkg/sqlstore/store.go index 0255620c6..28c3a1cbb 100644 --- a/pkg/sqlstore/store.go +++ b/pkg/sqlstore/store.go @@ -173,9 +173,9 @@ func (s *Store[R]) Interact(_ context.Context) (R, error) { return s.newRepo(s.Pool), nil } -func (s *Store[R]) Transact(ctx context.Context, f func(context.Context, R) error) error { +func (s *Store[R]) Transact(ctx context.Context, _ string, _ any, f func(context.Context, R) error) (any, error) { ctx = crdb.WithMaxRetries(ctx, s.maxRetries) - return crdbpgx.ExecuteTx(ctx, s.Pool, pgx.TxOptions{IsoLevel: pgx.Serializable}, func(tx pgx.Tx) error { + return nil, crdbpgx.ExecuteTx(ctx, s.Pool, pgx.TxOptions{IsoLevel: pgx.Serializable}, func(tx pgx.Tx) error { return f(ctx, s.newRepo(tx)) }) } diff --git a/pkg/store/params/params.go b/pkg/store/params/params.go index 7a25e29f6..3bf49bfa1 100644 --- a/pkg/store/params/params.go +++ b/pkg/store/params/params.go @@ -8,6 +8,7 @@ import ( const ( RaftStoreType = "raft" SQLStoreType = "sql" + MemStoreType = "mem" ) type ( @@ -22,6 +23,7 @@ var ( ) func init() { + // NB: Memstore not available there on purpose. flag.StringVar(&storeParameters.StoreType, "store_type", SQLStoreType, fmt.Sprintf("Store type. Use '%s' for CockroachDB/YugabyteDB and '%s' for Raft-based store.", SQLStoreType, RaftStoreType)) } diff --git a/pkg/store/store.go b/pkg/store/store.go index e5e95c4b0..e59318a9e 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -16,7 +16,9 @@ type Store[R any] interface { Interact(context.Context) (R, error) // Attempt to apply the operations in f to the R Repo it is supplied. All operations performed // on the R Repo by f will be applied or rejected atomically. - Transact(ctx context.Context, f func(context.Context, R) error) error + // requestType and payload are used by the Raftstore to build the proposal. + // The returned any is the proposal result (also Raftstore only). + Transact(ctx context.Context, requestType string, payload any, f func(context.Context, R) error) (any, error) } const ( diff --git a/pkg/timestamp/timestamp.go b/pkg/timestamp/timestamp.go new file mode 100644 index 000000000..6ac952a16 --- /dev/null +++ b/pkg/timestamp/timestamp.go @@ -0,0 +1,32 @@ +package timestamp + +import ( + "context" + "net/http" + "time" +) + +type timestampKey struct{} + +// NowFromContext returns the timestamp from the context, or zero if not present. +func NowFromContext(ctx context.Context) time.Time { + t, ok := ctx.Value(timestampKey{}).(time.Time) + if !ok { + return time.Time{} + } + + return t +} + +// WithTimestamp returns a new context with the given timestamp. +func WithTimestamp(ctx context.Context, t time.Time) context.Context { + return context.WithValue(ctx, timestampKey{}, t) +} + +// Middleware is an HTTP middleware that adds a timestamp to the request context. +func Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := WithTimestamp(r.Context(), time.Now().UTC()) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +}