From 64d64010b1d89c1e5a2b8fb33538975fd605aa7a Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Mon, 15 Jun 2026 15:07:26 +0200 Subject: [PATCH] [raft/memstore] Add aux memstore --- pkg/aux_/store/memstore/dss.go | 69 ++++++++++++++- pkg/aux_/store/memstore/dss_test.go | 125 ++++++++++++++++++++++++++++ pkg/aux_/store/memstore/store.go | 28 ++++++- 3 files changed, 216 insertions(+), 6 deletions(-) create mode 100644 pkg/aux_/store/memstore/dss_test.go diff --git a/pkg/aux_/store/memstore/dss.go b/pkg/aux_/store/memstore/dss.go index 38fa4d5bf..7198b3936 100644 --- a/pkg/aux_/store/memstore/dss.go +++ b/pkg/aux_/store/memstore/dss.go @@ -2,6 +2,8 @@ package memstore import ( "context" + "database/sql" + "time" auxmodels "github.com/interuss/dss/pkg/aux_/models" dsserr "github.com/interuss/dss/pkg/errors" @@ -9,17 +11,76 @@ import ( ) func (r *repo) SaveOwnMetadata(_ context.Context, locality string, publicEndpoint string) error { - return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SaveOwnMetadata not implemented for memstore") + if locality == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") + } + if publicEndpoint == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Public endpoint not set") + } + + r.participants[locality] = &participant{ + publicEndpoint: publicEndpoint, + updatedAt: time.Now(), + } + return nil } func (r *repo) GetDSSMetadata(_ context.Context) ([]*auxmodels.DSSMetadata, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSMetadata not implemented for memstore") + metadata := make([]*auxmodels.DSSMetadata, 0, len(r.participants)) + for locality, p := range r.participants { + updatedAt := p.updatedAt + m := &auxmodels.DSSMetadata{ + Locality: locality, + PublicEndpoint: p.publicEndpoint, + UpdatedAt: &updatedAt, + } + + // Find the latest heartbeat across all sources for this locality. + var latest auxmodels.Heartbeat + found := false + for key, hb := range r.heartbeats { + if key.locality != locality { + continue + } + if !found || hb.Timestamp.After(*latest.Timestamp) { + latest = hb + found = true + } + } + + if found { + m.LatestTimestamp.Source = sql.NullString{String: latest.Source, Valid: true} + m.LatestTimestamp.Timestamp = latest.Timestamp + m.LatestTimestamp.NextHeartbeatExpectedBefore = latest.NextHeartbeatExpectedBefore + m.LatestTimestamp.Reporter = sql.NullString{String: latest.Reporter, Valid: true} + } + + metadata = append(metadata, m) + } + return metadata, nil } func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) error { - return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "RecordHeartbeat not implemented for memstore") + if heartbeat.Locality == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") + } + if heartbeat.Source == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Source not set") + } + + if heartbeat.Timestamp == nil { + now := time.Now() + heartbeat.Timestamp = &now + } + + if heartbeat.NextHeartbeatExpectedBefore != nil && heartbeat.NextHeartbeatExpectedBefore.Before(*heartbeat.Timestamp) { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Cannot expect the timestamp of the next heartbeat before the timestamp of the new heartbeat") + } + + r.heartbeats[heartbeatKey{locality: heartbeat.Locality, source: heartbeat.Source}] = heartbeat + return nil } func (r *repo) GetDSSAirspaceRepresentationID(_ context.Context) (string, error) { - return "", stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSAirspaceRepresentationID not implemented for memstore") + return "", stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSAirspaceRepresentationID not implementable for memstore") } diff --git a/pkg/aux_/store/memstore/dss_test.go b/pkg/aux_/store/memstore/dss_test.go new file mode 100644 index 000000000..43563d27b --- /dev/null +++ b/pkg/aux_/store/memstore/dss_test.go @@ -0,0 +1,125 @@ +package memstore + +import ( + "context" + "testing" + "time" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/stacktrace" + "github.com/stretchr/testify/require" +) + +func TestSaveOwnMetadataValidation(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.SaveOwnMetadata(ctx, "", "https://example.com"))) + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.SaveOwnMetadata(ctx, "dss-1", ""))) +} + +func TestSaveOwnMetadataRoundTrip(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.Equal(t, "dss-1", md[0].Locality) + require.Equal(t, "https://example.com", md[0].PublicEndpoint) + require.NotNil(t, md[0].UpdatedAt) + + // No heartbeat recorded yet. + require.False(t, md[0].LatestTimestamp.Source.Valid) + require.Nil(t, md[0].LatestTimestamp.Timestamp) +} + +func TestSaveOwnMetadataUpsert(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://old.example.com")) + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://new.example.com")) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.Equal(t, "https://new.example.com", md[0].PublicEndpoint) +} + +func TestRecordHeartbeatValidation(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Source: "source1"}))) + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1"}))) + + ts := time.Now() + before := ts.Add(-time.Minute) + err := r.RecordHeartbeat(ctx, auxmodels.Heartbeat{ + Locality: "dss-1", + Source: "source1", + Timestamp: &ts, + NextHeartbeatExpectedBefore: &before, + }) + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(err)) +} + +func TestRecordHeartbeatDefaultsTimestamp(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1"})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Source.Valid) + require.NotNil(t, md[0].LatestTimestamp.Timestamp) +} + +func TestGetDSSMetadataPicksLatestHeartbeat(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + older := time.Now().Add(-time.Hour) + newer := time.Now() + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &older, Reporter: "uss1"})) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source2", Timestamp: &newer, Reporter: "uss2"})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Timestamp.Equal(newer)) + require.Equal(t, "source2", md[0].LatestTimestamp.Source.String) + require.Equal(t, "uss2", md[0].LatestTimestamp.Reporter.String) +} + +func TestGetDSSMetadataUpdatesHeartbeatPerSource(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + first := time.Now().Add(-time.Hour) + second := time.Now() + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &first})) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &second})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Timestamp.Equal(second)) +} diff --git a/pkg/aux_/store/memstore/store.go b/pkg/aux_/store/memstore/store.go index c2b875f9b..f94a1ffc2 100644 --- a/pkg/aux_/store/memstore/store.go +++ b/pkg/aux_/store/memstore/store.go @@ -2,17 +2,41 @@ package memstore import ( "context" + "time" + auxmodels "github.com/interuss/dss/pkg/aux_/models" "github.com/interuss/dss/pkg/aux_/repos" "github.com/interuss/dss/pkg/memstore" "go.uber.org/zap" ) // repo is a full implementation of aux_.repos.Repository for memory-based storage. -type repo struct{} +type repo struct { + // participants holds pool participants metadata, keyed by locality. + participants map[string]*participant + // heartbeats holds the latest heartbeat per (locality, source). + heartbeats map[heartbeatKey]auxmodels.Heartbeat +} + +type participant struct { + publicEndpoint string + updatedAt time.Time +} + +type heartbeatKey struct { + locality string + source string +} + +func newRepo() *repo { + return &repo{ + participants: map[string]*participant{}, + heartbeats: map[heartbeatKey]auxmodels.Heartbeat{}, + } +} func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { - return memstore.Init(ctx, logger, "aux_", &repo{}) + return memstore.Init(ctx, logger, "aux_", newRepo()) } func (r *repo) GetRepo() repos.Repository { return r }