diff --git a/cmd/ateapi/internal/controlapi/functional_test.go b/cmd/ateapi/internal/controlapi/functional_test.go index 46e3feba..2219e6ad 100644 --- a/cmd/ateapi/internal/controlapi/functional_test.go +++ b/cmd/ateapi/internal/controlapi/functional_test.go @@ -26,6 +26,7 @@ import ( "testing" "time" + "github.com/agent-substrate/substrate/cmd/ateapi/internal/store" "github.com/agent-substrate/substrate/cmd/ateapi/internal/store/ateredis" "github.com/agent-substrate/substrate/internal/ateinterceptors" "github.com/agent-substrate/substrate/internal/proto/ateletpb" @@ -1203,7 +1204,7 @@ func TestSuspendActor_DanglingWorker(t *testing.T) { deleteWorkerPod(t, tc, ns, "worker-1") // 3. Call SuspendActor -> Should succeed (our fix skips missing pod execution) - actors, _ := tc.persistence.ListActors(context.Background()) + actors, _ := tc.persistence.ListActors(context.Background(), store.ListOptions{}) t.Logf("Actors in Redis before Suspend: %d", len(actors)) for _, a := range actors { t.Logf(" Actor: %s/%s/%s", a.GetActorTemplateNamespace(), a.GetActorTemplateName(), a.GetActorId()) diff --git a/cmd/ateapi/internal/controlapi/list_actors.go b/cmd/ateapi/internal/controlapi/list_actors.go index 3cc1128a..5311a72d 100644 --- a/cmd/ateapi/internal/controlapi/list_actors.go +++ b/cmd/ateapi/internal/controlapi/list_actors.go @@ -18,6 +18,7 @@ import ( "context" "fmt" + "github.com/agent-substrate/substrate/cmd/ateapi/internal/store" "github.com/agent-substrate/substrate/pkg/proto/ateapipb" ) @@ -25,7 +26,7 @@ func (s *Service) ListActors(ctx context.Context, req *ateapipb.ListActorsReques if err := validateListActorsRequest(req); err != nil { return nil, err } - actors, err := s.persistence.ListActors(ctx) + actors, err := s.persistence.ListActors(ctx, store.ListOptions{}) if err != nil { return nil, fmt.Errorf("while listing actors in db: %w", err) } diff --git a/cmd/ateapi/internal/controlapi/list_workers.go b/cmd/ateapi/internal/controlapi/list_workers.go index 94406bf6..20ca0758 100644 --- a/cmd/ateapi/internal/controlapi/list_workers.go +++ b/cmd/ateapi/internal/controlapi/list_workers.go @@ -18,6 +18,7 @@ import ( "context" "fmt" + "github.com/agent-substrate/substrate/cmd/ateapi/internal/store" "github.com/agent-substrate/substrate/pkg/proto/ateapipb" ) @@ -25,7 +26,7 @@ func (s *Service) ListWorkers(ctx context.Context, req *ateapipb.ListWorkersRequ if err := validateListWorkersRequest(req); err != nil { return nil, err } - workers, err := s.persistence.ListWorkers(ctx) + workers, err := s.persistence.ListWorkers(ctx, store.ListOptions{}) if err != nil { return nil, fmt.Errorf("while listing workers in db: %w", err) } diff --git a/cmd/ateapi/internal/controlapi/syncer_test.go b/cmd/ateapi/internal/controlapi/syncer_test.go index 6276c53a..3516d3ab 100644 --- a/cmd/ateapi/internal/controlapi/syncer_test.go +++ b/cmd/ateapi/internal/controlapi/syncer_test.go @@ -60,7 +60,7 @@ func TestSyncer_Lifecycle(t *testing.T) { poolName := "pool1" // 1. Verify no workers in Redis initially - workers, err := persistence.ListWorkers(context.Background()) + workers, err := persistence.ListWorkers(context.Background(), store.ListOptions{}) if err != nil { t.Fatalf("failed to list workers: %v", err) } diff --git a/cmd/ateapi/internal/controlapi/workflow_resume.go b/cmd/ateapi/internal/controlapi/workflow_resume.go index 1401461c..317fed69 100644 --- a/cmd/ateapi/internal/controlapi/workflow_resume.go +++ b/cmd/ateapi/internal/controlapi/workflow_resume.go @@ -85,7 +85,7 @@ func (s *AssignWorkerStep) IsComplete(ctx context.Context, input *ResumeInput, s return state.Actor.GetStatus() == ateapipb.Actor_STATUS_RUNNING, nil } func (s *AssignWorkerStep) Execute(ctx context.Context, input *ResumeInput, state *ResumeState) error { - workers, err := s.store.ListWorkers(ctx) + workers, err := s.store.ListWorkers(ctx, store.ListOptions{}) if err != nil { return fmt.Errorf("while listing workers: %w", err) } diff --git a/cmd/ateapi/internal/store/ateredis/ateredis.go b/cmd/ateapi/internal/store/ateredis/ateredis.go index 7b4b109a..19368768 100644 --- a/cmd/ateapi/internal/store/ateredis/ateredis.go +++ b/cmd/ateapi/internal/store/ateredis/ateredis.go @@ -353,7 +353,7 @@ func (s *Persistence) UpdateActor(ctx context.Context, actor *ateapipb.Actor, ex return nil } -func (s *Persistence) ListWorkers(ctx context.Context) ([]*ateapipb.Worker, error) { +func (s *Persistence) ListWorkers(ctx context.Context, opts store.ListOptions) ([]*ateapipb.Worker, error) { var result []*ateapipb.Worker var mu sync.Mutex @@ -377,6 +377,10 @@ func (s *Persistence) ListWorkers(ctx context.Context) ([]*ateapipb.Worker, erro return fmt.Errorf("in protojson.Unmarshal: %w", err) } + if !matchesWorker(worker, opts) { + continue + } + mu.Lock() result = append(result, worker) mu.Unlock() @@ -393,7 +397,7 @@ func (s *Persistence) ListWorkers(ctx context.Context) ([]*ateapipb.Worker, erro return result, nil } -func (s *Persistence) ListActors(ctx context.Context) ([]*ateapipb.Actor, error) { +func (s *Persistence) ListActors(ctx context.Context, opts store.ListOptions) ([]*ateapipb.Actor, error) { var result []*ateapipb.Actor var mu sync.Mutex @@ -416,6 +420,10 @@ func (s *Persistence) ListActors(ctx context.Context) ([]*ateapipb.Actor, error) return fmt.Errorf("in protojson.Unmarshal: %w", err) } + if !matchesActor(actor, opts) { + continue + } + mu.Lock() result = append(result, actor) mu.Unlock() @@ -429,6 +437,88 @@ func (s *Persistence) ListActors(ctx context.Context) ([]*ateapipb.Actor, error) return result, nil } +func matchesWorker(w *ateapipb.Worker, opts store.ListOptions) bool { + if len(opts.FieldSelector) == 0 { + return true + } + for k, v := range opts.FieldSelector { + switch k { + case "worker_namespace": + if w.GetWorkerNamespace() != v { + return false + } + case "worker_pool": + if w.GetWorkerPool() != v { + return false + } + case "worker_pod": + if w.GetWorkerPod() != v { + return false + } + case "actor_namespace": + if w.GetActorNamespace() != v { + return false + } + case "actor_template": + if w.GetActorTemplate() != v { + return false + } + case "actor_id": + if w.GetActorId() != v { + return false + } + case "ip": + if w.GetIp() != v { + return false + } + default: + return false + } + } + return true +} + +func matchesActor(a *ateapipb.Actor, opts store.ListOptions) bool { + if len(opts.FieldSelector) == 0 { + return true + } + for k, v := range opts.FieldSelector { + switch k { + case "actor_id": + if a.GetActorId() != v { + return false + } + case "actor_template_namespace": + if a.GetActorTemplateNamespace() != v { + return false + } + case "actor_template_name": + if a.GetActorTemplateName() != v { + return false + } + case "status": + if a.GetStatus().String() != v { + return false + } + case "ateom_pod_namespace": + if a.GetAteomPodNamespace() != v { + return false + } + case "ateom_pod_name": + if a.GetAteomPodName() != v { + return false + } + case "ateom_pod_ip": + if a.GetAteomPodIp() != v { + return false + } + default: + return false + } + } + return true +} + func (s *Persistence) AcquireLock(ctx context.Context, key string, value string, ttl time.Duration) (bool, error) { ok, err := s.rdb.SetNX(ctx, key, value, ttl).Result() if err != nil { diff --git a/cmd/ateapi/internal/store/ateredis/ateredis_test.go b/cmd/ateapi/internal/store/ateredis/ateredis_test.go index a5b5e862..7c78001f 100644 --- a/cmd/ateapi/internal/store/ateredis/ateredis_test.go +++ b/cmd/ateapi/internal/store/ateredis/ateredis_test.go @@ -373,7 +373,7 @@ func TestListWorkers(t *testing.T) { t.Fatalf("failed to create worker2: %v", err) } - workers, err := s.ListWorkers(ctx) + workers, err := s.ListWorkers(ctx, store.ListOptions{}) if err != nil { t.Fatalf("ListWorkers failed: %v", err) } @@ -424,7 +424,7 @@ func TestListActors(t *testing.T) { t.Fatalf("failed to create actor2: %v", err) } - actors, err := s.ListActors(ctx) + actors, err := s.ListActors(ctx, store.ListOptions{}) if err != nil { t.Fatalf("ListActors failed: %v", err) } @@ -515,7 +515,7 @@ func TestListWorkers_Empty(t *testing.T) { mr, s, ctx := setupTest(t) defer mr.Close() - workers, err := s.ListWorkers(ctx) + workers, err := s.ListWorkers(ctx, store.ListOptions{}) if err != nil { t.Fatalf("ListWorkers failed: %v", err) } @@ -529,7 +529,7 @@ func TestListActors_Empty(t *testing.T) { mr, s, ctx := setupTest(t) defer mr.Close() - actors, err := s.ListActors(ctx) + actors, err := s.ListActors(ctx, store.ListOptions{}) if err != nil { t.Fatalf("ListActors failed: %v", err) } @@ -721,3 +721,144 @@ func TestAcquireLock_NonReentry(t *testing.T) { t.Errorf("expected second lock acquisition to fail (non-reentrant)") } } + +func TestListWorkers_Filtering(t *testing.T) { + mr, s, ctx := setupTest(t) + defer mr.Close() + + worker1 := &ateapipb.Worker{ + WorkerNamespace: "ns1", + WorkerPool: "pool1", + WorkerPod: "pod1", + ActorId: "actor1", + } + worker2 := &ateapipb.Worker{ + WorkerNamespace: "ns2", + WorkerPool: "pool2", + WorkerPod: "pod2", + ActorId: "", + } + if err := s.CreateWorker(ctx, worker1); err != nil { + t.Fatalf("failed to create worker1: %v", err) + } + if err := s.CreateWorker(ctx, worker2); err != nil { + t.Fatalf("failed to create worker2: %v", err) + } + + tests := []struct { + name string + selector map[string]string + expectedPodIDs []string + }{ + { + name: "match pool1", + selector: map[string]string{"worker_pool": "pool1"}, + expectedPodIDs: []string{"pod1"}, + }, + { + name: "match empty actor_id (idle worker)", + selector: map[string]string{"actor_id": ""}, + expectedPodIDs: []string{"pod2"}, + }, + { + name: "match worker namespace and pool", + selector: map[string]string{"worker_namespace": "ns2", "worker_pool": "pool2"}, + expectedPodIDs: []string{"pod2"}, + }, + { + name: "no match", + selector: map[string]string{"worker_pool": "non-existent"}, + expectedPodIDs: []string{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + opts := store.ListOptions{FieldSelector: tc.selector} + workers, err := s.ListWorkers(ctx, opts) + if err != nil { + t.Fatalf("ListWorkers failed: %v", err) + } + + if len(workers) != len(tc.expectedPodIDs) { + t.Fatalf("expected %d workers, got %d", len(tc.expectedPodIDs), len(workers)) + } + for i, w := range workers { + if w.GetWorkerPod() != tc.expectedPodIDs[i] { + t.Errorf("expected worker %s, got %s", tc.expectedPodIDs[i], w.GetWorkerPod()) + } + } + }) + } +} + +func TestListActors_Filtering(t *testing.T) { + mr, s, ctx := setupTest(t) + defer mr.Close() + + actor1 := &ateapipb.Actor{ + ActorId: "id1", + ActorTemplateNamespace: "ns1", + ActorTemplateName: "tmpl1", + Status: ateapipb.Actor_STATUS_RUNNING, + } + actor2 := &ateapipb.Actor{ + ActorId: "id2", + ActorTemplateNamespace: "ns2", + ActorTemplateName: "tmpl2", + Status: ateapipb.Actor_STATUS_SUSPENDED, + } + + if err := s.CreateActor(ctx, actor1); err != nil { + t.Fatalf("failed to create actor1: %v", err) + } + if err := s.CreateActor(ctx, actor2); err != nil { + t.Fatalf("failed to create actor2: %v", err) + } + + tests := []struct { + name string + selector map[string]string + expectedActorIDs []string + }{ + { + name: "match status running", + selector: map[string]string{"status": "STATUS_RUNNING"}, + expectedActorIDs: []string{"id1"}, + }, + { + name: "match status suspended", + selector: map[string]string{"status": "STATUS_SUSPENDED"}, + expectedActorIDs: []string{"id2"}, + }, + { + name: "match template name", + selector: map[string]string{"actor_template_name": "tmpl1"}, + expectedActorIDs: []string{"id1"}, + }, + { + name: "no match", + selector: map[string]string{"status": "STATUS_UNSPECIFIED"}, + expectedActorIDs: []string{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + opts := store.ListOptions{FieldSelector: tc.selector} + actors, err := s.ListActors(ctx, opts) + if err != nil { + t.Fatalf("ListActors failed: %v", err) + } + + if len(actors) != len(tc.expectedActorIDs) { + t.Fatalf("expected %d actors, got %d", len(tc.expectedActorIDs), len(actors)) + } + for i, a := range actors { + if a.GetActorId() != tc.expectedActorIDs[i] { + t.Errorf("expected actor %s, got %s", tc.expectedActorIDs[i], a.GetActorId()) + } + } + }) + } +} diff --git a/cmd/ateapi/internal/store/store.go b/cmd/ateapi/internal/store/store.go index 52553c00..41c70543 100644 --- a/cmd/ateapi/internal/store/store.go +++ b/cmd/ateapi/internal/store/store.go @@ -37,6 +37,12 @@ var ( ErrFailedPrecondition = errors.New("persistence: failed precondition") ) +// ListOptions holds parameters for filtering List results. +type ListOptions struct { + // FieldSelector filters results by matching specific fields exactly (e.g., "worker_pool=cpu-pool"). + FieldSelector map[string]string +} + // Interface defines the contract for the persistence layer storing actor state. type Interface interface { // Fetches an actor by id. Returns ErrNotFound if missing. @@ -51,8 +57,8 @@ type Interface interface { // Removes an actor. Returns ErrNotFound if missing, or ErrFailedPrecondition if not suspended. DeleteActor(ctx context.Context, id string) error - // Lists all known actors. Returns nil if none found. - ListActors(ctx context.Context) ([]*ateapipb.Actor, error) + // Lists all known actors matching ListOptions. Returns nil if none found. + ListActors(ctx context.Context, opts ListOptions) ([]*ateapipb.Actor, error) // Fetches worker state by namespace, pool, and pod name. Returns ErrNotFound if missing. GetWorker(ctx context.Context, namespace, pool, pod string) (*ateapipb.Worker, error) @@ -66,8 +72,8 @@ type Interface interface { // Removes a worker. Idempotent: does nothing if worker is not found. DeleteWorker(ctx context.Context, namespace, pool, pod string) error - // Lists all known workers. Returns nil if none found. - ListWorkers(ctx context.Context) ([]*ateapipb.Worker, error) + // Lists all known workers matching ListOptions. Returns nil if none found. + ListWorkers(ctx context.Context, opts ListOptions) ([]*ateapipb.Worker, error) // AcquireLock attempts to acquire a distributed lock with a TTL. // Returns true if the lock was successfully acquired.